Back to jax

JAX Demo

cloud_tpu_colabs/JAX_demo.ipynb

0.3.256.3 KB
Original Source

The basics: interactive NumPy on GPU and TPU


python
import jax
import jax.numpy as jnp
from jax import random

key = random.key(0)
python
key, subkey = random.split(key)
x = random.normal(key, (5000, 5000))

print(x.shape)
print(x.dtype)
python
y = jnp.dot(x, x)
print(y[0, 0])
python
x
python
import matplotlib.pyplot as plt

plt.plot(x[0])
python
jnp.dot(x, x.T)
python
print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])
python
import numpy as np

x_cpu = np.array(x)
%timeit -n 1 -r 1 np.dot(x_cpu, x_cpu)
python
%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()

Automatic differentiation

python
from jax import grad
python
def f(x):
  if x > 0:
    return 2 * x ** 3
  else:
    return 3 * x
python
key = random.key(0)
x = random.normal(key, ())

print(grad(f)(x))
print(grad(f)(-x))
python
print(grad(grad(f))(-x))
print(grad(grad(grad(f)))(-x))
python
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # inputs to the next layer
  return outputs                # no activation on last layer

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.sum((predictions - targets)**2)



def init_layer(key, n_in, n_out):
  k1, k2 = random.split(key)
  W = random.normal(k1, (n_in, n_out))
  b = random.normal(k2, (n_out,))
  return W, b

layer_sizes = [5, 2, 3]

key = random.key(0)
key, *keys = random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

key, *keys = random.split(key, 3)
inputs = random.normal(keys[0], (8, 5))
targets = random.normal(keys[1], (8, 3))
batch = (inputs, targets)
python
print(loss(params, batch))
python
step_size = 1e-2

for _ in range(20):
  grads = grad(loss)(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]
python
print(loss(params, batch))

Other JAX autodiff highlights:

  • Forward- and reverse-mode, totally composable
  • Fast Jacobians and Hessians
  • Complex number support (holomorphic and non-holomorphic)
  • Jacobian pre-accumulation for elementwise operations (like gelu)

For much more, see the JAX Autodiff Cookbook (Part 1).

End-to-end compilation with XLA using jit

python
from jax import jit
python
key = random.key(0)
x = random.normal(key, (5000, 5000))
python
def f(x):
  y = x
  for _ in range(10):
    y = y - 0.1 * y + 3.
  return y[:100, :100]

f(x)
python
g = jit(f)
g(x)
python
%timeit f(x).block_until_ready()
python
%timeit g(x).block_until_ready()
python
grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)

Constraints that come with using jit

python
def f(x):
  if x > 0:
    return 2 * x ** 2
  else:
    return 3 * x

g = jit(f)
python
f(2)
python
try:
    g(2)
except Exception as e:
    print(e)
    pass
python
def f(x, n):
  i = 0
  while i < n:
    x = x * x
    i += 1
  return x

g = jit(f)
python
f(jnp.array([1., 2., 3.]), 5)
python
try:
    g(jnp.array([1., 2., 3.]), 5)
except Exception as e:
    print(e)
    pass
python
g = jit(f, static_argnums=(1,))
python
g(jnp.array([1., 2., 3.]), 5)

Vectorization with vmap

python
from jax import vmap
python
print(vmap(lambda x: x**2)(jnp.arange(8)))
python
from jax import make_jaxpr

make_jaxpr(jnp.dot)(jnp.ones(8), jnp.ones(8))
python
make_jaxpr(vmap(jnp.dot))(jnp.ones((10, 8)), jnp.ones((10, 8)))
python
make_jaxpr(vmap(vmap(jnp.dot)))(jnp.ones((10, 10, 8)), jnp.ones((10, 10, 8)))
python
perex_grads = vmap(grad(loss), in_axes=(None, 0))
make_jaxpr(perex_grads)(params, batch)

Parallel accelerators with pmap

python
jax.devices()
python
from jax import pmap
python
y = pmap(lambda x: x ** 2)(jnp.arange(8))
print(y)
python
y
python
z = y / 2
print(z)
python
import matplotlib.pyplot as plt
plt.plot(y)
python
keys = random.split(random.key(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)
result = pmap(jnp.dot)(mats, mats)
print(pmap(jnp.mean)(result))
python
timeit -n 5 -r 5 pmap(jnp.dot)(mats, mats).block_until_ready()

Collective communication operations

python
from functools import partial
from jax.lax import psum

@partial(pmap, axis_name='i')
def normalize(x):
  return x / psum(x, 'i')

print(normalize(jnp.arange(8.)))
python
@partial(pmap, axis_name='rows')
@partial(pmap, axis_name='cols')
def f(x):
  row_sum = psum(x, 'rows')
  col_sum = psum(x, 'cols')
  total_sum = psum(x, ('rows', 'cols'))
  return row_sum, col_sum, total_sum

x = jnp.arange(8.).reshape((4, 2))
a, b, c = f(x)

print("input:\n", x)
print("row sum:\n", a)
print("col sum:\n", b)
print("total sum:\n", c)

For more, see the pmap cookbook.

Compose pmap with other transforms!

python
@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  return grad(lambda w: jnp.sum(g(w)))(x)

f(x)
python
grad(lambda x: jnp.sum(f(x)))(x)

Compose everything

python
from jax import jvp, vjp  # forward and reverse-mode

curry = lambda f: partial(partial, f)

@curry
def jacfwd(fun, x):
  pushfwd = partial(jvp, fun, (x,))  # jvp!
  std_basis = jnp.eye(np.size(x)).reshape((-1,) + jnp.shape(x)),
  y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis)  # vmap!
  return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))

@curry
def jacrev(fun, x):
  y, pullback = vjp(fun, x)  # vjp!
  std_basis = jnp.eye(np.size(y)).reshape((-1,) + jnp.shape(y))
  jac_flat, = vmap(pullback)(std_basis)  # vmap!
  return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))  # jit!
python
input_hess = hessian(lambda inputs: loss(params, (inputs, targets)))
per_example_hess = pmap(input_hess)  # pmap!
per_example_hess(inputs)