ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

搭建网络模型的笔记

2021-08-26 16:03:52  阅读:174  来源: 互联网

标签:loss 模型 torch 笔记 targets net data model 搭建


搭建网络模型

1. 导入模块

  • import 模块

2. 选择设备

  • device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

3. 准备数据集

  • 训练集:
    • train_data = torchvision.datasets.CIFAR10(root="./data_CIFAR10", train=True,transform=torchvision.transforms.ToTensor(),download=True)
  • 测试集:
    • test_data = torchvision.datasets.CIFAR10(root="./data_CIFAR10", train=False,transform=torchvision.transforms.ToTensor(),download=True)

4. 加载数据集

  • train_dataloader = DataLoader(train_data, batch_size=64)
  • test_dataloader = DataLoader(test_data, batch_size=64)

5. 创建网络模型

  • class MyModel(nn.Module):
    • def init(self):
      • xxxxxxx
      • xxxxxxx
    • def forward(self, x):
      • x = self.model1(x)
      • return x

6. 实例化网络模型

  • net_model = MyModel()
  • net_model = net_model.to(device)

7. 定义损失函数

  • loss_fn = nn.xxxxxxxLoss()
  • if torch.cuda.is_available():
    • loss_fn = loss_fn.cuda()

8. 定义优化器

  • optimizer = torch.optim.SGD(net_model.parameters(), lr=learning_rate)
  • optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9) # 选择优化器
  • scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略

9. 训练部分

  • for i in range(epoch):
    • 训练步骤

    • for data in train_dataloader:

      • 1.获取训练数据
      • imgs, targets = data
      • 2.选择设备
      • imgs = imgs.to(device)
      • 3.把图片传入网络模型进行训练,返回10个结果
      • targets = targets.to(device)
      • outputs = net_model(imgs)
      • 4.进行损失函数处理
      • loss = loss_fn(outputs, targets)
      • 5.梯度清零
      • optimizer.zero_grad()
      • 6.反向传播
      • loss.backward()
      • 7.优化器,更新权重
      • optimizer.step()
    • 测试步骤

    • with torch.no_grad():

      • for data in test_dataloader:
        • imgs, targets = data # 1.获取测试数据

        • imgs = imgs.to(device) # 2.选择设备

        • targets = targets.to(device)

        • outputs = net_model(imgs) # 3.将测试图片传入训练模型

        • loss = loss_fn(outputs, targets) # 4.计算损失值

        • total_test_loss = total_test_loss + loss.item() # 5.计算总的损失值

        • accuracy = (outputs.argmax(1) == targets).sum() # 6.计算准确率

        • total_accuarcy = total_accuarcy + accuracy # 7.计算总准确率

    • 保存训练好的模型

    • torch.save(net_model, "net_model{}.path".format(i))

10. 验证数据

  • 待验证图片预处理

    • 转换为和测试集相同格式的图片,输入为同类型
      1. 加载
    • transfrom = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])
    • image = transfrom(image)
      1. 转换
    • image = torch.reshape(image, (1, 3, 32, 32))
  • 加载保存的模型

    • model = torch.load("net_model25.path")
  • 验证

  • model.eval()

  • with torch.no_grad():

    • image = image.to(device)# 转换成相同类型数据集
    • output = model(image)
  • print(output)

  • print(output.argmax(1))

标签:loss,模型,torch,笔记,targets,net,data,model,搭建
来源: https://www.cnblogs.com/cmn-note/p/15190053.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有