mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Update debugging docs to mention pjit
This commit is contained in:
parent
4386a0f909
commit
fb0cf668b8
@ -4,7 +4,7 @@ Do you have exploding gradients? Are nans making you gnash your teeth? Just want
|
||||
|
||||
## [Interactive inspection with `jax.debug`](print_breakpoint)
|
||||
|
||||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`- or `jax.pmap`-decorated functions
|
||||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions
|
||||
and use {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
|
||||
|
||||
```python
|
||||
|
@ -5,7 +5,7 @@ inside of JIT-ted functions.
|
||||
|
||||
## Debugging with `jax.debug.print` and other debugging callbacks
|
||||
|
||||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`- or `jax.pmap`-decorated functions:
|
||||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions:
|
||||
|
||||
```python
|
||||
import jax
|
||||
@ -65,7 +65,6 @@ Notice that the printed results are in different orders!
|
||||
|
||||
By revealing these inner-workings, the output of `jax.debug.print` doesn't respect JAX's usual semantics guarantees, like that `jax.vmap(f)(xs)` and `jax.lax.map(f, xs)` compute the same thing (in different ways). Yet these evaluation order details are exactly what we might want to see when debugging!
|
||||
|
||||
<!-- mattjj tweaked this line -->
|
||||
So use `jax.debug.print` for debugging, and not when semantics guarantees are important.
|
||||
|
||||
### More examples of `jax.debug.print`
|
||||
@ -101,7 +100,6 @@ jax.grad(f)(1.)
|
||||
# Prints: x: 1.0
|
||||
```
|
||||
|
||||
<!-- mattjj added this line -->
|
||||
This behavior is similar to how Python's builtin `print` works under a `jax.grad`. But by using `jax.debug.print` here, the behavior is the same even if the caller applies a `jax.jit`.
|
||||
|
||||
To print on the backward pass, just use a `jax.custom_vjp`:
|
||||
@ -128,6 +126,11 @@ jax.grad(f)(1.)
|
||||
# Prints: x_grad: 2.0
|
||||
```
|
||||
|
||||
#### Printing in other transformations
|
||||
|
||||
`jax.debug.print` also works in other transformations like `xmap` and `pjit`
|
||||
(but `pjit` only works on TPUs for now).
|
||||
|
||||
### More control with `jax.debug.callback`
|
||||
|
||||
In fact, `jax.debug.print` is a thin convenience wrapper around `jax.debug.callback`, which can be used directly for greater control over string formatting, or even the kind of output.
|
||||
|
Loading…
x
Reference in New Issue
Block a user