docs/debugging/print_breakpoint.md
The {mod}jax.debug package offers some useful tools for inspecting values
inside of compiled functions.
jax.debug.print and other debugging callbacksSummary: Use {func}jax.debug.print to print traced array values to
stdout in compiled (e.g. jax.jit or jax.pmap-decorated) functions:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
With some transformations, like jax.grad and jax.vmap, you can use Python's builtin print function to print out numerical values. But print won't work with jax.jit or jax.pmap because those transformations delay numerical evaluation. So use jax.debug.print instead!
Semantically, jax.debug.print is roughly equivalent to the following Python function
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
except that it can be staged out and transformed by JAX. See the {func}API reference <jax.debug.print> for more details.
Note that fmt cannot be an f-string because f-strings are formatted immediately, whereas for jax.debug.print, we'd like to delay formatting until later.
You should use jax.debug.print for dynamic (i.e. traced) array values within JAX transformations
like jit, vmap, and others.
For printing of static values (like array shapes or dtypes), you can use a normal Python print statement.
In the name of debugging, jax.debug.print can reveal information about how computations are evaluated:
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = jnp.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
Notice that the printed results are in different orders!
By revealing these inner-workings, the output of jax.debug.print doesn't respect JAX's usual semantics guarantees, like that jax.vmap(f)(xs) and jax.lax.map(f, xs) compute the same thing (in different ways). Yet these evaluation order details are exactly what we might want to see when debugging!
So use jax.debug.print for debugging, and not when semantics guarantees are important.
jax.debug.printIn addition to the above examples using jit and vmap, here are a few more to have in mind.
jax.pmapWhen jax.pmap-ed, jax.debug.prints might be reordered!
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# OR
# Prints: x: 1.0
# x: 0.0
jax.gradUnder a jax.grad, jax.debug.prints will only print on the forward pass:
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
This behavior is similar to how Python's builtin print works under a jax.grad. But by using jax.debug.print here, the behavior is the same even if the caller applies a jax.jit.
To print on the backward pass, just use a jax.custom_vjp:
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
jax.debug.print also works in other transformations like pjit.
jax.debug.callbackIn fact, jax.debug.print is a thin convenience wrapper around jax.debug.callback, which can be used directly for greater control over string formatting, or even the kind of output.
Semantically, jax.debug.callback is roughly equivalent to the following Python function
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
As with jax.debug.print, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it's not safe to use jax.debug.callback for timing operations, since callbacks might be reordered and asynchronous (see below).
Like most JAX APIs, jax.debug.print can cut you if you're not careful.
When distinct calls to jax.debug.print involve arguments which don't depend on one another, they might be reordered when staged out, e.g. by jax.jit:
@jax.jit
def f(x, y):
jax.debug.print("x: {}", x)
jax.debug.print("y: {}", y)
return x + y
f(2., 3.)
# Prints: x: 2.0
# y: 3.0
# OR
# Prints: y: 3.0
# x: 2.0
Why? Under the hood, the compiler gets a functional representation of the staged-out computation, where the imperative order of the Python function is lost and only data dependence remains. This change is invisible to users with functionally pure code, but in the presence of side-effects like printing, it's noticeable.
To preserve the original order of jax.debug.prints as written in your Python function, you can use jax.debug.print(..., ordered=True), which will ensure the relative order of prints is preserved. But using ordered=True will raise an error under jax.pmap and other JAX transformations involving parallelism, since ordering can't be guaranteed under parallel execution.
Adding jax.debug.print or jax.debug.breakpoint statements will change the computation that XLA is asked to compile. This can potentially result in numeric discrepancies compared to the same code without debug statements because XLA might perform different operation fusions during compilation. Keep this in mind when debugging numerical issues, as the act of adding debug statements might affect the behavior you're trying to investigate.
Depending on the backend, jax.debug.prints may happen asynchronously, i.e. not in your main program thread. This means that values could be printed to your screen even after your JAX function has returned a value.
@jax.jit
def f(x):
jax.debug.print("x: {}", x)
return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2.
To block on the jax.debug.prints in a function, you can call jax.effects_barrier(), which will wait until any remaining side-effects in the function have completed as well:
@jax.jit
def f(x):
jax.debug.print("x: {}", x)
return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else>
While jax.debug.print was designed to have a minimal performance footprint, it can interfere with compiler optimizations and potentially affect the memory profile of your JAX programs.
def f(w, b, x):
logits = w.dot(x) + b
jax.debug.print("logits: {}", logits)
return jax.nn.relu(logits)
In this example, we are printing intermediate values in between a linear layer and the activation function. Compilers like XLA can perform fusion optimizations, which might avoid materializing logits in memory. But when we use jax.debug.print on logits, we are forcing those intermediates to be materialized, potentially slowing down the program and increasing memory usage.
Furthermore, when using jax.debug.print with jax.pjit, a global synchronization occurs that will materialize values on a single device.
jax.debug.print inherently incurs communication between an accelerator and its host. The underlying mechanism differs from backend to backend (e.g. GPU vs TPU) but in all cases, we'll need to copy the printed values from device to host. In the CPU case, this overhead is smaller.
Furthermore, when using jax.debug.print with jax.pjit, a global synchronization occurs that adds some overhead.
jax.debug.printjax.debug.callback can be used for other innocuous side-effectsjax.debug.breakpoint()Summary: Use jax.debug.breakpoint() to pause the execution of your JAX program to inspect values:
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!
jax.debug.breakpoint() is actually just an application of jax.debug.callback(...) that captures information about the call stack. It has the same transformation behaviors as jax.debug.print as a result (e.g. vmap-ing jax.debug.breakpoint() unrolls it across the mapped axis).
Calling jax.debug.breakpoint() in a compiled JAX function will pause your program when it hits the breakpoint. You'll be presented with a pdb-like prompt that allows you to inspect the values in the call stack. Unlike pdb, you will not be able to step through the execution, but you are allowed to resume it.
Debugger commands:
help - prints out available commandsp - evaluates an expression and prints its resultpp - evaluates an expression and pretty-prints its resultu(p) - go up a stack framed(own) - go down a stack framew(here)/bt - print out a backtracel(ist) - print out code contextc(ont(inue)) - resumes the execution of the programq(uit)/exit - exits the program (does not work on TPU)jax.lax.condWhen combined with jax.lax.cond, the debugger can become a useful tool for detecting nans or infs.
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
Because jax.debug.breakpoint is a just an application of jax.debug.callback, it has the same sharp bits as jax.debug.print, with a few more caveats:
jax.debug.breakpoint materializes even more intermediates than jax.debug.print because it forces materialization of all values in the call stackjax.debug.breakpoint has more runtime overhead than a jax.debug.print because it has to potentially copy all the intermediate values in a JAX program from device to host.jax.debug.breakpoint()