From 0b470361dac51fb4f5ab2f720f1cf35e442db005 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 28 Apr 2022 06:01:22 -0700 Subject: [PATCH] Change the default jnp.take mode to "fill". Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices. The previous behavior can be approximated using the "clip" or "wrap" fill modes. PiperOrigin-RevId: 445130143 --- CHANGELOG.md | 4 +++- jax/_src/lax/lax.py | 2 +- jax/_src/lax/slicing.py | 9 +++++---- jax/_src/numpy/lax_numpy.py | 20 +++++++------------- 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb0656c79..844f26bb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing `mode="clip"`. - * Scatter operations, such as `x.at[...].set(...)`, now have "drop" semantics. + * {func}`jax.numpy.take` now defaults to `mode="fill"`, which returns + invalid values (e.g., NaN) for out-of-bounds indices. + * Scatter operations, such as `x.at[...].set(...)`, now have `"drop"` semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 0bc5af10b..bdb4b943f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2995,7 +2995,7 @@ mlir.register_lowering(squeeze_p, _squeeze_lower) -def _shape_as_value(shape: core.Shape): +def shape_as_value(shape: core.Shape): """Converts a shape that may contain Poly values into a JAX value.""" if len(shape) == 0: return full((0,), np.array(0, np.int64)) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 220425995..6f75f6cc0 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1162,12 +1162,13 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes, """Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking.""" dnums = dimension_numbers intarray = partial(np.array, dtype=np.int64) - operand_dims = lax._shape_as_value(operand.shape) + operand_dims = lax.shape_as_value(operand.shape) indices = lax.convert_element_type(indices, np.int64) num_batch_dims = len(indices.shape) - 1 - upper_bound = (operand_dims[intarray(dnums.start_index_map)] - - intarray(slice_sizes)[intarray(dnums.start_index_map)]) + upper_bound = ( + operand_dims[intarray(dnums.start_index_map)] - + lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)]) mask = lax.bitwise_and( lax.ge(indices, np.int64(0)), lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims))))) @@ -1466,7 +1467,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums): upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i] for i in dnums.scatter_dims_to_operand_dims) # Stack upper_bounds into a DeviceArray[n] - upper_bound = lax._shape_as_value(upper_bounds) + upper_bound = lax.shape_as_value(upper_bounds) upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max) upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, (len(indices.shape) - 1,)) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 83025e6ac..59ccec0bb 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3352,9 +3352,10 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'): @_wraps(np.take, skip_params=['out'], lax_description="""\ -In the JAX version, the ``mode`` argument defaults to a special version of ``"clip"`` -that handles negative indices. See :attr:`jax.numpy.ndarray.at` for more discussion -of out-of-bounds indexing in JAX. +In the JAX version, the ``mode`` argument defaults to a special mode +(``"fill"``) that returns invalid values (e.g., NaN) for out-of-bounds indices. +See :attr:`jax.numpy.ndarray.at` for more discussion of out-of-bounds indexing +in JAX. """) def take(a, indices, axis: Optional[int] = None, out=None, mode=None): return _take(a, indices, None if axis is None else operator.index(axis), out, @@ -3374,23 +3375,16 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None): else: axis_idx = _canonicalize_axis(axis, ndim(a)) - if mode is None: - # TODO(phawkins): change default mode to "fill" and delete this case. + if mode is None or mode == "fill": + gather_mode = lax.GatherScatterMode.FILL_OR_DROP # lax.gather() does not support negative indices, so we wrap them here indices = where(indices < 0, indices + a.shape[axis_idx], indices) - gather_mode = lax.GatherScatterMode.CLIP elif mode == "raise": # TODO(phawkins): we have no way to report out of bounds errors yet. raise NotImplementedError("The 'raise' mode to jnp.take is not supported.") elif mode == "wrap": indices = mod(indices, _lax_const(indices, a.shape[axis_idx])) gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS - elif mode == "fill": - # Undocumented non-standard mode corresponding to the fill_or_drop mode on - # lax.gather() - gather_mode = lax.GatherScatterMode.FILL_OR_DROP - # lax.gather() does not support negative indices, so we wrap them here - indices = where(indices < 0, indices + a.shape[axis_idx], indices) elif mode == "clip": gather_mode = lax.GatherScatterMode.CLIP else: @@ -3439,7 +3433,7 @@ TAKE_ALONG_AXIS_DOC = """ Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes an optional ``mode`` parameter controlling how out-of-bounds indices should be handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``). -See :attr:`jax.numpy.ndarray.at` for futrher discussion of out-of-bounds +See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds indexing in JAX. """