ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

深度学习项目中_5-训练模块的编写

2022-03-02 18:59:00  阅读:186  来源: 互联网

标签:loss val args epoch train 模块 深度 编写 model


1.训练模块中的步骤

训练模块一般是保存在train.py 的文件中, 该模块中一般包含以下步骤:

  1. 导入各类模块, (标准库, 三方库, cv, torch, torchvision), 如果在model.py ( 自己定义网络模型文件), loss.py (自定义的损失函数), utils.py( 自定义的各种方法), config.py(整体项目的配置文件), 这些模块都需要导入

  2. 命令行解析

  3. 数据集加载

  4. 检测模型保存地址是否存在, 如果不存在则创建;

  5. 实例化网络模型;

  6. 实例化 损失函数和优化器

  7. 准备事件 文件, 方便 Tensorboard --logdir=“run”, 可视化训练过程;

  8. 检查是否采用,接着上一次的检查点checkpoint 训练, 若是加载chepoint. 模型;

  9. 开始训练, 循环epochs:
    – 将梯度置零;
    – 求 loss;
    – 反向传播;
    – 更新权重参数;
    – 更新优化器中的学习率(可选)

  10. 可视化指标;

  11. 验证valid 模型, (根据模型在验证集上损失和度量, 调整模型的超参数)

2. 训练模块的代码实例

2.1 训练模块 示范一

import os
import torch
from torch.utils.data import DataLoader
from torch import nn
import argparse
from tensorboardX import SummaryWriter

from data_preparation.data_preparation import FileDateset
from model.Baseline import Base_model
from model.ops import pytorch_LSD


def parse_args():
    parser = argparse.ArgumentParser()
    # 重头开始训练 defaule=None, 继续训练defaule设置为'/**.pth'
    parser.add_argument("--model_name", type=str, default=None, help="是否加载模型继续训练 '/50.pth' None")
    parser.add_argument("--batch-size", type=int, default=16, help="")
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument('--lr', type=float, default=3e-4, help='学习率 (default: 0.01)')
    parser.add_argument('--train_data', default="./data_preparation/Synthetic/TRAIN", help='数据集的path')
    parser.add_argument('--val_data', default="./data_preparation/Synthetic/VAL", help='验证样本的path')
    parser.add_argument('--checkpoints_dir', default="./checkpoints/AEC_baseline", help='模型检查点文件的路径(以继续培训)')
    parser.add_argument('--event_dir', default="./event_file/AEC_baseline", help='tensorboard事件文件的地址')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    print("GPU是否可用:", torch.cuda.is_available())  # True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 实例化 Dataset
    train_set = FileDateset(dataset_path=args.train_data)  # 实例化训练数据集
    val_set = FileDateset(dataset_path=args.val_data)  # 实例化验证数据集

    # 数据加载器
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=True)

    # ###########    保存检查点的地址(如果检查点不存在,则创建)   ############
    if not os.path.exists(args.checkpoints_dir):
        os.makedirs(args.checkpoints_dir)

    ################################
    #          实例化模型          #
    ################################
    model = Base_model().to(device)  # 实例化模型
    # summary(model, input_size=(322, 999))  # 模型输出 torch.Size([64, 322, 999])
    # ###########    损失函数   ############
    criterion = nn.MSELoss(reduce=True, size_average=True, reduction='mean')

    ###############################
    # 创建优化器 Create optimizers #
    ###############################
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, )
    # lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.1)

    # ###########    TensorBoard可视化 summary  ############
    writer = SummaryWriter(args.event_dir)  # 创建事件文件

    # ###########    加载模型检查点   ############
    start_epoch = 0
    if args.model_name:
        print("加载模型:", args.checkpoints_dir + args.model_name)
        checkpoint = torch.load(args.checkpoints_dir + args.model_name)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        start_epoch = checkpoint['epoch']
        # lr_schedule.load_state_dict(checkpoint['lr_schedule'])  # 加载lr_scheduler

    for epoch in range(start_epoch, args.epochs):
        model.train()  # 训练模型
        for batch_idx, (train_X, train_mask, train_nearend_mic_magnitude, train_nearend_magnitude) in enumerate(
                train_loader):
            train_X = train_X.to(device)  # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T)
            train_mask = train_mask.to(device)  # IRM [batch_size 161, 999]
            train_nearend_mic_magnitude = train_nearend_mic_magnitude.to(device)
            train_nearend_magnitude = train_nearend_magnitude.to(device)

            # 前向传播
            pred_mask = model(train_X)  # [batch_size, 322, 999]--> [batch_size, 161, 999]
            train_loss = criterion(pred_mask, train_mask)

            # 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999]
            pred_near_spectrum = pred_mask * train_nearend_mic_magnitude
            train_lsd = pytorch_LSD(train_nearend_magnitude, pred_near_spectrum)

            # 反向传播
            optimizer.zero_grad()  # 将梯度清零
            train_loss.backward()  # 反向传播
            optimizer.step()  # 更新参数

            # ###########    可视化打印   ############
        print('Train Epoch: {} Loss: {:.6f} LSD: {:.6f}'.format(epoch + 1, train_loss.item(), train_lsd.item()))

        # ###########    TensorBoard可视化 summary  ############
        # lr_schedule.step()  # 学习率衰减
        # writer.add_scalar(tag="lr", scalar_value=model.state_dict()['param_groups'][0]['lr'], global_step=epoch + 1)
        writer.add_scalar(tag="train_loss", scalar_value=train_loss.item(), global_step=epoch + 1)
        writer.add_scalar(tag="train_lsd", scalar_value=train_lsd.item(), global_step=epoch + 1)
        writer.flush()

        # 神经网络在验证数据集上的表现
        model.eval()  # 测试模型
        # 测试的时候不需要梯度
        with torch.no_grad():
            for val_batch_idx, (val_X, val_mask, val_nearend_mic_magnitude, val_nearend_magnitude) in enumerate(
                    val_loader):
                val_X = val_X.to(device)  # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T)
                val_mask = val_mask.to(device)  # IRM [batch_size 161, 999]
                val_nearend_mic_magnitude = val_nearend_mic_magnitude.to(device)
                val_nearend_magnitude = val_nearend_magnitude.to(device)

                # 前向传播
                val_pred_mask = model(val_X)
                val_loss = criterion(val_pred_mask, val_mask)

                # 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999]
                val_pred_near_spectrum = val_pred_mask * val_nearend_mic_magnitude
                val_lsd = pytorch_LSD(val_nearend_magnitude, val_pred_near_spectrum)

            # ###########    可视化打印   ############
            print('  val Epoch: {} \tLoss: {:.6f}\tlsd: {:.6f}'.format(epoch + 1, val_loss.item(), val_lsd.item()))
            ######################
            # 更新tensorboard    #
            ######################
            writer.add_scalar(tag="val_loss", scalar_value=val_loss.item(), global_step=epoch + 1)
            writer.add_scalar(tag="val_lsd", scalar_value=val_lsd.item(), global_step=epoch + 1)
            writer.flush()

        # # ###########    保存模型   ############
        if (epoch + 1) % 10 == 0:
            checkpoint = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch + 1,
                # 'lr_schedule': lr_schedule.state_dict()
            }
            torch.save(checkpoint, '%s/%d.pth' % (args.checkpoints_dir, epoch + 1))


