Back to jax

Lorentz ODE Solver in JAX

cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb

0.3.257.5 KB
Original Source

Lorentz ODE Solver in JAX

Alex Alemi

Imports

import io
import os
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax import vmap, jit, grad, ops, lax, config
from jax import random as jr

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from IPython.display import display_png

mpl.rcParams['savefig.pad_inches'] = 0
plt.style.use('seaborn-dark')
%matplotlib inline

Plotting Utilities

These just provide fast, better antialiased line plotting than typical matplotlib plotting routines.

@jit
def drawline(im, x0, y0, x1, y1):
  """An implementation of Wu's antialiased line algorithm.
  
  This functional version was adapted from here:
    https://en.wikipedia.org/wiki/Xiaolin_Wu's_line_algorithm
  """

  ipart = lambda x: jnp.floor(x).astype('int32')
  round_ = lambda x: ipart(x + 0.5).astype('int32')
  fpart = lambda x: x - jnp.floor(x)
  rfpart = lambda x: 1 - fpart(x)

  def plot(im, x, y, c):
    return ops.index_add(im, ops.index[x, y], c)

  steep = jnp.abs(y1 - y0) > jnp.abs(x1 - x0)
  cond_swap = lambda cond, x: lax.cond(cond, x, lambda x: (x[1], x[0]), x, lambda x: x)
  
  (x0, y0) = cond_swap(steep, (x0, y0))
  (x1, y1) = cond_swap(steep, (x1, y1))
  
  (y0, y1) = cond_swap(x0 > x1, (y0, y1))
  (x0, x1) = cond_swap(x0 > x1, (x0, x1))

  dx = x1 - x0
  dy = y1 - y0
  gradient = jnp.where(dx == 0.0, 1.0, dy/dx)

  # handle first endpoint
  xend = round_(x0)
  yend = y0 + gradient * (xend - x0)
  xgap = rfpart(x0 + 0.5)
  xpxl1 = xend # this will be used in main loop
  ypxl1 = ipart(yend)

  def true_fun(im):
    im = plot(im, ypxl1, xpxl1, rfpart(yend) * xgap)
    im = plot(im, ypxl1+1, xpxl1,  fpart(yend) * xgap)
    return im
  def false_fun(im):
    im = plot(im, xpxl1, ypxl1  , rfpart(yend) * xgap)
    im = plot(im, xpxl1, ypxl1+1,  fpart(yend) * xgap)
    return im
  im = lax.cond(steep, im, true_fun, im, false_fun)
  
  intery = yend + gradient

  # handle second endpoint
  xend = round_(x1)
  yend = y1 + gradient * (xend - x1)
  xgap = fpart(x1 + 0.5)
  xpxl2 = xend  # this will be used in the main loop
  ypxl2 = ipart(yend)
  def true_fun(im):
    im = plot(im, ypxl2  , xpxl2, rfpart(yend) * xgap)
    im = plot(im, ypxl2+1, xpxl2,  fpart(yend) * xgap)
    return im
  def false_fun(im):
    im = plot(im, xpxl2, ypxl2,  rfpart(yend) * xgap)
    im = plot(im, xpxl2, ypxl2+1, fpart(yend) * xgap)
    return im
  im = lax.cond(steep, im, true_fun, im, false_fun)
  
  def true_fun(arg):
    im, intery = arg
    def body_fun(x, arg):
      im, intery = arg
      im = plot(im, ipart(intery), x, rfpart(intery))
      im = plot(im, ipart(intery)+1, x, fpart(intery))
      intery = intery + gradient
      return (im, intery)
    im, intery = lax.fori_loop(xpxl1+1, xpxl2, body_fun, (im, intery))
    return (im, intery)
  def false_fun(arg):
    im, intery = arg
    def body_fun(x, arg):
      im, intery = arg
      im = plot(im, x, ipart(intery), rfpart(intery))
      im = plot(im, x, ipart(intery)+1, fpart(intery))
      intery = intery + gradient
      return (im, intery)
    im, intery = lax.fori_loop(xpxl1+1, xpxl2, body_fun, (im, intery))
    return (im, intery)
  im, intery = lax.cond(steep, (im, intery), true_fun, (im, intery), false_fun)
  
  return im

def img_adjust(data):
  oim = np.array(data)
  hist, bin_edges = np.histogram(oim.flat, bins=256*256)
  bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
  cdf = hist.cumsum()
  cdf = cdf / float(cdf[-1])
  return np.interp(oim.flat, bin_centers, cdf).reshape(oim.shape)

