mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
docs: runtime debugging tweaks
Mainly make titles/headings easier to read, by swapping code for words and not using headings as links.
This commit is contained in:
parent
763d600508
commit
42de34263f
@ -10,7 +10,9 @@ Table of contents:
|
||||
* [Functional error checks with jax.experimental.checkify](checkify_guide)
|
||||
* [Throwing Python errors with JAX’s debug flags](flags)
|
||||
|
||||
## [Interactive inspection with `jax.debug`](print_breakpoint)
|
||||
## Interactive inspection with `jax.debug`
|
||||
|
||||
Complete guide [here](print_breakpoint)
|
||||
|
||||
**Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
|
||||
and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
|
||||
@ -34,9 +36,11 @@ Table of contents:
|
||||
# 🤯 0.9092974662780762 🤯
|
||||
```
|
||||
|
||||
Click [here](print_breakpoint) to learn more!
|
||||
[Read more](print_breakpoint).
|
||||
|
||||
## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
|
||||
## Functional error checks with `jax.experimental.checkify`
|
||||
|
||||
Complete guide [here](checkify_guide)
|
||||
|
||||
**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
|
||||
|
||||
@ -77,9 +81,11 @@ Click [here](print_breakpoint) to learn more!
|
||||
# ValueError: nan generated by primitive sin at <...>:8 (f)
|
||||
```
|
||||
|
||||
Click [here](checkify_guide) to learn more!
|
||||
[Read more](checkify_guide).
|
||||
|
||||
## [Throwing Python errors with JAX's debug flags](flags)
|
||||
## Throwing Python errors with JAX's debug flags
|
||||
|
||||
Complete guide [here](flags)
|
||||
|
||||
**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
|
||||
|
||||
@ -92,7 +98,7 @@ def f(x, y):
|
||||
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
|
||||
```
|
||||
|
||||
Click [here](flags) to learn more!
|
||||
[Read more](flags).
|
||||
|
||||
```{toctree}
|
||||
:caption: Read more
|
||||
|
@ -1,9 +1,9 @@
|
||||
# `jax.debug.print` and `jax.debug.breakpoint`
|
||||
# Compiled prints and breakpoints
|
||||
|
||||
<!--* freshness: { reviewed: '2024-03-13' } *-->
|
||||
|
||||
The {mod}`jax.debug` package offers some useful tools for inspecting values
|
||||
inside of JIT-ted functions.
|
||||
inside of compiled functions.
|
||||
|
||||
## Debugging with `jax.debug.print` and other debugging callbacks
|
||||
|
||||
@ -27,7 +27,6 @@ f(2.)
|
||||
# 🤯 0.9092974662780762 🤯
|
||||
```
|
||||
|
||||
<!-- mattjj added this line -->
|
||||
With some transformations, like `jax.grad` and `jax.vmap`, you can use Python's builtin `print` function to print out numerical values. But `print` won't work with `jax.jit` or `jax.pmap` because those transformations delay numerical evaluation. So use `jax.debug.print` instead!
|
||||
|
||||
Semantically, `jax.debug.print` is roughly equivalent to the following Python function
|
||||
|
Loading…
x
Reference in New Issue
Block a user