ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

TensorFlow学习记录(四)---TensorFlow自动求导机制

2021-12-08 14:33:44  阅读:225  来源: 互联网

标签:dy Variable --- print tape dx tf 求导 TensorFlow


一元函数

import tensorflow as tf
x = tf.Variable(3.)
with tf.GradientTape(persistent=True) as tape:
    y = tf.square(x)
    z = tf.pow(x,3)
dy_dx = tape.gradient(y,x)
dz_dx = tape.gradient(z,x)
print(y)
print(dy_dx)
print(z)
print(dz_dx)
del tape #使用完之后手动释放

多元函数

import tensorflow as tf
x = tf.Variable(3.)
y = tf.Variable(4.)
with tf.GradientTape(persistent=True) as tape:
   f = tf.square(x)+2*tf.square(y)+1
dy_dx,df_dy = tape.gradient(f,[x,y])

print(f)
print(dy_dx)
print(df_dy)
del tape #使用完之后手动释放

二阶导数

import tensorflow as tf
x = tf.Variable(3.)
y = tf.Variable(4.)
with tf.GradientTape(persistent=True) as tape2:
    with tf.GradientTape(persistent=True) as tape1:
        f = tf.square(x)+2*tf.square(y)+1
    firsy_grads= tape1.gradient(f,[x,y])
second_grads=tape2.gradient(firsy_grads,[x,y])

print(f)
print(firsy_grads)
print(second_grads)
del tape1
del tape2#使用完之后手动释放

TensorFlow实现一元线性回归

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
x = np.array([137.97,104.50,100.00,124.32,79.20,99.00,124.00,114.00,106.69,138.05,53.75,46.91,
              68.00,63.02,81.26,86.21])
y = np.array([145.00,110.00,93.00,116.00,65.32,104.00,118.00,91.00,62.00,133.00,51.00,45.00,78.50,69.65,75.69,95.30])
#设置超参数
learn_rate = 0.0001
iter = 10
display_step = 1
#设置模型参数初始值
np.random.seed(612)
w = tf.Variable(np.random.randn())
b= tf.Variable(np.random.randn())
#训练模型
mse = []
for i in range(0,iter+1):
    with tf.GradientTape() as tape:
        pred = w*x +b
        Loss = 0.5*tf.reduce_mean(tf.square(y-pred))
    mse.append(Loss)
    dL_dw,dL_db = tape.gradient(Loss,[w,b])
    w.assign_sub(learn_rate*dL_dw)#减等于
    b.assign_sub(learn_rate*dL_db)
    if i %display_step == 0:
        print("i:%i,Loss: %f,w:%f,b:%f"%(i,Loss,w.numpy(),b.numpy()))

标签:dy,Variable,---,print,tape,dx,tf,求导,TensorFlow
来源: https://blog.csdn.net/mabaiteng/article/details/121790407

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有