ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

TensorFlow 2 Keras实现线性回归

2021-11-25 11:02:45  阅读:182  来源: 互联网

标签:keras Keras tf API 线性 TensorFlow


介绍

线性回归是入门机器学习必学的算法,其也是最基础的算法之一。

接下来,我们以线性回归为例,使用 TensorFlow 2 提供的 API 和 Eager Execution 机制对其进行实现。

线性回归是一种较为简单,但十分重要的机器学习方法,它也是神经网络的基础。

如下所示,线性回归要解决的问题就是如何找到最理想的直线去拟合散点样本

image-20211012155654480

对于一个线性回归问题,一般来讲有 2 种解决方法,分别是:

  • 最小二乘法
    • 代数求解
    • 矩阵求解
  • 梯度下降法。

本次,我们将使用梯度下降方法来解决线性回归问题。

Keras 方式实现

配合 TensorFlow 提供的高阶 API,我们省去了定义线性函数,定义损失函数,以及定义优化算法等 3 个步骤。

不过,高阶 API 实现过程实际上还不够精简,我们可以完全使用 TensorFlow Keras API 来实现线性回归

Keras 本来是一个用 Python 编写的独立高阶神经网络 API,它能够以 TensorFlow, CNTK,或者 Theano 作为后端运行

目前,TensorFlow 已经吸纳 Keras,并组成了 tf.keras 模块。官方介绍,tf.keras 和单独安装的 Keras 略有不同,但考虑到未来的发展趋势,主要以学习 tf.keras 为主。

image-20211012160727469

初始化

image-20211012162409785

我们这里使用 Keras 提供的 Sequential 顺序模型结构向其中添加一个线性层。不同的地方在于,Keras 顺序模型第一层为线性层时,规定需指定输入维度,这里为 input_dim=1

image-20211012161515170

接下来,直接使用 .compile 编译模型,指定损失函数为 MSE 平方损失函数,优化器选择 SGD 随机梯度下降。然后,就可以使用 .fit 传入数据开始迭代了。

batch_size 是采用小批次训练的参数,主要用于解决一次性传入数据过多无法训练的问题。当然,由于示例数据本身较少,这里意义不大,但还是按照常规使用方法进行设置。

你会发现,完全使用 Keras 高阶 API 实际上只需要 4 行核心代码即可完成,相比于低阶 API 简化了很多。

完整代码:

import tensorflow as tf
TRUE_W = 3.0
TRUE_b = 2.0
NUM_SAMPLES = 100

X = tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
noise = tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
y = X * TRUE_W + TRUE_b + noise

# 模型训练
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=1,input_dim=1))
model.compile(optimizer='sgd',loss='mse')
model.fit(X,y,epochs=10,batch_size=32)

标签:keras,Keras,tf,API,线性,TensorFlow
来源: https://blog.csdn.net/qq_41103843/article/details/121473640

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有