Change the default jnp.take mode to "fill".

Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.

The previous behavior can be approximated using the "clip" or "wrap" fill modes.

PiperOrigin-RevId: 445130143
This commit is contained in:
Peter Hawkins 2022-04-28 06:01:22 -07:00 committed by jax authors
parent 611759d0ce
commit 0b470361da
4 changed files with 16 additions and 19 deletions

View File

@ -24,7 +24,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
previous versions of JAX, invalid indices were clamped into range. The previous versions of JAX, invalid indices were clamped into range. The
previous behavior can be restored by passing `mode="clip"`. previous behavior can be restored by passing `mode="clip"`.
* Scatter operations, such as `x.at[...].set(...)`, now have "drop" semantics. * {func}`jax.numpy.take` now defaults to `mode="fill"`, which returns
invalid values (e.g., NaN) for out-of-bounds indices.
* Scatter operations, such as `x.at[...].set(...)`, now have `"drop"` semantics.
This has no effect on the scatter operation itself, but it means that when This has no effect on the scatter operation itself, but it means that when
differentiated the gradient of a scatter will yield zero cotangents for differentiated the gradient of a scatter will yield zero cotangents for
out-of-bounds indices. Previously out-of-bounds indices were clamped into out-of-bounds indices. Previously out-of-bounds indices were clamped into

View File

@ -2995,7 +2995,7 @@ mlir.register_lowering(squeeze_p, _squeeze_lower)
def _shape_as_value(shape: core.Shape): def shape_as_value(shape: core.Shape):
"""Converts a shape that may contain Poly values into a JAX value.""" """Converts a shape that may contain Poly values into a JAX value."""
if len(shape) == 0: if len(shape) == 0:
return full((0,), np.array(0, np.int64)) return full((0,), np.array(0, np.int64))

View File

@ -1162,12 +1162,13 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
"""Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking.""" """Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking."""
dnums = dimension_numbers dnums = dimension_numbers
intarray = partial(np.array, dtype=np.int64) intarray = partial(np.array, dtype=np.int64)
operand_dims = lax._shape_as_value(operand.shape) operand_dims = lax.shape_as_value(operand.shape)
indices = lax.convert_element_type(indices, np.int64) indices = lax.convert_element_type(indices, np.int64)
num_batch_dims = len(indices.shape) - 1 num_batch_dims = len(indices.shape) - 1
upper_bound = (operand_dims[intarray(dnums.start_index_map)] - upper_bound = (
intarray(slice_sizes)[intarray(dnums.start_index_map)]) operand_dims[intarray(dnums.start_index_map)] -
lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)])
mask = lax.bitwise_and( mask = lax.bitwise_and(
lax.ge(indices, np.int64(0)), lax.ge(indices, np.int64(0)),
lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims))))) lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
@ -1466,7 +1467,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i] upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims) for i in dnums.scatter_dims_to_operand_dims)
# Stack upper_bounds into a DeviceArray[n] # Stack upper_bounds into a DeviceArray[n]
upper_bound = lax._shape_as_value(upper_bounds) upper_bound = lax.shape_as_value(upper_bounds)
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max) upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,)) (len(indices.shape) - 1,))

View File

@ -3352,9 +3352,10 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
@_wraps(np.take, skip_params=['out'], lax_description="""\ @_wraps(np.take, skip_params=['out'], lax_description="""\
In the JAX version, the ``mode`` argument defaults to a special version of ``"clip"`` In the JAX version, the ``mode`` argument defaults to a special mode
that handles negative indices. See :attr:`jax.numpy.ndarray.at` for more discussion (``"fill"``) that returns invalid values (e.g., NaN) for out-of-bounds indices.
of out-of-bounds indexing in JAX. See :attr:`jax.numpy.ndarray.at` for more discussion of out-of-bounds indexing
in JAX.
""") """)
def take(a, indices, axis: Optional[int] = None, out=None, mode=None): def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
return _take(a, indices, None if axis is None else operator.index(axis), out, return _take(a, indices, None if axis is None else operator.index(axis), out,
@ -3374,23 +3375,16 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
else: else:
axis_idx = _canonicalize_axis(axis, ndim(a)) axis_idx = _canonicalize_axis(axis, ndim(a))
if mode is None: if mode is None or mode == "fill":
# TODO(phawkins): change default mode to "fill" and delete this case. gather_mode = lax.GatherScatterMode.FILL_OR_DROP
# lax.gather() does not support negative indices, so we wrap them here # lax.gather() does not support negative indices, so we wrap them here
indices = where(indices < 0, indices + a.shape[axis_idx], indices) indices = where(indices < 0, indices + a.shape[axis_idx], indices)
gather_mode = lax.GatherScatterMode.CLIP
elif mode == "raise": elif mode == "raise":
# TODO(phawkins): we have no way to report out of bounds errors yet. # TODO(phawkins): we have no way to report out of bounds errors yet.
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.") raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
elif mode == "wrap": elif mode == "wrap":
indices = mod(indices, _lax_const(indices, a.shape[axis_idx])) indices = mod(indices, _lax_const(indices, a.shape[axis_idx]))
gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS
elif mode == "fill":
# Undocumented non-standard mode corresponding to the fill_or_drop mode on
# lax.gather()
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
# lax.gather() does not support negative indices, so we wrap them here
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
elif mode == "clip": elif mode == "clip":
gather_mode = lax.GatherScatterMode.CLIP gather_mode = lax.GatherScatterMode.CLIP
else: else:
@ -3439,7 +3433,7 @@ TAKE_ALONG_AXIS_DOC = """
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
an optional ``mode`` parameter controlling how out-of-bounds indices should be an optional ``mode`` parameter controlling how out-of-bounds indices should be
handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``). handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
See :attr:`jax.numpy.ndarray.at` for futrher discussion of out-of-bounds See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds
indexing in JAX. indexing in JAX.
""" """