diff --git a/CHANGELOG.md b/CHANGELOG.md index a8c75c416..c441b7aa5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list * `jax.interpreters.pxla.make_sharded_device_array` has been removed. This was deprecated in JAX version 0.4.6: use `jax.make_array_from_single_device_arrays` instead. + * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is + no longer supported, after being deprecated in JAX version 0.4.7. + For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` * Breaking changes * To fix a corner case, calls to {func}`jax.lax.cond` with five diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index ebfa33b6a..cc97a1afc 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -23,7 +23,6 @@ __all__ = ['register_jax_array_methods'] import abc from functools import partial, wraps -import inspect from typing import Any, Optional, Union import warnings @@ -465,27 +464,6 @@ class _IndexUpdateHelper: return f"_IndexUpdateHelper({repr(self.array)})" -# TODO(jakevdp): remove these deprecation warnings after June 2023 -def allow_pass_by_position_with_warning(f): - @wraps(f) - def wrapped(*args, **kwargs): - sig = inspect.signature(f) - try: - sig.bind(*args, **kwargs) - except TypeError: - argspec = inspect.getfullargspec(f) - n_positional = len(argspec.args) - keywords = argspec.kwonlyargs[:len(args) - n_positional] - warnings.warn( - f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. " - f"Pass by keyword instead", category=FutureWarning, stacklevel=2) - converted_kwargs = dict(unsafe_zip(keywords, args[n_positional:])) - return f(*args[:n_positional], **converted_kwargs, **kwargs) - else: - return f(*args, **kwargs) - return wrapped - - class _IndexUpdateRef: """Helper object to call indexed update functions for an (advanced) index. @@ -502,7 +480,6 @@ class _IndexUpdateRef: def __repr__(self): return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})" - @allow_pass_by_position_with_warning def get(self, *, indices_are_sorted=False, unique_indices=False, mode=None, fill_value=None): """Equivalent to ``x[idx]``. @@ -519,7 +496,6 @@ class _IndexUpdateRef: unique_indices=unique_indices, mode=mode, fill_value=fill_value) - @allow_pass_by_position_with_warning def set(self, values, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``x[idx] = y``. @@ -533,7 +509,6 @@ class _IndexUpdateRef: indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - @allow_pass_by_position_with_warning def apply(self, func, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. @@ -557,7 +532,6 @@ class _IndexUpdateRef: indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - @allow_pass_by_position_with_warning def add(self, values, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``x[idx] += y``. @@ -572,7 +546,6 @@ class _IndexUpdateRef: indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - @allow_pass_by_position_with_warning def multiply(self, values, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``x[idx] *= y``. @@ -589,7 +562,6 @@ class _IndexUpdateRef: mode=mode) mul = multiply - @allow_pass_by_position_with_warning def divide(self, values, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``x[idx] /= y``. @@ -606,7 +578,6 @@ class _IndexUpdateRef: indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode)) - @allow_pass_by_position_with_warning def power(self, values, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``x[idx] **= y``. @@ -623,7 +594,6 @@ class _IndexUpdateRef: indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode)) - @allow_pass_by_position_with_warning def min(self, values, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. @@ -639,7 +609,6 @@ class _IndexUpdateRef: indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - @allow_pass_by_position_with_warning def max(self, values, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 9b8e34214..8b19c0a2c 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -988,13 +988,6 @@ class IndexingTest(jtu.JaxTestCase): with self.assertRaisesRegex(TypeError, msg): jnp.zeros(2)[:, 'abc'] - def testIndexingPositionalArgumentWarning(self): - x = jnp.arange(4) - with self.assertWarnsRegex( - FutureWarning, "Passing 'indices_are_sorted' by position is deprecated"): - out = x.at[5].set(1, True, mode='drop') - self.assertArraysEqual(out, x) - def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 x = jnp.arange(5, dtype=jnp.int32) + 1 self.assertAllClose(x, x[:10])