def imify(arr, vmin=None, vmax=None, cmap=None, origin=None):
  arr = img_adjust(arr)
  sm = cm.ScalarMappable(cmap=cmap)
  sm.set_clim(vmin, vmax)
  if origin is None:
    origin = mpl.rcParams["image.origin"]
  if origin == "lower":
    arr = arr[::-1]
  rgba = sm.to_rgba(arr, bytes=True)
  return rgba

def plot_image(array, **kwargs):
  f = io.BytesIO()
  imarray = imify(array, **kwargs)
  plt.imsave(f, imarray, format="png")
  f.seek(0)
  dat = f.read()
  f.close()
  display_png(dat, raw=True)

def pack_images(images, rows, cols):
  shape = np.shape(images)
  width, height, depth = shape[-3:]
  images = np.reshape(images, (-1, width, height, depth))
  batch = np.shape(images)[0]
  rows = np.minimum(rows, batch)
  cols = np.minimum(batch // rows, cols)
  images = images[:rows * cols]
  images = np.reshape(images, (rows, cols, width, height, depth))
  images = np.transpose(images, [0, 2, 1, 3, 4])
  images = np.reshape(images, [rows * width, cols * height, depth])
  return images

Lorentz Dynamics

Implement Lorentz' attractor

sigma = 10.
beta = 8./3
rho = 28.

@jit
def f(state, t):
  x, y, z = state
  return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

Runge Kutta Integrator

@jit
def rk4(ys, dt, N):
  @jit
  def step(i, ys):
    h = dt
    t = dt * i
    k1 = h * f(ys[i-1], t)
    k2 = h * f(ys[i-1] + k1/2., dt * i + h/2.)
    k3 = h * f(ys[i-1] + k2/2., t + h/2.)
    k4 = h * f(ys[i-1] + k3, t + h)
    
    ysi = ys[i-1] + 1./6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return ops.index_update(ys, ops.index[i], ysi)
  return lax.fori_loop(1, N, step, ys)

Solve and plot a single ODE Solution using jitted solver and plotter

N = 40000

# set initial condition
state0 = jnp.array([1., 1., 1.])
ys = jnp.zeros((N,) + state0.shape)
ys = ops.index_update(ys, ops.index[0], state0)

# solve for N steps
ys = rk4(ys, 0.004, N).block_until_ready()
# plotting size and region:
xlim, zlim = (-20, 20), (0, 50)
xN, zN = 800, 600

# fast, jitted plotting function
@partial(jax.jit, static_argnums=(2,3,4,5))
def jplotter(xs, zs, xlim, zlim, xN, zN):
  im = jnp.zeros((xN, zN))
  xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN
  zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN
  def body_fun(i, im):
    return drawline(im, xpixels[i-1], zpixels[i-1], xpixels[i], zpixels[i])
  return lax.fori_loop(1, xpixels.shape[0], body_fun, im)

im = jplotter(ys[...,0], ys[...,2], xlim, zlim, xN, zN)
plot_image(im[:,::-1].T, cmap='magma')

Parallel ODE Solutions with Pmap

N_dev = jax.device_count()
N = 4000

# set some initial conditions for each replicate
ys = jnp.zeros((N_dev, N, 3))
state0 = jr.uniform(jr.key(1), 
                    minval=-1., maxval=1.,
                    shape=(N_dev, 3))
state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))
ys = ops.index_update(ys, ops.index[:, 0], state0)

# solve each replicate in parallel using `pmap` of rk4 solver:
ys = jax.pmap(rk4)(ys, 
                   0.004 * jnp.ones(N_dev), 
                   N * jnp.ones(N_dev, dtype=np.int32)
                  ).block_until_ready()
# parallel plotter using lexical closure and pmap'd core plotting function
def pplotter(_xs, _zs, xlim, zlim, xN, zN):
  N_dev = _xs.shape[0]
  im = jnp.zeros((N_dev, xN, zN))
  @jax.pmap
  def plotfn(im, xs, zs):
    xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN
    zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN
    def body_fun(i, im):
      return drawline(im, xpixels[i-1], zpixels[i-1], xpixels[i], zpixels[i])
    return lax.fori_loop(1, xpixels.shape[0], body_fun, im)
  return plotfn(im, _xs, _zs)
xlim, zlim = (-20, 20), (0, 50)
xN, zN = 200, 150
# above, plot ODE traces separately
ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN, zN)
im = pack_images(ims[..., None], 4, 2)[..., 0]
plot_image(im[:,::-1].T, cmap='magma')
# below, plot combined ODE traces
ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN*4, zN*4)
plot_image(jnp.sum(ims, axis=0)[:,::-1].T, cmap='magma')