ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

Person_reID_baseline_pytorch 源码解析之 test.py

2022-01-06 11:59:37  阅读:370  来源: 互联网

标签:torch baseline img py label 源码 ff query gallery


源码中有两个用于测试的脚本: test.py 和 evaluate_gpu.py 。其中, test.py 加载通过脚本 train.py 训练好的模型,实现对 query 和 gallery 图片的特征提取;本文对脚本 test.py 进行解析。

1. 加载模型和数据

首先需要载入训练好的模型,这里以基于 Resnet50 输出类别为 751 类的行人重识别模型 ft_net 为例。

model_structure = ft_net(751)
model = load_network(model_structure)

然后需要载入经过预处理的 gallery 和 query 数据集

data_transforms = transforms.Compose([
        transforms.Resize((256,128), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                             shuffle=False, num_workers=0) for x in ['gallery','query']}

加载预处理过的数据集和训练好的模型,然后使用函数 extract_feature 进行特征提取

with torch.no_grad():
    gallery_feature = extract_feature(model,dataloaders['gallery'])
    query_feature = extract_feature(model,dataloaders['query'])

2. 完成特征提取

extract_feature 是 test.py 中非常重要的一个函数,用于提取图片的特征,下面对它逐行解析

def extract_feature(model,dataloaders):
    features = torch.FloatTensor()
    count = 0
    # 加载数据集
    for data in dataloaders:
        img, label = data
        n, c, h, w = img.size()
        count += n
        # 统计数据集图片数量
        print(count)
        ff = torch.FloatTensor(n,512).zero_().cuda()
        for i in range(2):
            if(i==1):
            	# 翻转图片
                img = fliplr(img)
            # 将图片变成 Variable,准备加载到网络中
            input_img = Variable(img.cuda())
            # 缩放尺寸 multiple_scale
            for scale in ms:
                if scale != 1:
                    # bicubic is only  available in pytorch>= 1.1
                    input_img = nn.functional.interpolate(input_img, scale_factor=scale, mode='bicubic', align_corners=False)
                # 模型推理
                outputs = model(input_img) 
                # 拼接多尺度预测结果
                ff += outputs
        # norm feature 特征归一化
            fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
            ff = ff.div(fnorm.expand_as(ff))
		# 返回提取到的特征
        features = torch.cat((features,ff.data.cpu()), 0)
    return features

3. 实现特征归一化

fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)

这里是在输入张量 ff 的第 1 维进行 L2-norm,即 2 范数归一化。特征向量中每个元素均除以向量的L2范数。

在这里插入图片描述
pytorch 中使用 torch.norm 计算张量的范数。

fnorm = torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)
  • input 输入张量
  • p 是范数计算中的幂指数值,p = 2 时即为 2 范数
  • dim 指定计算的维度,如果 dim 是整数值,则计算向量范数。当输入张量 input 超过2维,将在最后一维计算向量范数
  • keepdim 指明是否保留输出张量的维度dim
  • out 输出张量
  • dtype 返回张量的期待数据类型

令特征向量除以向量的L2范数,expand_as 函数将范数 fnorm 扩展成张量 ff 相同的维度。

 ff = ff.div(fnorm.expand_as(ff))

然后使用 tensor.div 完成除法。

Tensor.div(value, *, rounding_mode=None)

最后,使用 torch.cat 在第 0 维上拼接输入张量

features = torch.cat((features,ff.data.cpu()), 0)

4. 生成 Matlab 文件

通过上述步骤实现了 query 和 gallery 图片特征的提取,将特征矩阵存储到 pytorch_result.mat 文件中。

# Save to Matlab for check
result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam}
scipy.io.savemat('pytorch_result.mat',result)

为了评估模型效果,还要记录图片的 label 和 camera 。
这里使用 get_id 函数通过图片名称获取 label 和 camera 信息。

def get_id(img_path):
    camera_id = []
    labels = []
    for path, v in img_path:
        #filename = path.split('/')[-1]
        filename = os.path.basename(path)
        label = filename[0:4]
        camera = filename.split('c')[1]
        if label[0:2]=='-1':
            labels.append(-1)
        else:
            labels.append(int(label))
        camera_id.append(int(camera[0]))
    return camera_id, labels

gallery_path = image_datasets['gallery'].imgs
query_path = image_datasets['query'].imgs

gallery_cam,gallery_label = get_id(gallery_path)
query_cam,query_label = get_id(query_path)

生成的 Matlab 文件将被脚本 evaluate_gpu.py 使用,用于计算模型的评估指标。

参考链接

  1. pytorch求范数函数——torch.norm
  2. pytorch torch.norm 文档
  3. Pytorch expand_as()函数
  4. torch.cat()函数的官方解释,详解以及例子
  5. torch.stack()的官方解释,详解以及例子

标签:torch,baseline,img,py,label,源码,ff,query,gallery
来源: https://blog.csdn.net/qq_39220334/article/details/121630259

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有