ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

vgg16复现

2021-10-03 15:01:19  阅读:168  来源: 互联网

标签:__ nn vgg16 self 复现 import data size


主要是练了一下数据读取

这次用的cifa10,整个是一个字典,取了前100个去训练了一下

要先把每一行reshape成32 * 32 * 3

self.data = self.data.reshape(-1, 32, 32, 3)

 __getitem__ 里放到tranforms之前先Image.fromarray()

 

VGG_dataset:

from torch.utils import data
from PIL import Image
import random
import torchvision.transforms as T
import matplotlib.pyplot as plt

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

# imgs = unpickle('H:/DataSet/cifar-10-python/cifar-10-batches-py/data_batch_1')
# print(imgs[b'data'].reshape(-1, 3, 32, 32))



class Dataset(data.Dataset):
    def __init__(self, root, train = True, test = False):
        self.test = test
        self.train = train
        imgs = unpickle(root)
        self.data = imgs[b'data'][: 100, :]
        self.data = self.data.reshape(-1, 32, 32, 3)
        self.label = imgs[b'labels'][: 100]

        if self.train:
            self.transforms = T.Compose([
                T.Scale(random.randint(256, 384)),
                T.RandomCrop(224),
                T.ToTensor(),
                T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
            ])
        elif self.test:
            self.transforms = T.Compose([
                T.Scale(224),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    def __getitem__(self, index):
        data = Image.fromarray(self.data[index])
        data = self.transforms(data)
        return data, self.label[index]
    def __len__(self):
        return len(self.label)

 

config:

class configuration:
    train_root = 'H:/DataSet/cifar-10-python/cifar-10-batches-py/data_batch_1'
    test_root = 'H:/DataSet/cifar-10-python/cifar-10-batches-py/test_batch'
    label_nums = 10
    batch_size = 4
    epochs = 10
    lr = 0.01

VGG:

import torch
import torch.nn as nn
import torch.utils.data.dataloader as Dataloader
import numpy as np
import torch.nn.functional as F
from config import configuration
from VGG_dataset import Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

device = 'cuda' if torch.cuda.is_available() else 'cpu'

con = configuration()

class vgg(nn.Module):
    def __init__(self):
        super(vgg, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size = 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64,kernel_size = 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size = 3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size = 3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size = 3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size = 3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(256, 512, kernel_size = 3, stride=1, padding=1)
        self.conv9 = nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1)
        self.conv10 = nn.Conv2d(512, 512,  kernel_size=3, stride=1, padding=1)
        self.conv11 = nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1)
        self.conv12 = nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1)
        self.conv13 = nn.Conv2d(512, 512,  kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(512 * 7 * 7, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, con.label_nums)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(F.relu(self.conv4(x)), 2)
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.max_pool2d(F.relu(self.conv7(x)), 2)
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        x = F.max_pool2d(F.relu(self.conv10(x)), 2)
        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        x = F.max_pool2d(F.relu(self.conv13(x)), 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# img = Image.open('H:/C5AM385_Intensity.jpg')
# print(np.array(img).shape)


if __name__ == '__main__':
    model = vgg()
    model.to(device)
    train_dataset = Dataset(con.train_root)
    test_dataset = Dataset(con.test_root, False, True)
    train_dataloader = Dataloader.DataLoader(train_dataset, batch_size = con.batch_size, shuffle = True, num_workers = 4)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = con.lr)

    for epoch in range(con.epochs):
        total_loss = 0
        cnt = 0
        true_label = 0
        for data, label in tqdm(train_dataloader):
            # print(np.array(data[0]).shape)
            # plt.imshow(data[0])
            # plt.show()

            optimizer.zero_grad()
            data.to(device)
            label.to(device)
            output = model(data)
            loss_value = loss(output, label)
            loss_value.backward()
            optimizer.step()
            output = torch.max(output, 1)[1]
            total_loss += loss_value
            true_label += torch.sum(output == label)
            cnt += 1
        loss_mean = total_loss / float(cnt)
        accuracy = true_label / float(len(train_dataset))
        print('Loss:{:.4f}, Accuracy:{:.2f}'.format(loss_mean, accuracy))
    print('Train Accepted!')

 

标签:__,nn,vgg16,self,复现,import,data,size
来源: https://www.cnblogs.com/WTSRUVF/p/15364206.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有