mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Change the default jnp.take mode to "fill".
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices. The previous behavior can be approximated using the "clip" or "wrap" fill modes. PiperOrigin-RevId: 445130143
This commit is contained in:
parent
611759d0ce
commit
0b470361da
@ -24,7 +24,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
|
||||
previous versions of JAX, invalid indices were clamped into range. The
|
||||
previous behavior can be restored by passing `mode="clip"`.
|
||||
* Scatter operations, such as `x.at[...].set(...)`, now have "drop" semantics.
|
||||
* {func}`jax.numpy.take` now defaults to `mode="fill"`, which returns
|
||||
invalid values (e.g., NaN) for out-of-bounds indices.
|
||||
* Scatter operations, such as `x.at[...].set(...)`, now have `"drop"` semantics.
|
||||
This has no effect on the scatter operation itself, but it means that when
|
||||
differentiated the gradient of a scatter will yield zero cotangents for
|
||||
out-of-bounds indices. Previously out-of-bounds indices were clamped into
|
||||
|
@ -2995,7 +2995,7 @@ mlir.register_lowering(squeeze_p, _squeeze_lower)
|
||||
|
||||
|
||||
|
||||
def _shape_as_value(shape: core.Shape):
|
||||
def shape_as_value(shape: core.Shape):
|
||||
"""Converts a shape that may contain Poly values into a JAX value."""
|
||||
if len(shape) == 0:
|
||||
return full((0,), np.array(0, np.int64))
|
||||
|
@ -1162,12 +1162,13 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
|
||||
"""Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking."""
|
||||
dnums = dimension_numbers
|
||||
intarray = partial(np.array, dtype=np.int64)
|
||||
operand_dims = lax._shape_as_value(operand.shape)
|
||||
operand_dims = lax.shape_as_value(operand.shape)
|
||||
indices = lax.convert_element_type(indices, np.int64)
|
||||
num_batch_dims = len(indices.shape) - 1
|
||||
|
||||
upper_bound = (operand_dims[intarray(dnums.start_index_map)] -
|
||||
intarray(slice_sizes)[intarray(dnums.start_index_map)])
|
||||
upper_bound = (
|
||||
operand_dims[intarray(dnums.start_index_map)] -
|
||||
lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)])
|
||||
mask = lax.bitwise_and(
|
||||
lax.ge(indices, np.int64(0)),
|
||||
lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
|
||||
@ -1466,7 +1467,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
|
||||
upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
|
||||
for i in dnums.scatter_dims_to_operand_dims)
|
||||
# Stack upper_bounds into a DeviceArray[n]
|
||||
upper_bound = lax._shape_as_value(upper_bounds)
|
||||
upper_bound = lax.shape_as_value(upper_bounds)
|
||||
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
|
||||
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
||||
(len(indices.shape) - 1,))
|
||||
|
@ -3352,9 +3352,10 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
|
||||
|
||||
|
||||
@_wraps(np.take, skip_params=['out'], lax_description="""\
|
||||
In the JAX version, the ``mode`` argument defaults to a special version of ``"clip"``
|
||||
that handles negative indices. See :attr:`jax.numpy.ndarray.at` for more discussion
|
||||
of out-of-bounds indexing in JAX.
|
||||
In the JAX version, the ``mode`` argument defaults to a special mode
|
||||
(``"fill"``) that returns invalid values (e.g., NaN) for out-of-bounds indices.
|
||||
See :attr:`jax.numpy.ndarray.at` for more discussion of out-of-bounds indexing
|
||||
in JAX.
|
||||
""")
|
||||
def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
return _take(a, indices, None if axis is None else operator.index(axis), out,
|
||||
@ -3374,23 +3375,16 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
else:
|
||||
axis_idx = _canonicalize_axis(axis, ndim(a))
|
||||
|
||||
if mode is None:
|
||||
# TODO(phawkins): change default mode to "fill" and delete this case.
|
||||
if mode is None or mode == "fill":
|
||||
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
|
||||
# lax.gather() does not support negative indices, so we wrap them here
|
||||
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
|
||||
gather_mode = lax.GatherScatterMode.CLIP
|
||||
elif mode == "raise":
|
||||
# TODO(phawkins): we have no way to report out of bounds errors yet.
|
||||
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
|
||||
elif mode == "wrap":
|
||||
indices = mod(indices, _lax_const(indices, a.shape[axis_idx]))
|
||||
gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS
|
||||
elif mode == "fill":
|
||||
# Undocumented non-standard mode corresponding to the fill_or_drop mode on
|
||||
# lax.gather()
|
||||
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
|
||||
# lax.gather() does not support negative indices, so we wrap them here
|
||||
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
|
||||
elif mode == "clip":
|
||||
gather_mode = lax.GatherScatterMode.CLIP
|
||||
else:
|
||||
@ -3439,7 +3433,7 @@ TAKE_ALONG_AXIS_DOC = """
|
||||
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
|
||||
an optional ``mode`` parameter controlling how out-of-bounds indices should be
|
||||
handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
|
||||
See :attr:`jax.numpy.ndarray.at` for futrher discussion of out-of-bounds
|
||||
See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds
|
||||
indexing in JAX.
|
||||
"""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user