ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

加载模型及对测试数据进行预测p41

2022-02-05 18:59:24  阅读:153  来源: 互联网

标签:img image 测试数据 np ax model margin p41 加载


基于resnet训练flower图像分类模型(p31-p37)上一篇,我改成别的笔记本跑完了。按照老师的步骤,进行加载模型及测试数据预测。

我们之前是冻住了,只训练一层,也可以全部训练,我的显卡太低跑时间太长了,这部分没跑。

1加载训练好的模型

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)

# GPU模式
model_ft = model_ft.to(device)

# 保存文件的名字
filename='checkpoint.pth'

# 加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

2测试数据预处理

测试数据处理方法需要跟训练时一致才可以。输入的大小是一致,标准化用跟训练数据相同的mean和std,PyTorch中颜色通道是第一个维度,跟很多工具包都不一样,需要转换。


def process_image(image_path):
    # 读取测试数据
    img = Image.open(image_path)
    # Resize,thumbnail方法只能进行缩小,所以进行了判断
    if img.size[0] > img.size[1]:
        img.thumbnail((10000, 256))
    else:
        img.thumbnail((256, 10000))
    # Crop操作
    left_margin = (img.width - 224) / 2
    bottom_margin = (img.height - 224) / 2
    right_margin = left_margin + 224
    top_margin = bottom_margin + 224
    img = img.crop((left_margin, bottom_margin, right_margin,
                    top_margin))
    # 相同的预处理方法
    img = np.array(img) / 255
    mean = np.array([0.485, 0.456, 0.406])  # provided mean
    std = np.array([0.229, 0.224, 0.225])  # provided std
    img = (img - mean) / std

    # 注意颜色通道应该放在第一个位置
    img = img.transpose((2, 0, 1))

    return img


def imshow(image, ax=None, title=None):
    """展示数据"""
    if ax is None:
        fig, ax = plt.subplots()

    # 颜色通道还原
    image = np.array(image).transpose((1, 2, 0))

    # 预处理还原
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = std * image + mean
    image = np.clip(image, 0, 1)

    ax.imshow(image)
    ax.set_title(title)

    return ax

image_path = './data/flower_data/train/3/image_06620.jpg'
img = process_image(image_path)
imshow(img)

在验证集随便选一张花的图片进行测试

3对一个batch的数据进行测试

# 得到一个batch的测试数据
dataiter = iter(dataloaders['valid'])
images, labels = dataiter.next()

model_ft.eval()

if train_on_gpu:
    output = model_ft(images.cuda())  #utput表示对一个batch中每一个数据得到其属于各个类别的可能性
else:
    output = model_ft(images)

output,有8张图片,每个图片有102种分类结果

得到概率最大的那个

_, preds_tensor = torch.max(output, 1)  #得到概率最大的那个

preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())

展示预测结果:

fig=plt.figure(figsize=(20, 12))
columns =4
rows = 2

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
                 color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()  #绿色名字为预测正确,红色名字为预测错误

 看这个图,效果还不错,跟之前模型的85%的准确率比较符合。

标签:img,image,测试数据,np,ax,model,margin,p41,加载
来源: https://blog.csdn.net/bohu83/article/details/122791814

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有