对比PyTorch、TensorFlow、JAX、Theano,我发现都在关注两大问题

作者|王益

OneFlow社区编译

翻译|杨婷

最近,我在处理 PyTorch 分布式和 TorchRec 相关的作业,为此,我开端学习 PyTorch 2.0。在业余时间,我也在跟着Alpa作者学习JAX和XLA。如今回忆这些技术,我发现它们的关注点似乎都是如下两个问题:

  1. 包括主动求导和并行在内的函数转化,例如 vmap, pmap 和 pjit 等;

  2. 异构核算,CPU 担任操控流,GPU/TPU 担任张量核算和调集通讯。

本文档中的一切比如都支撑在 Colab 中运转:

对比PyTorch、TensorFlow、JAX、Theano,我发现都在关注两大问题

1

函数转化

“函数转化”意为将一个程序转变成另一个程序,最常见的比如是主动求导(autograd)。主动求导采用用户编写的前向进程并创立后向进程,关于用户来说,编写主动求导通常都过分杂乱。函数转化的首要难点在于:在编写函数转化算法时以何种办法表明输入和输出进程。

Theano:显式地构建 IR

Theano是最早的深度学习东西之一,也便是如今为人们所熟知的Aesara项目。Theano有一个允许用户在内存中将IR构建为数据结构的API,因而Theano可完成主动求导,并将成果输出为 Python 函数。

import aesara
from aesara import tensor as at
a = at.dscalar("a") # Define placeholders, which have no values.
b = at.dscalar("b")
c = a * b              # c now contains the IR of an expression.TT
dc = aesara.grad(c, a) # Convert the IR in c into another one, dc
f_dc = aesara.function([a, b], dc) # Convert the IR into a Python function,
assert f_dc(1.5, 2.5) == 2.5       # so we can call it.

TensorFlow 1.x:用于运转 IR 的虚拟机

TensorFlow 1.x明确保留了构建IR的想法。若在TensorFlow中运转上述示例,成果不会有什么不同;但倘若在TensorFlow 1.x中来运转,最大的不同在于:咱们不会将后向 IR 转化为 Python 函数,并运用 Python 解说器来运转。相反,咱们会在TensorFlow runtime中来运转。

import tensorflow.compat.v1 as tf # TensorFlow 1.x API
import numpy as np
tf.disable_eager_execution()
a = tf.placeholder(tf.float32, shape=())
b = tf.placeholder(tf.float32, shape=())
c = a * b
dc = tf.gradients(c, [a], stop_gradients=[a, b])
with tf.compat.v1.Session() as sess:  # TensorFlow has a runtime to execute the IR, 
  x = np.single(2)                    # so, no converting it into Python code. 
  y = np.single(3)     
  print(sess.run(dc, feed_dict={a:x, b:y}))

PyTorch 1.x:没有前向IR

PyTorch不会像Theano或TensorFlow那样将前向传达转化为IR。反之,PyTorch 运用 Python 解说器来运转前向传达。这样做的坏处在于会在运转期间生成表明后向传达的 IR,咱们称之为Eager形式(动态图形式)。

import torch
a = torch.tensor(1.0, requires_grad=True) # These are not placeholders, but values.
b = torch.tensor(2.0)
c = a * b    # Evaluates c and derives the IR of the backward in c.grad_fn_.
c.backward() # Executes c.grad_fn_.
print(c.grad)

TensorFlow 2.x: 梯度带

TensorFlow 2.x增加了一个像PyTorch API的Eager形式API。此 API 追踪前向传达怎么运转名为梯度带(GradientTape)的 IR 。TensorFlow 2.x能够从这个跟踪中找出后向传达。

import tensorflow as tf
a = tf.Variable(1.0) # Like PyTorch, these are values, not placehodlers. 
b = tf.Variable(2.0)
with tf.GradientTape() as tape:
  c = a * b
dcda = tape.gradient(c, a)
print(dcda)

JAX

JAX 不会向用户公开诸如梯度带等方面的初等级细节。简略说来,JAX的思想办法为:将输入和输出都用Python函数来表明。

import jax
a = 2.0
b = 3.0
jax.grad(jax.lax.mul)(a, b)  # Compute c = a * b w.r.t. a.  The result is b=3. 
jax.jit(jax.grad(jax.lax.mul))(a,b)
jax.experimental.pjit(jax.grad(jax.lax.mul), 
                      device_mesh(ntpus))(a,b)

