ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

TorchVision中通过AlexNet网络进行图像分类

2021-11-27 16:35:12  阅读:354  来源: 互联网

标签:tensor name height width shape 图像 images AlexNet TorchVision


      TorchVision中给出了AlexNet的pretrained模型,模型存放位置为https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth ,可通过models.alexnet函数下载,此函数实现在torchvision/models/alexnet.py中,下载后在Ubuntu上存放在~/.cache/torch/hub/checkpoints目录下,在Windows上存放在C:\Users\spring\.cache\torch\hub\checkpoints目录下,其中spring为用户名。

      AlexNet的介绍参考:https://blog.csdn.net/fengbingchun/article/details/112709281

      在推理(inference)过程中,模型的输入是一个tensor,shape需要是[1,c,h,w],原始图像进行预处理操作包括:

      (1).resize到短边为256,长边等比缩放。

      (2).在中心裁剪图像大小到224*224。

      (3).将数据从numpy.ndarray转换到tensor;原数据shape为[h,w,c],转换后tensor shape为[c,h,w];原数据值范围为[0,255],转换后值范围为[0.0,1.0]。

      (4).使用均值和标准差对tensor图像进行归一化。

      (5).将tensor的shape从[c,h,w]转换到[1,c,h,w]。

      模型是通过ImageNet数据集训练获得的,它的图像分类数是1000,ImageNet数据集的介绍参考:https://blog.csdn.net/fengbingchun/article/details/88606621

      以下为测试代码:

import torch
from torchvision import models
from torchvision import transforms
import cv2
from PIL import Image
import math
import numpy as np

#print(dir(models))

images_path = "../../data/image/"
images_name = ["5.jpg", "6.jpg", "7.jpg", "8.jpg", "9.jpg", "10.jpg"]
images_data = [] # opencv
tensor_data = [] # pytorch tensor

def images_stitch(images, cols=3, name="result.jpg"): # 图像简单拼接
    '''images: list, opencv image data; cols: number of images per line; name: save image result name'''
    width_total = 660
    width, height = width_total // cols, width_total // cols
    number = len(images)
    height_total = height * math.ceil(number / cols)

    mat1 = np.zeros((height_total, width_total, 3), dtype="uint8") # in Python images are represented as NumPy arrays

    for idx in range(number):
        height_, width_, _ = images[idx].shape
        if height_ != width_:
            if height_ > width_:
                width_ = math.floor(width_ / height_ * width)
                height_ = height
            else:
                height_ = math.floor(height_ / width_ * height)
                width_ = width
        else:
            height_, width_ = height, width

        mat2 = cv2.resize(images[idx], (width_, height_))
        offset_y, offset_x = (height - height_) // 2, (width - width_) // 2
        start_y, start_x = idx // cols * height, idx % cols * width
        mat1[start_y + offset_y:start_y + height_+offset_y, start_x + offset_x:start_x + width_+offset_x, :] = mat2

    cv2.imwrite(images_path+name, mat1)

for name in images_name:
    img = cv2.imread(images_path + name)
    print(f"name: {images_path+name}, opencv image shape: {img.shape}") # (h,w,c)
    images_data.append(img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_pil = Image.fromarray(img)

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    tensor = transform(img_pil)
    print(f"tensor shape: {tensor.shape}, max: {torch.max(tensor)}, min: {torch.min(tensor)}") # (c,h,w)
    tensor = torch.unsqueeze(tensor, 0) # 返回一个新的tensor,对输入的既定位置插入维度1
    print(f"tensor shape: {tensor.shape}, max: {torch.max(tensor)}, min: {torch.min(tensor)}") # (1,c,h,w)
    tensor_data.append(tensor)

images_stitch(images_data)

model = models.alexnet(pretrained=True) # AlexNet网络
#print(model) # 可查看模型结构,与torchvision/models/alexnet.py中一致
model.eval() # AlexNet is required to be put in evaluation mode in order to do prediction/evaluation

with open("imagenet_classes.txt") as f:
    classes = [line.strip() for line in f.readlines()] # the line number specified the class number

for x in range(len(tensor_data)):
    prediction = model(tensor_data[x])
    #print(prediction.shape) # [1,1000]
    _, index = torch.max(prediction, 1)
    percentage = torch.nn.functional.softmax(prediction, dim=1)[0] * 100
    print(f"result: {classes[index[0]]}, {percentage[index[0]].item()}")

print("test finish")

      执行结果如下:以下原始测试图像来自网络,每张图像仅输出可信度值最高的一个类别。从上往下,从左往右,每张图像的分类结果依次是:goldfish(金鱼)、hen(母鸡)、ostrich(鸵鸟)、African crocodile(非洲鳄鱼)、goose(鹅)、hartebeest(羚羊)。

 

      GitHubhttps://github.com/fengbingchun/PyTorch_Test 

标签:tensor,name,height,width,shape,图像,images,AlexNet,TorchVision
来源: https://blog.csdn.net/fengbingchun/article/details/121579039

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有