ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

从零写CRNN文字识别 —— (6)训练

2021-02-12 14:06:08  阅读:566  来源: 互联网

标签:acc batch epoch preds shape CRNN 零写 识别 size


前言

完整代码已经上传github:https://github.com/xmy0916/pytorch_crnn

训练

训练部分的代码逻辑如下:

for epoch in range(total_epoch):
  for data in dataloader:
    数据输入模型(前馈)
    根据输出计算loss
    loss反馈更新网络参数
  if epoch % eval_epoch == 0:
    评估数据输入模型(前馈)
    根据输出计算loss
    解码输出计算识别准确率
    if now_acc > best_acc:
      保存模型

对应的完整代码如下:

# 训练
    best_acc = 0.0
    for epoch in range(last_epoch,config.TRAIN.END_EPOCH):
      model.train()
      for i, (inp, idx) in enumerate(train_loader):
          # 前馈
          inp = inp.to(device)
          preds = model(inp).to(device)
          # 计算loss
          labels = get_batch_label(train_dataset, idx)
          batch_size = inp.size(0)
          text, length = encode(config.DICT,labels)
          preds_size = torch.IntTensor([preds.size(0)] * batch_size)
          loss = criterion(preds, text, preds_size, length)
          # 反馈
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          if i % config.PRINT_FREQ == 0:
            print("epoch:{} step:{} loss:{} lr:{}".format(epoch,i,loss.item(),lr_scheduler.get_lr()))
      # 每个epoch更新学习率
      lr_scheduler.step()

      # 每EVAL_FREQ评估一次并保存best模型
      if epoch % config.EVAL_FREQ == 0:
          model.eval()
          n_correct = 0
          test_num = len(val_loader) * config.TEST.BATCH_SIZE_PER_GPU
          with torch.no_grad():
              for i, (inp, idx) in enumerate(val_loader):
                  # 计算前馈
                  inp = inp.to(device)
                  preds = model(inp).cpu()
                  # 计算loss
                  labels = get_batch_label(val_dataset, idx)
                  batch_size = inp.size(0)
                  text, length = encode(config.DICT,labels)
                  preds_size = torch.IntTensor([preds.size(0)] * batch_size)
                  loss = criterion(preds, text, preds_size, length)
                  # 后处理解码
                  print("网络输出的preds的shape:",preds.cpu().detach().shape)
                  _, preds = preds.max(2)
                  print("max(2)的shape:",preds.cpu().detach().shape)
                  preds = preds.transpose(1, 0).contiguous().view(-1)
                  print("transpose的shape:",preds.cpu().detach().shape)
                  sim_preds = decode(preds.data, preds_size.data, config.DICT,raw=False)
                  for pred, target in zip(sim_preds, labels):
                    if pred == target:
                      n_correct += 1

              
          # 抓一个batch来显示
          raw_preds = decode(preds.data, preds_size.data, config.DICT, raw=True)[:config.TEST.NUM_TEST_DISP]
          for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
              print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
          print("preds:",preds.cpu().detach().numpy())
          print("preds_shape:",preds.cpu().detach().shape)
          print("dict:",config.DICT)
          now_acc = n_correct * 1.0 / test_num
          print("best_acc:{} correct:{}".format(now_acc,n_correct))
          if now_acc >= best_acc:
              torch.save(
                    {
                        "state_dict": model.state_dict(),
                        "epoch": epoch + 1,
                        # "optimizer": optimizer.state_dict(),
                        # "lr_scheduler": lr_scheduler.state_dict(),
                        "best_acc": best_acc,
                    },  os.path.join(config.OUTPUT_DIR, "checkpoint_{}_acc_{:.4f}.pth".format(epoch, now_acc)))
              best_acc = now_acc
              print("save_model!")

看看评估过程(摘一段代码出来):

preds = model(inp).cpu()
# 计算loss
labels = get_batch_label(val_dataset, idx)
batch_size = inp.size(0)
text, length = encode(config.DICT,labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
# 后处理解码
print("网络输出的preds的shape:",preds.cpu().detach().shape)
_, preds = preds.max(2)
print("max(2)的shape:",preds.cpu().detach().shape)
preds = preds.transpose(1, 0).contiguous().view(-1)
print("transpose的shape:",preds.cpu().detach().shape)

打印结果:
在这里插入图片描述
稍微解释下:
preds的shape[41,16,109]:

  • 41是卷积后的长度
  • 16是测试时的batch_size大小
  • 109是字典的类别数

preds.max(2)得到了从属于那一类的向量,2表示在109的纬度上取所以输出的shape是[41,16]
transpose是把二维向量拉平,656=41*16
这里注意一点,测试的时候每个batch_size是16,但是我们数据集不一定是16的整数倍,所以最后一个batch的大小不一定有16,例如我们这里最后一个batch的大小是14:
在这里插入图片描述
在代码中我将最后一个batch的测试图片可视化的打印了,结果如下:
在这里插入图片描述
这是第一个epoch训练的输出,
在这里插入图片描述
上图的横杠是设置的空字符的占位符,在config/config.yml中设置这个字符BLANK_CHAR
在这里插入图片描述
上图一共574个0,574 = 41 * 14因为是最后一个batch所以不够16个,上图理论上可以解码成574个字符,因为这是第一个epoch训练的结果,网络参数基本不对所以没有输出。

第16个epoch输出如下:
在这里插入图片描述
第一行的37这个值就是dict中L的位置
在这里插入图片描述

标签:acc,batch,epoch,preds,shape,CRNN,零写,识别,size
来源: https://blog.csdn.net/qq_37668436/article/details/113794325

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有