Change the default jnp.take mode to "fill".

Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.

The previous behavior can be approximated using the "clip" or "wrap" fill modes.

PiperOrigin-RevId: 445130143
This commit is contained in:
Peter Hawkins 2022-04-28 06:01:22 -07:00 committed by jax authors
parent 611759d0ce
commit 0b470361da
4 changed files with 16 additions and 19 deletions

View File

@ -24,7 +24,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
previous versions of JAX, invalid indices were clamped into range. The
previous behavior can be restored by passing `mode="clip"`.
* Scatter operations, such as `x.at[...].set(...)`, now have "drop" semantics.
* {func}`jax.numpy.take` now defaults to `mode="fill"`, which returns
invalid values (e.g., NaN) for out-of-bounds indices.
* Scatter operations, such as `x.at[...].set(...)`, now have `"drop"` semantics.
This has no effect on the scatter operation itself, but it means that when
differentiated the gradient of a scatter will yield zero cotangents for
out-of-bounds indices. Previously out-of-bounds indices were clamped into

View File

@ -2995,7 +2995,7 @@ mlir.register_lowering(squeeze_p, _squeeze_lower)
def _shape_as_value(shape: core.Shape):
def shape_as_value(shape: core.Shape):
"""Converts a shape that may contain Poly values into a JAX value."""
if len(shape) == 0:
return full((0,), np.array(0, np.int64))

View File

@ -1162,12 +1162,13 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
"""Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking."""
dnums = dimension_numbers
intarray = partial(np.array, dtype=np.int64)
operand_dims = lax._shape_as_value(operand.shape)
operand_dims = lax.shape_as_value(operand.shape)
indices = lax.convert_element_type(indices, np.int64)
num_batch_dims = len(indices.shape) - 1
upper_bound = (operand_dims[intarray(dnums.start_index_map)] -
intarray(slice_sizes)[intarray(dnums.start_index_map)])
upper_bound = (
operand_dims[intarray(dnums.start_index_map)] -
lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)])
mask = lax.bitwise_and(
lax.ge(indices, np.int64(0)),
lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
@ -1466,7 +1467,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims)
# Stack upper_bounds into a DeviceArray[n]
upper_bound = lax._shape_as_value(upper_bounds)
upper_bound = lax.shape_as_value(upper_bounds)
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))

View File

@ -3352,9 +3352,10 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
@_wraps(np.take, skip_params=['out'], lax_description="""\
In the JAX version, the ``mode`` argument defaults to a special version of ``"clip"``
that handles negative indices. See :attr:`jax.numpy.ndarray.at` for more discussion
of out-of-bounds indexing in JAX.
In the JAX version, the ``mode`` argument defaults to a special mode
(``"fill"``) that returns invalid values (e.g., NaN) for out-of-bounds indices.
See :attr:`jax.numpy.ndarray.at` for more discussion of out-of-bounds indexing
in JAX.
""")
def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
return _take(a, indices, None if axis is None else operator.index(axis), out,
@ -3374,23 +3375,16 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
else:
axis_idx = _canonicalize_axis(axis, ndim(a))
if mode is None:
# TODO(phawkins): change default mode to "fill" and delete this case.
if mode is None or mode == "fill":
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
# lax.gather() does not support negative indices, so we wrap them here
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
gather_mode = lax.GatherScatterMode.CLIP
elif mode == "raise":
# TODO(phawkins): we have no way to report out of bounds errors yet.
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
elif mode == "wrap":
indices = mod(indices, _lax_const(indices, a.shape[axis_idx]))
gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS
elif mode == "fill":
# Undocumented non-standard mode corresponding to the fill_or_drop mode on
# lax.gather()
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
# lax.gather() does not support negative indices, so we wrap them here
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
elif mode == "clip":
gather_mode = lax.GatherScatterMode.CLIP
else:
@ -3439,7 +3433,7 @@ TAKE_ALONG_AXIS_DOC = """
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
an optional ``mode`` parameter controlling how out-of-bounds indices should be
handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
See :attr:`jax.numpy.ndarray.at` for futrher discussion of out-of-bounds
See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds
indexing in JAX.
"""