if __name__ == "__main__":
    main()

2.2 训练模块 示范二

作者:石郎
链接:https://www.zhihu.com/question/406133826/answer/1334319004
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

# 定义网络
net = Net()

# 定义数据
#数据预处理,1.转为tensor,2.归一化
transform = transforms.Compose(    
     [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
# 验证集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

# 定义损失函数和优化器 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 开始训练
net.train()
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        # 将梯度置为0
        # zero the parameter gradients
        optimizer.zero_grad()
        # 求loss
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        # 梯度反向传播
        loss.backward()
        # 由梯度,更新参数
        optimizer.step()

        # 可视化
        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

# 查看在验证集上的效果
dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

net.eval()
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))

3 训练模块 中输入的超参数

3.1 超参数定义

在机器学习的过程中,
超参= 在开始机器学习之前,就人为设置好的参数。

模型参数=通过训练得到的参数数据。
通常情况下,需要对超参数进行优化,给学习机选择一组最优超参数,以提高学习的性能和效果

3.2 深度学习中的常见 超参数

一个深度学习网络有很多的参数可以配置,一般分成以下三类:

  1. 数据集参数(文件路径、batch_size等)
  2. 训练参数(学习率、训练epoch等)
  3. 模型参数(输入的大小,输出的大小)

这些参数可以写一个类保存,也可以写一个字典,然后使用json保存,这些都是需要自己去实现的,但是这些都是一些细枝末节东西,
多写几次,找到一个自己最喜欢的方式就可以,不是深度学习项目中必要的部分。

标签:loss,val,args,epoch,train,模块,深度,编写,model
来源: https://blog.csdn.net/chumingqian/article/details/123236656

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有