cloud_tpu_colabs/JAX_demo.ipynb
import jax
import jax.numpy as jnp
from jax import random
key = random.key(0)
key, subkey = random.split(key)
x = random.normal(key, (5000, 5000))
print(x.shape)
print(x.dtype)
y = jnp.dot(x, x)
print(y[0, 0])
x
import matplotlib.pyplot as plt
plt.plot(x[0])
jnp.dot(x, x.T)
print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])
import numpy as np
x_cpu = np.array(x)
%timeit -n 1 -r 1 np.dot(x_cpu, x_cpu)
%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()
from jax import grad
def f(x):
if x > 0:
return 2 * x ** 3
else:
return 3 * x
key = random.key(0)
x = random.normal(key, ())
print(grad(f)(x))
print(grad(f)(-x))
print(grad(grad(f))(-x))
print(grad(grad(grad(f)))(-x))
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.sum((predictions - targets)**2)
def init_layer(key, n_in, n_out):
k1, k2 = random.split(key)
W = random.normal(k1, (n_in, n_out))
b = random.normal(k2, (n_out,))
return W, b
layer_sizes = [5, 2, 3]
key = random.key(0)
key, *keys = random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = random.split(key, 3)
inputs = random.normal(keys[0], (8, 5))
targets = random.normal(keys[1], (8, 3))
batch = (inputs, targets)
print(loss(params, batch))
step_size = 1e-2
for _ in range(20):
grads = grad(loss)(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss(params, batch))
Other JAX autodiff highlights:
gelu)For much more, see the JAX Autodiff Cookbook (Part 1).
jitfrom jax import jit
key = random.key(0)
x = random.normal(key, (5000, 5000))
def f(x):
y = x
for _ in range(10):
y = y - 0.1 * y + 3.
return y[:100, :100]
f(x)
g = jit(f)
g(x)
%timeit f(x).block_until_ready()
%timeit g(x).block_until_ready()
grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)
jitdef f(x):
if x > 0:
return 2 * x ** 2
else:
return 3 * x
g = jit(f)
f(2)
try:
g(2)
except Exception as e:
print(e)
pass
def f(x, n):
i = 0
while i < n:
x = x * x
i += 1
return x
g = jit(f)
f(jnp.array([1., 2., 3.]), 5)
try:
g(jnp.array([1., 2., 3.]), 5)
except Exception as e:
print(e)
pass
g = jit(f, static_argnums=(1,))
g(jnp.array([1., 2., 3.]), 5)
vmapfrom jax import vmap
print(vmap(lambda x: x**2)(jnp.arange(8)))
from jax import make_jaxpr
make_jaxpr(jnp.dot)(jnp.ones(8), jnp.ones(8))
make_jaxpr(vmap(jnp.dot))(jnp.ones((10, 8)), jnp.ones((10, 8)))
make_jaxpr(vmap(vmap(jnp.dot)))(jnp.ones((10, 10, 8)), jnp.ones((10, 10, 8)))
perex_grads = vmap(grad(loss), in_axes=(None, 0))
make_jaxpr(perex_grads)(params, batch)
jax.devices()
from jax import pmap
y = pmap(lambda x: x ** 2)(jnp.arange(8))
print(y)
y
z = y / 2
print(z)
import matplotlib.pyplot as plt
plt.plot(y)
keys = random.split(random.key(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)
result = pmap(jnp.dot)(mats, mats)
print(pmap(jnp.mean)(result))
timeit -n 5 -r 5 pmap(jnp.dot)(mats, mats).block_until_ready()
from functools import partial
from jax.lax import psum
@partial(pmap, axis_name='i')
def normalize(x):
return x / psum(x, 'i')
print(normalize(jnp.arange(8.)))
@partial(pmap, axis_name='rows')
@partial(pmap, axis_name='cols')
def f(x):
row_sum = psum(x, 'rows')
col_sum = psum(x, 'cols')
total_sum = psum(x, ('rows', 'cols'))
return row_sum, col_sum, total_sum
x = jnp.arange(8.).reshape((4, 2))
a, b, c = f(x)
print("input:\n", x)
print("row sum:\n", a)
print("col sum:\n", b)
print("total sum:\n", c)
For more, see the pmap cookbook.
@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
f(x)
grad(lambda x: jnp.sum(f(x)))(x)
from jax import jvp, vjp # forward and reverse-mode
curry = lambda f: partial(partial, f)
@curry
def jacfwd(fun, x):
pushfwd = partial(jvp, fun, (x,)) # jvp!
std_basis = jnp.eye(np.size(x)).reshape((-1,) + jnp.shape(x)),
y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis) # vmap!
return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))
@curry
def jacrev(fun, x):
y, pullback = vjp(fun, x) # vjp!
std_basis = jnp.eye(np.size(y)).reshape((-1,) + jnp.shape(y))
jac_flat, = vmap(pullback)(std_basis) # vmap!
return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))
def hessian(fun):
return jit(jacfwd(jacrev(fun))) # jit!
input_hess = hessian(lambda inputs: loss(params, (inputs, targets)))
per_example_hess = pmap(input_hess) # pmap!
per_example_hess(inputs)