ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

优化算法篇

2022-09-06 23:32:46  阅读:192  来源: 互联网

标签:优化 epoch data 算法 cost lis xs grad


 梯度下降与随机梯度下降:

import torch
import matplotlib.pyplot as plt
import numpy as np
x_data = [5,6,7,8.5,9,10,11.5,12]
y_data = [1,2,8,4,5,6.5,7.5,8]

w = 1
#初始权重

def forward(x):
    return x * w

#MSE
def cost(xs,ys):
    cost = 0
    for x,y in zip(xs,ys):
        y_pred = forward(x)
        cost += (y-y_pred)**2
    return cost/len(xs)

def SGD_loss(xs,ys):
    y_pred = forward(xs)
    return (y_pred - ys)**2

def SGD_gradient(xs,ys):
    return 2*xs*(xs*w-ys)

def gradient(xs,ys):
    grad = 0
    for x,y in zip(xs,ys):
        grad += 2*x*(x*w-y)
    return grad/len(xs)

def draw(x,y):
    fig = plt.figure(num=1, figsize=(4, 4))
    ax = fig.add_subplot(111)
    ax.plot(x,y)
    plt.show()

# epoch_lis  =[]
# loss_lis = []
# learning_rate = 0.012
#
# for epoch in range(100):
#     cost_val = cost(x_data,y_data)
#     grad_val = gradient(x_data,y_data)
#     w -= learning_rate*grad_val
#     print("Epoch = {} w = {} loss = {} ".format(epoch,w,cost_val))
#     epoch_lis.append(epoch)
#     loss_lis.append(cost_val)
# print(forward(4))
# draw(epoch_lis,loss_lis)
# draw(x_data,y_data)


l_lis= []
epoch = []
learning_rate = 0.009
#SGD
for epoch in range(10):
    for x,y in zip(x_data,y_data):
        grad = SGD_gradient(x,y)
        w -= learning_rate*grad
        print(" x:{}  y:{}   grad:{}".format(x,y,grad))
        l = SGD_loss(x,y)
        print("loss: ",l)
        l_lis.append(l)

X = [int(i) for i in range(len(l_lis))]
draw(X,l_lis)

 

标签:优化,epoch,data,算法,cost,lis,xs,grad
来源: https://www.cnblogs.com/MrMKG/p/16660115.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有