sync common gotchas notebook

This commit is contained in:
Jake VanderPlas 2021-05-21 12:24:28 -07:00
parent 8c0aa3880b
commit b04d0c75a9

View File

@ -531,7 +531,7 @@
"id": "eoXrGARWypdR"
},
"source": [
"However, raising an error on other accelerators can be more difficult. Therefore, JAX does not raise an error, instead the index is clamped to the bounds of the array, meaning that for this example the last value of the array will be returned."
"However, raising an error from code running on an accelerator can be difficult or impossible. Therefore, JAX must choose some non-error behavior for out of bounds indexing (akin to how invalid floating point arithmetic results in `NaN`). When the indexing operation is an array index update (e.g. `index_add` or `scatter`-like primitives), updates at out-of-bounds indices will be skipped; when the operation is an array index retrieval (e.g. NumPy indexing or `gather`-like primitives) the index is clamped to the bounds of the array since __something__ must be returned. For example, the last value of the array will be returned from this indexing operation:"
]
},
{
@ -563,7 +563,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that due to this behavior jnp.nanargmin and jnp.nanargmax return -1 for slices consisting of NaNs whereas Numpy would throw an error."
"Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n",
"\n",
"Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
]
},
{