ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

机器学习库JAX教程—基于python的高性能数值计算库JAX介绍

2023-01-15 13:58:53  阅读:389  来源: 互联网

标签:机器学习库JAX python高性能数值计算库JAX 基于numpy的机器学习库


目前我们用于机器学习的数值计算库主要还是numpy,但是numpy的函数太过于庞杂,需要大量的时间来熟悉;而这款JAX是基于numpy的高性能数值计算库,目前已经在机器学习领域崭露头角,其雄心是让机器学习变得简单而高效。今天我们就来系统的学习一下这个高性能的python数值计算库JAX。今天就跟随icode9的小编一起来系统的学习下jax这个库吧

什么是用于机器学习的 JAX?
JAX 是一个专为高性能数值计算,尤其是机器学习研究而设计的 Python 库。它的数值函数 API 基于NumPy,这是一个用于科学计算的函数集合。JAX专注于通过使用 XLA 在 GPU 上编译 NumPy 函数来加速机器学习过程,并使用 autograd 来区分 Python 和 NumPy 函数以及基于梯度的优化。JAX 能够通过循环、分支、递归和闭包进行微分,并使用 GPU 加速轻松地对导数的导数进行导数。JAX 还支持反向传播和前向模式微分。

JAX 在使用 GPU 运行代码时提供卓越的性能,并提供实时 (JIT) 编译选项以轻松加速大型项目,我们将在本文后面深入探讨。 

将 JAX 视为一个 Python 库,它通过函数转换修改 NumPy 和 Python 代码以实现加速机器学习。作为一般规则,只要计划使用 GPU 进行训练、计算梯度(autograd)或使用 JIT 代码编译,就应该使用 JAX。

为什么使用 JAX?
除了与普通 CPU 一起工作之外,JAX 的主要功能是能够与 GPU 等不同的处理单元完全兼容。这为 JAX 提供了优于类似包的巨大优势,因为在图像和矢量处理方面,使用 GPU 并行化可实现比 CPU 更快的性能。 

这一点非常重要,因为当使用 NumPy 库时,用户可以构建特殊大小的矩阵,从而使 GPU 在处理此类数据格式时更加高效。 

这种时间差异使 JAX 库能够通过几个关键实现在速度和性能上超过 NumPy 100 多倍:

向量化——将多个数据作为单个指令处理,为线性代数计算和机器学习提供了极大的加速
代码并行化——获取在单个处理器上运行的串行代码并将其分发的过程。GPU 在这里是首选,因为它们有许多专门用于计算的处理器。

自动微分 - 非常简单直接的微分,可以多次链接以轻松评估高阶导数。
如何安装 JAX
要安装仅 CPU 版本的 JAX,这可能对在笔记本电脑上进行本地开发很有用,您可以运行


pip install --upgrade pip
pip install --upgrade "jax[cpu]"


在Linux上,往往需要先将pip更新到支持manylinux2014 wheels的版本。

pip 安装:GPU (CUDA)
要安装同时支持 CPU 和 NVIDIA GPU 的 JAX,您必须先安装CUDA和CuDNN(如果尚未安装)。与许多其他流行的深度学习系统不同,JAX 没有将 CUDA 或 CuDNN 捆绑为 pip 包的一部分。

JAX 为 Linux 提供预构建的 CUDA 兼容轮子,支持 CUDA 11.1 或更新版本,以及 CuDNN 8.0.5 或更新版本。操作系统、CUDA 和 CuDNN 的其他组合是可能的,但需要从源代码构建。

需要 CUDA 11.1 或更新版本
如果您从源代码构建,您也许可以使用较旧的 CUDA 版本,但所有 11.1 之前的 CUDA 版本都存在已知的 CUDA 错误,因此我们不会为较旧的 CUDA 版本提供预构建的二进制文件。

预制车轮支持的 cuDNN 版本是:
cuDNN 8.2 或更新版本。如果您的 cuDNN 安装足够新,我们建议使用 cuDNN 8.2 wheel,因为它支持附加功能。
cuDNN 8.0.5 或更新版本。

您必须使用至少与CUDA 工具包对应的驱动程序版本一样新的 NVIDIA 驱动程序版本。例如,如果您安装了 CUDA 11.4 update 4,则在 Linux 上必须使用 NVIDIA 驱动程序 470.82.01 或更新版本。这是一项严格的要求,因为 JAX 依赖于 JIT 编译代码;较旧的驱动程序可能会导致故障。

如果您需要将较新的 CUDA 工具包与较旧的驱动程序一起使用,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容包。

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels are only available on Linux.
pip install --upgrade "jax[cuda]" https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


jaxlib 版本必须与您要使用的现有 CUDA 安装版本相对应。您可以为 jaxlib 明确指定特定的 CUDA 和 CuDNN 版本:


pip install --upgrade pip
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


您可以使用以下命令找到您的 CUDA 版本:


nvcc --version

某些 GPU 功能要求 CUDA 安装在 /usr/local/cuda-XX,其中 XX 应替换为 CUDA 版本号(例如 cuda-11.1)。如果 CUDA 安装在系统的其他地方,您可以创建一个符号链接:


sudo ln -s /path/to/cuda /usr/local/cuda-X.X


比较 JAX 和 NumPy
由于 JAX 是增强型 NumPy,因此它们的语法非常相似,使用户能够在 NumPy 或 JAX 不执行的项目中互换使用这两者。这通常适用于较小的项目,在这些项目中,加速量在节省的时间上可以忽略不计。但是,随着模型变得越来越大,您应该越多地考虑 JAX。

