ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

第五讲 卷积神经网路-- Inception10 --cifar10

2020-05-10 09:08:31  阅读:309  来源: 互联网

标签:ch cifar10 卷积 self strides init plt Inception10 history


  1 import tensorflow as tf
  2 import os
  3 import numpy as np
  4 from matplotlib import pyplot as plt
  5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dropout, Flatten, Dense, GlobalAveragePooling2D
  6 from tensorflow.keras import Model
  7 
  8 np.set_printoptions(threshold=np.inf)
  9 
 10 ciar10 = tf.keras.datasets.cifar10
 11 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
 12 x_train, x_test = x_train/255.0, x_test/255.0
 13 
 14 class ConvBNRelu(Model):
 15     def __init__(self, ch, kernelsz=3, strides=1, padding='same'):
 16         super(ConvBNRelu, self).__init__()
 17         self.model = tf.keras.models.Sequential([
 18             Conv2D(ch, kernelsz, strides=strides, padding=padding),
 19             BatchNormalization(),
 20             Activation('relu')
 21         ])
 22 
 23     def call(self, x):
 24         x = self.model(x, training=False)
 25         #在training=False时,BN通过整个训练集计算均值、方差去做批归一化,training=True时,通过当前batch的均值、方差去做批归一化。推理时 training=False效果好
 26         return x
 27 
 28 
 29 
 30 class InceptionBlk(Model):
 31     def __init__(self, ch, strides=1):
 32         super(InceptionBlk, self).__init__()
 33         self.ch = ch
 34         self.strides = strides
 35         self.c1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
 36         self.c2_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
 37         self.c2_2 = ConvBNRelu(ch, kernelsz=3, strides=1)
 38         self.c3_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
 39         self.c3_2 = ConvBNRelu(ch, kernelsz=5, strides=1)
 40         self.p4_1 = MaxPooling2D(3, strides=1, padding='same')
 41         self.c4_2 = ConvBNRelu(ch, kernelsz=1, strides=strides)
 42 
 43     def call(self, x):
 44         x1 = self.c1(x)
 45         x2_1 = self.c2_1(x)
 46         x2_2 = self.c2_2(x2_1)
 47         x3_1 = self.c3_1(x)
 48         x3_2 = self.c3_2(x3_1)
 49         x4_1 = self.p4_1(x)
 50         x4_2 = self.c4_2(x4_1)
 51         # concat along axis=channel
 52         x = tf.concat([x1, x2_2, x3_2, x4_2], axis=1)
 53         return x
 54 
 55 class Inception10(Model):
 56     def __init__(self, num_blocks, num_classes, init_ch=16, **kwargs):
 57         super(Inception10, self).__init__(**kwargs)
 58         self.in_channels = init_ch
 59         self.out_channels = init_ch
 60         self.num_blocks = num_blocks
 61         self.init_ch = init_ch
 62         self.c1 = ConvBNRelu(init_ch)
 63         self.blocks = tf.keras.models.Sequential()
 64         for block_id in range(num_blocks):
 65             for layer_id in range(2):
 66                 if layer_id == 0:
 67                     block = InceptionBlk(self.out_channels, strides=1)
 68                 else:
 69                     block = InceptionBlk(self.out_channels, strides=1)
 70                 self.blocks.add(block)
 71             # enlarger out_channels per block
 72             self.out_channels *=2
 73         self.p1 = GlobalAveragePooling2D()
 74         self.f1 = Dense(num_classes, activation='softmax')
 75 
 76     def call(self, x):
 77         x = self.c1(x)
 78         x = self.blocks(x)
 79         x = self.p1(x)
 80         y = self.f1(x)
 81         return y
 82 
 83 model = Inception10(num_blocks=2, num_classes=10)
 84 
 85 model.compile(optimizer='adam',
 86               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
 87               metrics=['sparse_categorical_accuracy'])
 88 
 89 
 90 checkpoint_save_path = "./checkpoint/Inception10.ckpt"
 91 if os.path.exists(checkpoint_save_path + '.index'):
 92     print('-------------load the model---------------')
 93     model.load_weights(checkpoint_save_path)
 94 
 95 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_save_path,
 96                                                 save_weights_only = True,
 97                                                 save_best_only = True)
 98 
 99 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),validation_freq=1,
100                     callbacks=[cp_callback])
101 model.summary()
102 
103 
104 with open('./weights.txt', 'w') as f:
105     for v in model.trainable_variables:
106         f.write(str(v.name) + '\n')
107         f.write(str(v.shape) + '\n')
108         f.wrtte(str(v.numpy() + '\n')
109 
110 
111 
112 def plot_acc_loss_curve(history):
113     # 显示训练集和验证集的acc和loss曲线
114     from matplotlib import pyplot as plt
115     acc = history.history['sparse_categorical_accuracy']
116     val_acc = history.history['val_sparse_categorical_accuracy']
117     loss = history.history['loss']
118     val_loss = history.history['val_loss']
119     
120     plt.figure(figsize=(15, 5))
121     plt.subplot(1, 2, 1)
122     plt.plot(acc, label='Training Accuracy')
123     plt.plot(val_acc, label='Validation Accuracy')
124     plt.title('Training and Validation Accuracy')
125     #plt.legend()
126     plt.grid()
127     
128     plt.subplot(1, 2, 2)
129     plt.plot(loss, label='Training Loss')
130     plt.plot(val_loss, label='Validation Loss')
131     plt.title('Training and Validation Loss')
132     plt.legend()
133     #plt.grid()
134     plt.show()
135 
136 plot_acc_loss_curve(history)
137     

 

标签:ch,cifar10,卷积,self,strides,init,plt,Inception10,history
来源: https://www.cnblogs.com/wbloger/p/12862091.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有