Merge pull request #14235 from jakevdp:take-doc

PiperOrigin-RevId: 506151220
This commit is contained in:
jax authors 2023-01-31 16:34:50 -08:00
commit b0202b6ae2

View File

@ -3632,23 +3632,24 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
@_wraps(np.take, skip_params=['out'],
lax_description="""
The JAX version adds several extra parameters, described below, which are forwarded
to :func:`jax.lax.gather` for finer control over indexing.""",
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound
index semantics can be specified via the ``mode`` parameter (see below).
""",
extra_params="""
mode : string, default="fill"
Out-of-bounds indexing mode. The default mode="fill" returns invalid values
(e.g. NaN) for out-of bounds indices. See :attr:`jax.numpy.ndarray.at` for
more discussion of out-of-bounds indexing in JAX.
(e.g. NaN) for out-of bounds indices (see also ``fill_value`` below).
For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`.
fill_value : optional
The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored
otherwise. Defaults to NaN for inexact types, the largest negative value for
signed types, the largest positive value for unsigned types, and True for booleans.
unique_indices : bool, default=False
If True, the implementation will assume that the indices are unique,
which can result in more efficient execution on some backends.
indices_are_sorted : bool, default=False
If True, the implementation will assume that the indices are sorted in
ascending order, which can lead to more efficient execution on some backends.
fill_value : optional
The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored
otherwise. Defaults to NaN for inexact types, the largest negative value for
signed types, the largest positive value for unsigned types, and True for booleans.
""")
def take(a, indices, axis: Optional[int] = None, out=None, mode=None,
unique_indices=False, indices_are_sorted=False, fill_value=None):
@ -5168,9 +5169,8 @@ class _IndexUpdateHelper:
in which conflicting updates are applied is implementation-defined and may be
nondeterministic (e.g., due to concurrency on some hardware platforms).
By default, JAX assumes that all indices are in-bounds. There is experimental
support for giving more precise semantics to out-of-bounds indexed accesses,
via the ``mode`` parameter (see below).
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound
index semantics can be specified via the ``mode`` parameter (see below).
Arguments
---------