ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题

2019-07-23 15:00:42  阅读:375  来源: 互联网

标签:plot ResNet50 generator img VGG19 base InceptionV3 model size


实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题

  参考博客:::https://blog.csdn.net/pengdali/article/details/79050662     2018年01月13日 12:52:14  阅读数 10417  

一、实践流程

1、数据预处理

主要是对训练数据进行随机偏移、转动等变换图像处理,这样可以尽可能让训练数据多样化

另外处理数据方式采用分批无序读取的形式,避免了数据按目录排序训练

 

  1.   #数据准备
  2.   def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
  3.   if is_train:
  4.   datagen = ImageDataGenerator(rescale=1./255,
  5.   zoom_range=0.25, rotation_range=15.,
  6.   channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
  7.   horizontal_flip=True, fill_mode='constant')
  8.   else:
  9.   datagen = ImageDataGenerator(rescale=1./255)
  10.    
  11.   generator = datagen.flow_from_directory(
  12.   dir_path, target_size=(img_row, img_col),
  13.   batch_size=batch_size,
  14.   shuffle=is_train)
  15.    
  16.   return generator
2、载入现有模型

 

这个部分是核心工作,目的是使用ImageNet训练出的权重来做我们的特征提取器,注意这里后面的分类层去掉

 

  1.   base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
  2.   input_shape=(img_rows, img_cols, color),
  3.   classes=nb_classes)

然后是冻结这些层,因为是训练好的

 

  1.   for layer in base_model.layers:
  2.   layer.trainable = False
而分类部分,需要我们根据现有需求来新定义的,这里可以根据实际情况自己进行调整,比如这样
  1.   x = base_model.output
  2.   # 添加自己的全链接分类层
  3.   x = GlobalAveragePooling2D()(x)
  4.   x = Dense(1024, activation='relu')(x)
  5.   predictions = Dense(nb_classes, activation='softmax')(x)
或者

 

  1.   x = base_model.output
  2.   #添加自己的全链接分类层
  3.   x = Flatten()(x)
  4.   predictions = Dense(nb_classes, activation='softmax')(x)
3、训练模型

这里我们用fit_generator函数,它可以避免了一次性加载大量的数据,并且生成器与模型将并行执行以提高效率。比如可以在CPU上进行实时的数据提升,同时在GPU上进行模型训练

 

  1.   history_ft = model.fit_generator(
  2.   train_generator,
  3.   steps_per_epoch=steps_per_epoch,
  4.   epochs=epochs,
  5.   validation_data=validation_generator,
  6.   validation_steps=validation_steps)

二、猫狗大战数据集

 

训练数据540M,测试数据270M,大家可以去官网下载

https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

下载后把数据分成dog和cat两个目录来存放

三、训练

训练的时候会自动去下权值,比如vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5,但是如果我们已经下载好了的话,可以改源代码,让他直接读取我们的下载好的权值,比如在resnet50.py中

1、VGG19

vgg19的深度有26层,参数达到了549M,原模型最后有3个全连接层做分类器所以我还是加了一个1024的全连接层,训练10轮的情况达到了89%

2、ResNet50

ResNet50的深度达到了168层,但是参数只有99M,分类模型我就简单点,一层直接分类,训练10轮的达到了96%的准确率

3、inception_v3

InceptionV3的深度159层,参数92M,训练10轮的结果

这是一层直接分类的结果

这是加了一个512全连接的,大家可以随意调整测试

 

