mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Change the default jnp.take mode to "fill".
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices. The previous behavior can be approximated using the "clip" or "wrap" fill modes. PiperOrigin-RevId: 445130143
This commit is contained in:
parent
611759d0ce
commit
0b470361da
@ -24,7 +24,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
|
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
|
||||||
previous versions of JAX, invalid indices were clamped into range. The
|
previous versions of JAX, invalid indices were clamped into range. The
|
||||||
previous behavior can be restored by passing `mode="clip"`.
|
previous behavior can be restored by passing `mode="clip"`.
|
||||||
* Scatter operations, such as `x.at[...].set(...)`, now have "drop" semantics.
|
* {func}`jax.numpy.take` now defaults to `mode="fill"`, which returns
|
||||||
|
invalid values (e.g., NaN) for out-of-bounds indices.
|
||||||
|
* Scatter operations, such as `x.at[...].set(...)`, now have `"drop"` semantics.
|
||||||
This has no effect on the scatter operation itself, but it means that when
|
This has no effect on the scatter operation itself, but it means that when
|
||||||
differentiated the gradient of a scatter will yield zero cotangents for
|
differentiated the gradient of a scatter will yield zero cotangents for
|
||||||
out-of-bounds indices. Previously out-of-bounds indices were clamped into
|
out-of-bounds indices. Previously out-of-bounds indices were clamped into
|
||||||
|
@ -2995,7 +2995,7 @@ mlir.register_lowering(squeeze_p, _squeeze_lower)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _shape_as_value(shape: core.Shape):
|
def shape_as_value(shape: core.Shape):
|
||||||
"""Converts a shape that may contain Poly values into a JAX value."""
|
"""Converts a shape that may contain Poly values into a JAX value."""
|
||||||
if len(shape) == 0:
|
if len(shape) == 0:
|
||||||
return full((0,), np.array(0, np.int64))
|
return full((0,), np.array(0, np.int64))
|
||||||
|
@ -1162,12 +1162,13 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
|
|||||||
"""Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking."""
|
"""Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking."""
|
||||||
dnums = dimension_numbers
|
dnums = dimension_numbers
|
||||||
intarray = partial(np.array, dtype=np.int64)
|
intarray = partial(np.array, dtype=np.int64)
|
||||||
operand_dims = lax._shape_as_value(operand.shape)
|
operand_dims = lax.shape_as_value(operand.shape)
|
||||||
indices = lax.convert_element_type(indices, np.int64)
|
indices = lax.convert_element_type(indices, np.int64)
|
||||||
num_batch_dims = len(indices.shape) - 1
|
num_batch_dims = len(indices.shape) - 1
|
||||||
|
|
||||||
upper_bound = (operand_dims[intarray(dnums.start_index_map)] -
|
upper_bound = (
|
||||||
intarray(slice_sizes)[intarray(dnums.start_index_map)])
|
operand_dims[intarray(dnums.start_index_map)] -
|
||||||
|
lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)])
|
||||||
mask = lax.bitwise_and(
|
mask = lax.bitwise_and(
|
||||||
lax.ge(indices, np.int64(0)),
|
lax.ge(indices, np.int64(0)),
|
||||||
lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
|
lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
|
||||||
@ -1466,7 +1467,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
|
|||||||
upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
|
upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
|
||||||
for i in dnums.scatter_dims_to_operand_dims)
|
for i in dnums.scatter_dims_to_operand_dims)
|
||||||
# Stack upper_bounds into a DeviceArray[n]
|
# Stack upper_bounds into a DeviceArray[n]
|
||||||
upper_bound = lax._shape_as_value(upper_bounds)
|
upper_bound = lax.shape_as_value(upper_bounds)
|
||||||
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
|
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
|
||||||
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
||||||
(len(indices.shape) - 1,))
|
(len(indices.shape) - 1,))
|
||||||
|
@ -3352,9 +3352,10 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
|
|||||||
|
|
||||||
|
|
||||||
@_wraps(np.take, skip_params=['out'], lax_description="""\
|
@_wraps(np.take, skip_params=['out'], lax_description="""\
|
||||||
In the JAX version, the ``mode`` argument defaults to a special version of ``"clip"``
|
In the JAX version, the ``mode`` argument defaults to a special mode
|
||||||
that handles negative indices. See :attr:`jax.numpy.ndarray.at` for more discussion
|
(``"fill"``) that returns invalid values (e.g., NaN) for out-of-bounds indices.
|
||||||
of out-of-bounds indexing in JAX.
|
See :attr:`jax.numpy.ndarray.at` for more discussion of out-of-bounds indexing
|
||||||
|
in JAX.
|
||||||
""")
|
""")
|
||||||
def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||||
return _take(a, indices, None if axis is None else operator.index(axis), out,
|
return _take(a, indices, None if axis is None else operator.index(axis), out,
|
||||||
@ -3374,23 +3375,16 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
|||||||
else:
|
else:
|
||||||
axis_idx = _canonicalize_axis(axis, ndim(a))
|
axis_idx = _canonicalize_axis(axis, ndim(a))
|
||||||
|
|
||||||
if mode is None:
|
if mode is None or mode == "fill":
|
||||||
# TODO(phawkins): change default mode to "fill" and delete this case.
|
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
|
||||||
# lax.gather() does not support negative indices, so we wrap them here
|
# lax.gather() does not support negative indices, so we wrap them here
|
||||||
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
|
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
|
||||||
gather_mode = lax.GatherScatterMode.CLIP
|
|
||||||
elif mode == "raise":
|
elif mode == "raise":
|
||||||
# TODO(phawkins): we have no way to report out of bounds errors yet.
|
# TODO(phawkins): we have no way to report out of bounds errors yet.
|
||||||
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
|
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
|
||||||
elif mode == "wrap":
|
elif mode == "wrap":
|
||||||
indices = mod(indices, _lax_const(indices, a.shape[axis_idx]))
|
indices = mod(indices, _lax_const(indices, a.shape[axis_idx]))
|
||||||
gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS
|
gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS
|
||||||
elif mode == "fill":
|
|
||||||
# Undocumented non-standard mode corresponding to the fill_or_drop mode on
|
|
||||||
# lax.gather()
|
|
||||||
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
|
|
||||||
# lax.gather() does not support negative indices, so we wrap them here
|
|
||||||
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
|
|
||||||
elif mode == "clip":
|
elif mode == "clip":
|
||||||
gather_mode = lax.GatherScatterMode.CLIP
|
gather_mode = lax.GatherScatterMode.CLIP
|
||||||
else:
|
else:
|
||||||
@ -3439,7 +3433,7 @@ TAKE_ALONG_AXIS_DOC = """
|
|||||||
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
|
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
|
||||||
an optional ``mode`` parameter controlling how out-of-bounds indices should be
|
an optional ``mode`` parameter controlling how out-of-bounds indices should be
|
||||||
handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
|
handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
|
||||||
See :attr:`jax.numpy.ndarray.at` for futrher discussion of out-of-bounds
|
See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds
|
||||||
indexing in JAX.
|
indexing in JAX.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user