DOC: improve documentation of OOB indices in jnp.take

This commit is contained in:
Jake VanderPlas 2023-01-31 15:59:06 -08:00
parent 957adbd5ea
commit 14a0fe08c8

View File

@ -3632,23 +3632,24 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
@_wraps(np.take, skip_params=['out'],
lax_description="""
The JAX version adds several extra parameters, described below, which are forwarded
to :func:`jax.lax.gather` for finer control over indexing.""",
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound
index semantics can be specified via the ``mode`` parameter (see below).
""",
extra_params="""
mode : string, default="fill"
Out-of-bounds indexing mode. The default mode="fill" returns invalid values
(e.g. NaN) for out-of bounds indices. See :attr:`jax.numpy.ndarray.at` for
more discussion of out-of-bounds indexing in JAX.
(e.g. NaN) for out-of bounds indices (see also ``fill_value`` below).
For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`.
fill_value : optional
The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored
otherwise. Defaults to NaN for inexact types, the largest negative value for
signed types, the largest positive value for unsigned types, and True for booleans.
unique_indices : bool, default=False
If True, the implementation will assume that the indices are unique,
which can result in more efficient execution on some backends.
indices_are_sorted : bool, default=False
If True, the implementation will assume that the indices are sorted in
ascending order, which can lead to more efficient execution on some backends.
fill_value : optional
The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored
otherwise. Defaults to NaN for inexact types, the largest negative value for
signed types, the largest positive value for unsigned types, and True for booleans.
""")
def take(a, indices, axis: Optional[int] = None, out=None, mode=None,
unique_indices=False, indices_are_sorted=False, fill_value=None):
@ -5168,9 +5169,8 @@ class _IndexUpdateHelper:
in which conflicting updates are applied is implementation-defined and may be
nondeterministic (e.g., due to concurrency on some hardware platforms).
By default, JAX assumes that all indices are in-bounds. There is experimental
support for giving more precise semantics to out-of-bounds indexed accesses,
via the ``mode`` parameter (see below).
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound
index semantics can be specified via the ``mode`` parameter (see below).
Arguments
---------