四、完整的代码

 

  1.   # -*- coding: utf-8 -*-
  2.   import os
  3.   from keras.utils import plot_model
  4.   from keras.applications.resnet50 import ResNet50
  5.   from keras.applications.vgg19 import VGG19
  6.   from keras.applications.inception_v3 import InceptionV3
  7.   from keras.layers import Dense,Flatten,GlobalAveragePooling2D
  8.   from keras.models import Model,load_model
  9.   from keras.optimizers import SGD
  10.   from keras.preprocessing.image import ImageDataGenerator
  11.   import matplotlib.pyplot as plt
  12.    
  13.   class PowerTransferMode:
  14.   #数据准备
  15.   def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
  16.   if is_train:
  17.   datagen = ImageDataGenerator(rescale=1./255,
  18.   zoom_range=0.25, rotation_range=15.,
  19.   channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
  20.   horizontal_flip=True, fill_mode='constant')
  21.   else:
  22.   datagen = ImageDataGenerator(rescale=1./255)
  23.    
  24.   generator = datagen.flow_from_directory(
  25.   dir_path, target_size=(img_row, img_col),
  26.   batch_size=batch_size,
  27.   #class_mode='binary',
  28.   shuffle=is_train)
  29.    
  30.   return generator
  31.    
  32.   #ResNet模型
  33.   def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
  34.   color = 3 if RGB else 1
  35.   base_model = ResNet50(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
  36.   classes=nb_classes)
  37.    
  38.   #冻结base_model所有层,这样就可以正确获得bottleneck特征
  39.   for layer in base_model.layers:
  40.   layer.trainable = False
  41.    
  42.   x = base_model.output
  43.   #添加自己的全链接分类层
  44.   x = Flatten()(x)
  45.   #x = GlobalAveragePooling2D()(x)
  46.   #x = Dense(1024, activation='relu')(x)
  47.   predictions = Dense(nb_classes, activation='softmax')(x)
  48.    
  49.   #训练模型
  50.   model = Model(inputs=base_model.input, outputs=predictions)
  51.   sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
  52.   model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  53.    
  54.   #绘制模型
  55.   if is_plot_model:
  56.   plot_model(model, to_file='resnet50_model.png',show_shapes=True)
  57.    
  58.   return model
  59.    
  60.    
  61.   #VGG模型
  62.   def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
  63.   color = 3 if RGB else 1
  64.   base_model = VGG19(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
  65.   classes=nb_classes)
  66.    
  67.   #冻结base_model所有层,这样就可以正确获得bottleneck特征
  68.   for layer in base_model.layers:
  69.   layer.trainable = False
  70.    
  71.   x = base_model.output
  72.   #添加自己的全链接分类层
  73.   x = GlobalAveragePooling2D()(x)
  74.   x = Dense(1024, activation='relu')(x)
  75.   predictions = Dense(nb_classes, activation='softmax')(x)
  76.    
  77.   #训练模型
  78.   model = Model(inputs=base_model.input, outputs=predictions)
  79.   sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
  80.   model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  81.    
  82.   # 绘图
  83.   if is_plot_model:
  84.   plot_model(model, to_file='vgg19_model.png',show_shapes=True)
  85.    
  86.   return model
  87.    
  88.   # InceptionV3模型
  89.   def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True,
  90.   is_plot_model=False):
  91.   color = 3 if RGB else 1
  92.   base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
  93.   input_shape=(img_rows, img_cols, color),
  94.   classes=nb_classes)
  95.    
  96.   # 冻结base_model所有层,这样就可以正确获得bottleneck特征
  97.   for layer in base_model.layers:
  98.   layer.trainable = False
  99.    
  100.   x = base_model.output
  101.   # 添加自己的全链接分类层
  102.   x = GlobalAveragePooling2D()(x)
  103.   x = Dense(1024, activation='relu')(x)
  104.   predictions = Dense(nb_classes, activation='softmax')(x)
  105.    
  106.   # 训练模型
  107.   model = Model(inputs=base_model.input, outputs=predictions)
  108.   sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
  109.   model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  110.    
  111.   # 绘图
  112.   if is_plot_model:
  113.   plot_model(model, to_file='inception_v3_model.png', show_shapes=True)
  114.    
  115.   return model
  116.    
  117.   #训练模型
  118.   def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):
  119.   # 载入模型
  120.   if is_load_model and os.path.exists(model_url):
  121.   model = load_model(model_url)
  122.    
  123.   history_ft = model.fit_generator(
  124.   train_generator,
  125.   steps_per_epoch=steps_per_epoch,
  126.   epochs=epochs,
  127.   validation_data=validation_generator,
  128.   validation_steps=validation_steps)
  129.   # 模型保存
  130.   model.save(model_url,overwrite=True)
  131.   return history_ft
  132.    
  133.   # 画图
  134.   def plot_training(self, history):
  135.   acc = history.history['acc']
  136.   val_acc = history.history['val_acc']
  137.   loss = history.history['loss']
  138.   val_loss = history.history['val_loss']
  139.   epochs = range(len(acc))
  140.   plt.plot(epochs, acc, 'b-')
  141.   plt.plot(epochs, val_acc, 'r')
  142.   plt.title('Training and validation accuracy')
  143.   plt.figure()
  144.   plt.plot(epochs, loss, 'b-')
  145.   plt.plot(epochs, val_loss, 'r-')
  146.   plt.title('Training and validation loss')
  147.   plt.show()
  148.    
  149.    
  150.   if __name__ == '__main__':
  151.   image_size = 197
  152.   batch_size = 32
  153.    
  154.   transfer = PowerTransferMode()
  155.    
  156.   #得到数据
  157.   train_generator = transfer.DataGen('data/cat_dog_Dataset/train', image_size, image_size, batch_size, True)
  158.   validation_generator = transfer.DataGen('data/cat_dog_Dataset/test', image_size, image_size, batch_size, False)
  159.    
  160.   #VGG19
  161.   #model = transfer.VGG19_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
  162.   #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model=False)
  163.    
  164.   #ResNet50
  165.   model = transfer.ResNet50_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
  166.   history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'resnet50_model_weights.h5', is_load_model=False)
  167.    
  168.   #InceptionV3
  169.   #model = transfer.InceptionV3_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=True)
  170.   #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model=False)
  171.    
  172.   # 训练的acc_loss图
  173.   transfer.plot_training(history_ft)

标签:plot,ResNet50,generator,img,VGG19,base,InceptionV3,model,size
来源: https://www.cnblogs.com/shuimuqingyang/p/11231748.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有