关于想要自己编写的函数转化的高档用户,他们能够调用make_jaxpr等初级 API 来访问 IR,称为 JAXPR。

jax.make_jaxpr(jax.lax.mul)(2.0, 3.0)  # Returns the IR representing jax.lax.mul(2,3)
jax.make_jaxpr(jax.grad(jax.lax.mul))(2.0, 3.0)  # Returns the IR of grad(mul)(2,3)

FuncTorch

FuncTorch和JAX类似,都是依据PyTorch的函数转化。

import torch, functorch
a = torch.tensor([2.0])
b = torch.tensor([3.0])
functorch.grad(torch.dot)(a, b)

JAX的make_jaxpr类似于functorch的make_fx。

def f(a, b):
  return torch.dot(a, b) # Have to wrap the builtin function dot into f. # 必须将内置函数dot转化成f.
print(functorch.make_fx(f)(a, b).code)
print(functorch.make_fx(functorch.grad(f))(a, b).code)

TensorFlow 2.x、JAX 和 functorch 都为前向传递构建了一个 IR,但 PyTorch Eager形式没有。IR 不只可用于主动求导,还可用于其他类型的函数转化。在下列比如中,
functorch.compile.aot_function调用了回调函数print_compile_fn两次,别离用于前向和后向传达。

from functorch.compile import aot_function
import torch.fx as fx
def print_compile_fn(fx_module, args):
    print(fx_module)
    return fx_module
aot_fn = aot_function(torch.dot, print_compile_fn)
aot_fn(a, b)

2

高阶导数

PyTorch

import torch
from torch import autograd
x = torch.tensor(1., requires_grad = True)
y = 2*x**3 + 8
first_derivative = autograd.grad(y, x, create_graph=True)
print(first_derivative)
second_derivative = autograd.grad(first_derivative, x)
print(second_derivative)

TensorFlow 2.x

import tensorflow as tf
x = tf.Variable(1.0)
with tf.GradientTape() as outer_tape:
    with tf.GradientTape() as tape:
        y = 2*x**3 + 8
        dy_dx = tape.gradient(y, x)
        print(dy_dx)
    d2y_dx2 = outer_tape.gradient(dy_dx, x)
    print(d2y_dx2)

JAX

def f(a):
  return 2*a**3 + 8
print(jax.grad(f)(1.0))
print(jax.grad(jax.grad(f))(1.0))

3

动态操控流

动态操控流(dynamic control flows)有两个层级:在 CPU 上运转的粗粒度等级和在 GPU /TPU 上运转的细粒度等级。本部分首要介绍在 CPU 上运转的粗粒度等级的动态操控流。下面咱们将用(if/else)条件句子作为比如检验深度学习东西。

TensorFlow 1.x

在 TensorFlow 1.x 中,咱们需要将条件句子显式构建到 IR 中。此刻条件句子是一个特别的运算符tf.cond。

def f1(): return tf.multiply(a, 17)
def f2(): return tf.add(b, 23)
r = tf.cond(tf.less(a, b), f1, f2)
with tf.compat.v1.Session() as sess:  # TensorFlow has a runtime to execute the IR,
  print(sess.run(r, feed_dict={a:x, b:y}))

TensorFlow 2.x

TensorFlow 2.x 支撑运用tf.cond和tf.while_loop显式构建操控流。此外,实验项目google/tangent中有AutoGraph功用,它能够将Python操控流通化为tf.cond或tf.while_loop。此功用利用了 Python 解说器支撑的函数和函数源代码。例如下面的g函数调用了 Python 的规范库将源代码解析为 AST,然后调用 SSA 表单来了解操控流。

def g(x, y):
    if tf.reduce_any(x < y):
        return tf.multiply(x, 17)
    return tf.add(y, 23)
converted_g = tf.autograph.to_graph(g)
import inspect
print(inspect.getsource(converted_g))

JAX

因为部分Python语法很杂乱,所以经过解析源代码来了解操控流就显得很困难,这就导致AutoGraph经常犯错。但如果这种办法很简略,那么Python开发者社区也不会在构建Python编译器时失利这么多次了。正是因为有这种应战的存在,必须要明确地将操控流构建到 IR 中。为此,JAX 供给了jax.lax.cond和jax.lax.for_loop函数。

jax.lax.cond(a < b, lambda : a*17, lambda: b+23)

考虑到这一点,你或许会觉得咱们能够运用递归算法。但是下面用于核算阶乘的递归无法用JAX跟踪。

def factorial(r, x):
  return jax.lax.cond(x <= 1.0, lambda: r, lambda: factorial(r*x, x-1))
