ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

pytorch——linear model2

2022-02-04 11:35:45  阅读:193  来源: 互联网

标签:plt linear val model2 list pytorch pred mse data


#模型x*W+b,三维图象横坐标是w,纵坐标是b,竖坐标是损失函数
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from modulefinder import *
from mpl_toolkits.mplot3d import Axes3D
x_data=[1,2,3]
y_data=[2,4,6]
def forward(x,b):
return x*w+b
def loss(x,y):
y_pred = forward(x,b)
return (y_pred - y) * (y_pred - y)
w_list=[] #随机w
mse_list=[] #mean square error=每个w对应的损失函数
for w in np.arange(0,4,0.1):
for b in np.arange(-2.0,2.0,0.1):
print('w=',w)
print('b=',b)
l_sum=0
for x_val,y_val in zip(x_data,y_data):#将x_data和y_data用zip拼成x_val y_val
y_pred_val=forward(x_val,b) #求y尖
loss_val=loss(x_val,y_val) #预测值y^和真实值y之间的平方差,损失函数
l_sum+=loss_val #求每个样本损失函数之和
print('x=',x_val,'y=',y_val,'y^=',y_pred_val,'每个样本的损失函数:',loss_val)

print('dataset数据集的平均损失函数mse:', l_sum / 3)
w_list.append(w)#w[]列表追加元素w
mse_list.append(l_sum / 3)#mse[]列表追加元素新的平均损失函数
fig=plt.figure()
ax=Axes3D(fig)
ax.plot_surface(w, b, mse_list,rstride=1,cstride=1, cmap=plt.get_cmap('rainbow'))
plt.xlabel(r'w',fontsize=20,color='cyan')
plt.ylabel(r'b',fontsize=20,color='cyan')
ax.plot_surface(w,b=1,mse_list,)
plt.show()

标签:plt,linear,val,model2,list,pytorch,pred,mse,data
来源: https://www.cnblogs.com/xinrui-wang/p/15862582.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有