DOC: add info about repeated indices to jax.ops docs

This commit is contained in:
Jake VanderPlas 2022-01-26 15:44:53 -08:00
parent 18a661497c
commit 928087ada0

View File

@ -35,6 +35,12 @@ None of these expressions modify the original `x`; instead they return
a modified copy of `x`. However, inside a :py:func:`jit` compiled function,
expressions like ``x = x.at[idx].set(y)`` are guaranteed to be applied in-place.
Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple
indices refer to the same location, all updates will be applied (NumPy would
only apply the last update, rather than applying all updates.) The order
in which conflicting updates are applied is implementation-defined and may be
nondeterministic (e.g., due to concurrency on some hardware platforms).
By default, JAX assumes that all indices are in-bounds. There is experimental
support for giving more precise semantics to out-of-bounds indexed accesses,
via the ``mode`` parameter to functions such as ``get`` and ``set``. Valid