标签:loss 训练 batch sample item pytorch train 内存 model
代码中存在累加loss,但每步的loss没加item()。
pytorch中,.item()方法 是得到一个元素张量里面的元素值
具体就是 用于将一个零维张量转换成浮点数,比如计算loss,accuracy的值
就比如:
loss = (y_pred - y).pow(2).sum()
print(loss.item())
for epoch in range(100):
index=np.arange(train_sample.shape[0])
np.random.shuffle(index)
train_set=train_sample[index].tolist()
model.train()
loss,s=0,0
for s in tqdm(range(0,train_sample.shape[0],batch_size)):
if s+batch_size>train_sample.shape[0]:
break
batch_loss=model(train_set[s:s+batch_size])
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
# 会导致内存一直增加,需改为loss+=batch_loss.item()
loss+=batch_loss
s+=batch_size
loss/=total_batch
print(epoch,loss)
if (epoch+1) % 10 ==0:
model.eval()
model.save_embedding(epoch)
以上代码会导致内存占用越来越大,解决的方法是:loss+=batch_loss.item()。值得注意的是,要复现内存越来越大的问题,模型中需要切换model.train() 和 model.eval(),train_loss以及eval_loss的作用是保存模型的平均误差(这里是累积误差),保存到tensorboard中。
标签:loss,训练,batch,sample,item,pytorch,train,内存,model 来源: https://blog.csdn.net/u013602059/article/details/120370461
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。