Make JAX Debugging and Profiling guides more visible, move Profiling to User Guides from Notes

This commit is contained in:
8bitmp3 2023-01-19 23:05:24 +00:00
parent fae7306d88
commit d6cc2bdb22
3 changed files with 12 additions and 6 deletions

View File

@ -1,11 +1,17 @@
# Runtime value debugging in JAX
Do you have exploding gradients? Are nans making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has tl;dr summaries and you can click the "Read more" links at the bottom to learn more.
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.
Table of contents:
* [Interactive inspection with `jax.debug`](print_breakpoint)
* [Functional error checks with jax.experimental.checkify](checkify_guide)
* [Throwing Python errors with JAXs debug flags](flags)
## [Interactive inspection with `jax.debug`](print_breakpoint)
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions
and use {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
```python
import jax

View File

@ -10,7 +10,5 @@ Notes
deprecation
concurrency
gpu_memory_allocation
profiling
device_memory_profiling
rank_promotion_warning
jax_array_migration

View File

@ -13,5 +13,7 @@ User Guides
pytrees
type_promotion
errors
debugging/index
profiling
device_memory_profiling
transfer_guard
debugging/index