factorial(1.0, 3.0)

或许你还想调用factorial来核算3!=6。但这会让递归深度超越最大值,因为递归不只依靠于条件,还依靠于函数界说和调用。

PyTorch

PyTorch开始是Python-native。正如前文所说,因为多功用调度机制,grad和vamp的函数转化都是即时的。值得注意的是:

  1. 相比Theano 和 TensorFlow构建IR后的函数转化,即时函数转化功率更高。

  2. 在进行grad和vmap时,JAX也是即时函数转化。但是像pamp和pjit等更杂乱的函数转化需要对整个核算进程进行概述,在这个进程中IR是必不可少的。

因为IR在pmap和pjit中的必要性,PyTorch社区最近添加了torch.condpytorch/pytorch#83154

4

分布式核算

依据履行代码或 IR 的不同办法,在运用 Python 解说器或runtime时,有两种分布式核算办法。

Python-Native

Theano和PyTorch采用了Python-native分布式核算办法。这种分布式练习作业包括多个Python解说器进程。这导致出现了以下成果。

  1. 打包和运转(Pack and run)。因为这些 Python 进程在不同的host上运转,因而咱们需要打包用户程序和依靠项,并将它们发送到这些host上去运转。一直以来TorchX担任了这个打包进程。它支撑例如Docker和torch.package等各种打包格局,并且能够与各种集群管理器合作运用,如Kubernetes和SLURM。

  2. 单程序多数据(SPMD)。因为将用户程序发送到各种host上要依靠于打包,与其他权重较轻的办法(如经过 RPC 发送代码)相比,这种办法不太灵活,因而,咱们通常只发送一个程序。当一切这些进程运转同一程序时,这个作业就变成了单程序多数据(SPMD)作业。

Python-native SPMD

下面是一个简略的SPMD PyTorch程序,咱们能够在相同或不同的host上运用进程运转这个程序。在这个进程中,咱们只需要调用all_gather。真实的分布式练习程序会调用更高档别的API,例如
torch.nn.parallel.DistributedDataParallel和
torchrec.DistributedModelParallel, 然后再调用初级 API,例如all_gather和all_reduce。

import os
import torch
from torch import distributed as dist
def main():
    use_gpu = torch.cuda.is_available()
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", "0"))
    device = torch.device(f"cuda:{local_rank}" if use_gpu else "cpu")
    dist.init_distributed(backend="nccl")
    lst = torch.tensor([local_rank + 100]).to(device)
    # placeholder 
    rlt_lst = [torch.zeros_like(lst) for _ in range(local_world_size)]
    dist.all_gather(rlt_lst, lst, async_op=False)
    print("After broadcasting:", rlt_lst)

Python-native Non-SPMD

PyTorch 不只限于 SPMD 式的分布式练习。它还经过
torch.distributed.pipeline.sync.Pipe和PiPPy project供给流水并行,其中流水并行的各个阶段在不同的设备上运转不同的程序。这些阶段常经过torch.rpc包来交流。

分布式运转时机制

分布式 TensorFlow 作业由运转 TensorFlow runtime 程序的进程组成,而不是由 Python 解说器组成。此分布式运转时作业履行 TensorFlow graph (IR),它是由履行用户程序的 Python 解说器生成。

用户程序能够运用初级API(如tf.device)去指定作业要运转什么操作、在哪台设备和主机上运转等等。因为API有runtime,所以能够做到这一点。

with tf.device('/job:bar/task:0/device:gpu:2'):
    # ops created here have the fully specified device above

与PyTorch相同,TensorFlow也为分布式练习供给了高档APItf.distributed.strategy,Keras和DTensor。

strategy = tf.distribute.MirroredStrategy() \
           if tf.config.list_physical_devices('GPU') \
           else tf.distribute.get_strategy()
with strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss='mse', optimizer='sgd')

分布式运转时极大地方便了练习服务的维护,因为咱们不再将用户程序打包到集群上运转。相反,咱们打包运转时程序,因为相比用户程序,运转时程序更加一致。

混合理念

JAX 支撑 Python-native 和分布式运转时。

JAX 供给例如vmap、pmap和pjit的函数转化,这能够将 Python 函数转化为分布式程序。

(本文经授权后由OneFlow社区编译,译文转载请联络取得授权。原文:
quip.com/Y8qtAyV4EXR…

欢迎 Star、试用 OneFlow 最新版本:
github.com/Oneflow-Inc…