ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Numpy 写3层神经网络拟合sinx

2022-02-25 22:33:03  阅读:203  来源: 互联网

标签:sinx pred self random lr 拟合 np Numpy out


代码

# -*- coding: utf-8 -*-
"""
Created on Wed Feb 23 20:37:01 2022

@author: koneko
"""
import numpy as np
import matplotlib.pyplot as plt


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def mean_squared_error(y, t):
    return 0.5 * np.sum((y-t)**2)


class Sigmoid:
    def __init__(self):
        self.out = None
        
    def forward(self, x):
        out = sigmoid(x)
        self.out = out
        return out
    
    def backward(self, dout):
        dx = dout * (1.0 - self.out) * self.out 
        return dx
    
 
    
x = np.linspace(-np.pi, np.pi, 1000)
y = np.sin(x)
plt.plot(x,y)
x = x.reshape(1, x.size)
y = y.reshape(1, y.size)

# 初始化权重
W1 = np.random.randn(3,1)
b1 = np.random.randn(3,1)

W2 = np.random.randn(2,3)
b2 = np.random.randn(2,1)

W3 = np.random.randn(1,2)
b3 = np.random.randn(1,1)


sig1 = Sigmoid()

sig2 = Sigmoid()

lr = 0.001


for i in range(30000):
    a1 = W1 @ x + b1
    c1 = sig1.forward(a1)
    
    a2 = W2 @ c1 + b2
    c2 = sig2.forward(a2)
    
    y_pred = W3 @ c2 + b3
    
    #y_pred = W2 @ c1 + b2
    
    Loss = mean_squared_error(y, y_pred)
    print(f"Loss[{i}]: {Loss}")
    
    dy_pred = y_pred - y
    
    dc2 = W3.T @ dy_pred
    da2 = sig2.backward(dc2)
    
    dc1 = W2.T @ da2
    da1 = sig1.backward(dc1)
    
    # 计算Loss对各层参数的偏导数

    dW3 = dy_pred @ c2.T
    db3 = np.sum(dy_pred)
    
    dW2 = da2 @ c1.T
    db2 = np.sum(da2, axis=1)
    db2 = db2.reshape(db2.size, 1)
    
    dW1 = da1 @ x.T
    db1 = np.sum(da1, axis=1)
    db1 = db1.reshape(db1.size, 1)
    
    W3 -= lr*dW3
    b3 -= lr*db3
    W2 -= lr*dW2
    b2 -= lr*db2
    W1 -= lr*dW1
    b1 -= lr*db1
    
    if i % 100 == 99:
        plt.cla()
        plt.plot(x.T,y.T)
        plt.plot(x.T,y_pred.T)
    

    

效果



标签:sinx,pred,self,random,lr,拟合,np,Numpy,out
来源: https://www.cnblogs.com/urahyou/p/15937995.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有