ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

卷积神经网络 --Net in Net

2020-06-08 21:53:53  阅读:322  来源: 互联网

标签:卷积 self batch images 神经网络 train tf Net size


 1 import tensorflow as tf
 2 print(tf.__version__)
 3 
 4 
 5 for gpu in tf.config.experimental.list_physical_devices('GPU'):
 6   tf.config.experimental.set_memory_growth(gpu, True)
 7 
 8 
 9 def nin_block(num_channels, kernel_size, strides, padding):
10   blk = tf.keras.models.Sequential()
11   blk.add(tf.keras.layers.Conv2D(num_channels, kernel_size, strides=strides, 
12                                  padding=padding, activation='relu'))
13   blk.add(tf.keras.layers.Conv2D(num_channels, kernel_size=1, activation='relu'))
14   blk.add(tf.keras.layers.Conv2D(num_channels, kernel_size=1, activation='relu'))
15   return blk
16 
17 
18 net = tf.keras.models.Sequential()
19 net.add(nin_block(96, kernel_size=11, strides=4, padding='valid'))
20 net.add(tf.keras.layers.MaxPool2D(pool_size=3, strides=2))
21 net.add(nin_block(256, kernel_size=5, strides=1, padding='same'))
22 net.add(tf.keras.layers.MaxPool2D(pool_size=3, strides=2))
23 net.add(nin_block(384, kernel_size=3, strides=1, padding='same'))
24 net.add(tf.keras.layers.MaxPool2D(pool_size=3, strides=2))
25 net.add(tf.keras.layers.Dropout(0.5))
26 net.add(nin_block(10, kernel_size=3, strides=1, padding='same'))
27 net.add(tf.keras.layers.GlobalAveragePooling2D())
28 net.add(tf.keras.layers.Flatten())
29 
30 
31 X = tf.random.uniform((1, 224, 224, 1))
32 for blk in net.layers:
33   X = blk(X)
34   print(blk.name, 'output shape: \t', X.shape)
35 
36 
37 import numpy as np
38 
39 class DataLoader():
40   def __init__(self):
41     fashion_mnist = tf.keras.datasets.fashion_mnist
42     (self.train_images, self.train_labels), (self.test_images, self.test_labels) = fashion_mnist.load_data()
43     self.train_images = np.expand_dims(self.train_images.astype(np.float32)/255.0, axis=-1)
44     self.test_images = np.expand_dims(self.test_images.astype(np.float32)/255.0, axis=-1)
45     self.train_labels = self.train_labels.astype(np.int32)
46     self.test_labels = self.test_labels.astype(np.int32)
47     self.num_train, self.num_test = self.train_images.shape[0], self.test_images.shape[0]
48 
49   def get_batch_train(self, batch_size):
50     index = np.random.randint(0, np.shape(self.train_images)[0], batch_size)
51     #need to resize images to (224, 224)
52     resized_images = tf.image.resize_with_pad(self.train_images[index], 224, 224,)
53     return resized_images.numpy(), self.train_labels[index]
54 
55   def get_batch_test(self, batch_size):
56     index = np.random,randint(0, np.shape(self.test_images)[0], batch_size)
57     #need to resize to (224, 224)
58     resized_images = tf.image.resize_with_pad(self.test_images[index], 224, 224,)
59     return resized_images.numpy(), self.test_labels[index]
60 
61 batch_size = 128
62 dataLoader = DataLoader()
63 x_batch, y_batch = dataLoader.get_batch_train(batch_size)
64 print('x_batch shape:', x_batch.shape, 'y_batch shape:', y_batch.shape)
65 
66 
67 
68 def train_nin():
69   #net.load_weights('5.8_nin_weights.h5')
70   epoch = 5
71   num_iter = dataLoader.num_train//batch_size
72   for e in range(epoch):
73     for n in range(num_iter):
74       x_batch, y_batch = dataLoader.get_batch_train(batch_size)
75       net.fit(x_batch, y_batch)
76       if n%200 == 0:
77         net.save_weights('5.8_nin_weights.h5')
78 
79 
80 #optimizer = tf.keras.optimizers.SGD(learning_rate=0.06, momentum=0.3, nesterov=False)
81 optimizer = tf.keras.optimizers.Adam(lr=1e-6)
82 net.compile(optimizer=optimizer, 
83             loss='sparse_categorical_crossentropy',
84             metrics=['accuracy'])
85 x_batch, y_batch = dataLoader.get_batch_train(batch_size)
86 net.fit(x_batch, y_batch)
87 train_nin()

 

标签:卷积,self,batch,images,神经网络,train,tf,Net,size
来源: https://www.cnblogs.com/wbloger/p/13068768.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有