Update debugging docs to mention pjit

This commit is contained in:
Sharad Vikram 2022-07-28 21:47:25 -07:00
parent 4386a0f909
commit fb0cf668b8
2 changed files with 7 additions and 4 deletions

View File

@ -4,7 +4,7 @@ Do you have exploding gradients? Are nans making you gnash your teeth? Just want
## [Interactive inspection with `jax.debug`](print_breakpoint)
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`- or `jax.pmap`-decorated functions
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions
and use {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
```python

View File

@ -5,7 +5,7 @@ inside of JIT-ted functions.
## Debugging with `jax.debug.print` and other debugging callbacks
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`- or `jax.pmap`-decorated functions:
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions:
```python
import jax
@ -65,7 +65,6 @@ Notice that the printed results are in different orders!
By revealing these inner-workings, the output of `jax.debug.print` doesn't respect JAX's usual semantics guarantees, like that `jax.vmap(f)(xs)` and `jax.lax.map(f, xs)` compute the same thing (in different ways). Yet these evaluation order details are exactly what we might want to see when debugging!
<!-- mattjj tweaked this line -->
So use `jax.debug.print` for debugging, and not when semantics guarantees are important.
### More examples of `jax.debug.print`
@ -101,7 +100,6 @@ jax.grad(f)(1.)
# Prints: x: 1.0
```
<!-- mattjj added this line -->
This behavior is similar to how Python's builtin `print` works under a `jax.grad`. But by using `jax.debug.print` here, the behavior is the same even if the caller applies a `jax.jit`.
To print on the backward pass, just use a `jax.custom_vjp`:
@ -128,6 +126,11 @@ jax.grad(f)(1.)
# Prints: x_grad: 2.0
```
#### Printing in other transformations
`jax.debug.print` also works in other transformations like `xmap` and `pjit`
(but `pjit` only works on TPUs for now).
### More control with `jax.debug.callback`
In fact, `jax.debug.print` is a thin convenience wrapper around `jax.debug.callback`, which can be used directly for greater control over string formatting, or even the kind of output.