比较 JAX 和 NumPy
使用 JAX 与 NumPy 相乘两个矩阵
为了清楚地说明这两个库之间的速度差异,我们将使用两者将两个矩阵相互相乘,然后检查仅 CPU 和 GPU 之间的性能差异。我们还将检查由JIT 编译器引起的性能提升。 

要按照本教程进行操作,请安装并导入 JAX 和 NumPy 库(来自上一步)。您可以在Kaggle或Google Colab等网站上测试您的代码。与任何库一样,您应该通过在代码开头编写以下行来导入 JAX:

Python
import jax.numpy as jnp
from jax import random


您还可以用类似的方式导入 NumPy 库:

Python
import numpy as np


接下来,我们将通过在 Python 中将两个矩阵相乘来比较使用 CPU 和 GPU 的 JAX 和 Numpy 的性能。对于这些基准,越低越好。

CPU 上的 NumPy
首先,我们将使用 NumPy 创建一个 5,000 x 5,000 的矩阵并快速测试其性能。

Python
import numpy as np

size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)


每个循环 785 毫秒

在 NumPy 上运行的单个代码循环每个循环大约需要 750 毫秒才能运行。

CPU 上的 JAX
现在让我们运行相同的代码,但这次使用 JAX 库。

Python
import jax.numpy as jnp

size = 5000
x = jnp.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()


每个循环 1.43 秒

如您所见,比较 JAX 和 NumPy 仅 CPU 的性能表明 NumPy 是更快的选择。虽然 JAX 可能无法为普通 CPU 提供最佳性能,但它确实可以为 GPU 提供更好的性能。

带 GPU 的 JAX
现在,让我们尝试创建相同的 5,000 x 5,000 矩阵,这次使用 JAX 和 GPU 而不是常规 CPU:

Python
import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
size = 5000

x = random.normal(key, (size, size)).astype(jnp.float32)
%time x_jax = jax.device_put(x)
%time jnp.dot(x_jax, x_jax.T).block_until_ready()
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()


每个循环 80.6 毫秒

正如在 GPU 而不是 CPU 上运行 JAX 时所清楚显示的那样,我们实现了每个循环大约 80 毫秒的更好时间(性能的大约 15 倍)。当使用更大的矩阵或时间尺度时,这将更容易看到。

即时编译 (JIT)
使用jit 命令,我们的代码将使用特定的 XLA 编译器进行编译,从而使我们的功能能够高效地执行。

XLA是加速线性代数的缩写,JAX 和 Tensorflow 等库使用它来以更高的效率在 GPU 上编译和运行代码。所以总而言之,XLA 是一种特定的线性代数编译器,能够以更高的速度编译代码。

我们将使用selu_np函数测试我们的代码,该函数代表 Scaled Exponential Linear Unit,并检查普通 CPU 上的 NumPy 与使用 JIT 在 GPU 上运行 JAX 之间的不同时间性能。

Python
def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


CPU 上的 NumPy 
首先,我们将使用 NumPy 库创建一个大小为 1,000,000 的向量。

Python
import numpy as np

x = np.random.normal(size=(1000000,)).astype(np.float32)
%timeit selu_np(x)


每个循环 8.3 毫秒

带有 JIT 的 GPU 上的 JAX
现在我们将在 GPU 上使用 JAX 和 JIT 测试我们的代码。

import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit

key = random.PRNGKey(0)

def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))

selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x) 
%time selu_jax_jit(x_jax).block_until_ready() 
%timeit selu_jax_jit(x_jax).block_until_ready() 


每个循环 153 µs  (每个循环 0.153 毫秒)

最后,当使用带有 GPU 的 JIT 编译器时,我们获得了比使用普通 GPU 更好的性能。您可以清楚地看到差异非常明显,从 NumPy 到 JAX 和 JIT,速度提高了近 5000% 或快了 50 倍!

将 JAX 视为对 NumPy 的修改,以使用 GPU 加速机器学习。由于 NumPy 只能在 CPU 上编译,如果您选择在 GPU 上执行代码,JAX 比 NumPy 更快。作为一般规则,只要计划将 NumPy 与 GPU 一起使用或使用 JIT 代码编译,就应该使用 JAX。

注意:要查看使用本教程示例的原始文章,请查看以下链接: 原始代码。

JAX 限制:纯函数
JAX 转换和并发症是为功能纯的 Python 函数设计的。纯函数不能通过访问外部变量来改变程序的状态,也不能对像 print() 这样的输入/输出流等函数产生副作用。

连续运行会导致这些副作用无法按预期执行。如果您不小心,未跟踪的副作用可能会影响您预期计算的准确性。

使用谷歌的 JAX 
在本文中,我们解释了 JAX 的功能以及它为 NumPy 带来的优势。我们介绍了如何安装 JAX 库及其对机器学习的优势。

然后我们继续导入 JAX 和 NumPy。此外,我们将 JAX 与 NumPy(这是目前最著名的竞争对手库)进行了比较,并通过使用常规 CPU 和 GPU 以及一些 JIT 测试揭示了这两者之间的时间和性能差异,并看到了速度的显着提高。

如果您是一名高级机器/深度学习从业者,那么将 JAX 等库及其 (GPU/TPU) 加速器及其高效的 JIT 编译器添加到您的武器库中肯定会让生活变得更加轻松。

标签:机器学习库JAX,python高性能数值计算库JAX,基于numpy的机器学习库
来源:

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有