标签:tensorflow2 Cos warmup max self bat step warm lr
class WarmUpCos(keras.callbacks.Callback): def __init__(self, lr_max,lr_min, warm_step,sum_step,bat): super(WarmUpCos, self).__init__() self.lr_max = lr_max self.lr_min = lr_min self.warm_step = warm_step self.sum_step = sum_step self.bat = bat def on_train_begin(self, batch, logs=None): self.init_lr = self.lr_max self.step = 0 def on_epoch_begin(self, epoch, logs=None): self.epoch = epoch def on_batch_end(self,batch, logs=None): self.step += 1 print('step:',self.step) # learning_decay_steps = 1 # learning_decay_rate = 0.999 warm_lr = self.lr_max * (self.step / self.warm_step) # decay_lr = max(self.init_lr * tf.pow(learning_decay_rate , ((step-self.warm_step) / learning_decay_steps)),self.lr_min) decay_lr = self.lr_max * ( 1 + math.cos( (self.step - self.warm_step) * math.pi / ( self.sum_step - self.warm_step) ) ) / 2 if self.step < self.warm_step: lr = warm_lr else: lr =decay_lr K.set_value(self.model.optimizer.lr, lr) warm_up = WarmUpCos(lr_rate, lr_min, warm_step=warm_epoch*int(train_x.shape[0]//bat),bat=bat,sum_step=epochs*int(train_x.shape[0]//bat)) s_model.fit(train_db, epochs=epochs, validation_data=test_db, callbacks=[warm_up])
搜索
复制
标签:tensorflow2,Cos,warmup,max,self,bat,step,warm,lr 来源: https://www.cnblogs.com/cxhzy/p/16496377.html
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。