mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve jax_debug_nans documentation (#3665)
This commit is contained in:
parent
8f93607330
commit
c485a5b04a
@ -2497,13 +2497,141 @@
|
||||
"### Debugging NaNs\n",
|
||||
"\n",
|
||||
"If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:\n",
|
||||
"- setting the `JAX_DEBUG_NANS=True` environment variable.\n",
|
||||
"- adding `from jax.config import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file\n",
|
||||
"- adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`.\n",
|
||||
"\n",
|
||||
"This will cause computations to error-out immediately on production of a NaN.\n",
|
||||
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
|
||||
"\n",
|
||||
"⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions!\n"
|
||||
"* adding `from jax.config import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
|
||||
"\n",
|
||||
"* adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
|
||||
"\n",
|
||||
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
|
||||
"\n",
|
||||
"There could be tricky situations that arise, like nans that only occur under a `@jit` but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute.\n",
|
||||
"\n",
|
||||
"If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line `env JAX_DEBUG_NANS=True ipython`, then ran this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```\n",
|
||||
"In [1]: import jax.numpy as np\n",
|
||||
"\n",
|
||||
"In [2]: np.divide(0., 0.)\n",
|
||||
"---------------------------------------------------------------------------\n",
|
||||
"FloatingPointError Traceback (most recent call last)\n",
|
||||
"<ipython-input-2-f2e2c413b437> in <module>()\n",
|
||||
"----> 1 np.divide(0., 0.)\n",
|
||||
"\n",
|
||||
".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n",
|
||||
" 343 return floor_divide(x1, x2)\n",
|
||||
" 344 else:\n",
|
||||
"--> 345 return true_divide(x1, x2)\n",
|
||||
" 346\n",
|
||||
" 347\n",
|
||||
"\n",
|
||||
".../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)\n",
|
||||
" 332 x1, x2 = _promote_shapes(x1, x2)\n",
|
||||
" 333 return lax.div(lax.convert_element_type(x1, result_dtype),\n",
|
||||
"--> 334 lax.convert_element_type(x2, result_dtype))\n",
|
||||
" 335\n",
|
||||
" 336\n",
|
||||
"\n",
|
||||
".../jax/jax/lax.pyc in div(x, y)\n",
|
||||
" 244 def div(x, y):\n",
|
||||
" 245 r\"\"\"Elementwise division: :math:`x \\over y`.\"\"\"\n",
|
||||
"--> 246 return div_p.bind(x, y)\n",
|
||||
" 247\n",
|
||||
" 248 def rem(x, y):\n",
|
||||
"\n",
|
||||
"... stack trace ...\n",
|
||||
"\n",
|
||||
".../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)\n",
|
||||
" 103 py_val = device_buffer.to_py()\n",
|
||||
" 104 if onp.any(onp.isnan(py_val)):\n",
|
||||
"--> 105 raise FloatingPointError(\"invalid value\")\n",
|
||||
" 106 else:\n",
|
||||
" 107 return DeviceArray(device_buffer, *result_shape)\n",
|
||||
"\n",
|
||||
"FloatingPointError: invalid value\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The nan generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jit`, as the example below shows."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```\n",
|
||||
"In [4]: from jax import jit\n",
|
||||
"\n",
|
||||
"In [5]: @jit\n",
|
||||
" ...: def f(x, y):\n",
|
||||
" ...: a = x * y\n",
|
||||
" ...: b = (x + y) / (x - y)\n",
|
||||
" ...: c = a + 2\n",
|
||||
" ...: return a + b * c\n",
|
||||
" ...:\n",
|
||||
"\n",
|
||||
"In [6]: x = np.array([2., 0.])\n",
|
||||
"\n",
|
||||
"In [7]: y = np.array([3., 0.])\n",
|
||||
"\n",
|
||||
"In [8]: f(x, y)\n",
|
||||
"Invalid value encountered in the output of a jit function. Calling the de-optimized version.\n",
|
||||
"---------------------------------------------------------------------------\n",
|
||||
"FloatingPointError Traceback (most recent call last)\n",
|
||||
"<ipython-input-8-811b7ddb3300> in <module>()\n",
|
||||
"----> 1 f(x, y)\n",
|
||||
"\n",
|
||||
" ... stack trace ...\n",
|
||||
"\n",
|
||||
"<ipython-input-5-619b39acbaac> in f(x, y)\n",
|
||||
" 2 def f(x, y):\n",
|
||||
" 3 a = x * y\n",
|
||||
"----> 4 b = (x + y) / (x - y)\n",
|
||||
" 5 c = a + 2\n",
|
||||
" 6 return a + b * c\n",
|
||||
"\n",
|
||||
".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n",
|
||||
" 343 return floor_divide(x1, x2)\n",
|
||||
" 344 else:\n",
|
||||
"--> 345 return true_divide(x1, x2)\n",
|
||||
" 346\n",
|
||||
" 347\n",
|
||||
"\n",
|
||||
".../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)\n",
|
||||
" 332 x1, x2 = _promote_shapes(x1, x2)\n",
|
||||
" 333 return lax.div(lax.convert_element_type(x1, result_dtype),\n",
|
||||
"--> 334 lax.convert_element_type(x2, result_dtype))\n",
|
||||
" 335\n",
|
||||
" 336\n",
|
||||
"\n",
|
||||
".../jax/jax/lax.pyc in div(x, y)\n",
|
||||
" 244 def div(x, y):\n",
|
||||
" 245 r\"\"\"Elementwise division: :math:`x \\over y`.\"\"\"\n",
|
||||
"--> 246 return div_p.bind(x, y)\n",
|
||||
" 247\n",
|
||||
" 248 def rem(x, y):\n",
|
||||
"\n",
|
||||
" ... stack trace ...\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When this code sees a nan in the output of an `@jit` function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with `%debug` to inspect all the values to figure out the error.\n",
|
||||
"\n",
|
||||
"⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions!"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user