Require index update optional arguments to be passed by keyword.

Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023)

PiperOrigin-RevId: 544620172
This commit is contained in:
Jake VanderPlas 2023-06-30 04:29:43 -07:00 committed by jax authors
parent 3f9da19c63
commit d0e75ca117
3 changed files with 3 additions and 38 deletions

View File

@ -18,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.interpreters.pxla.make_sharded_device_array` has been removed. This was
deprecated in JAX version 0.4.6: use `jax.make_array_from_single_device_arrays`
instead.
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
no longer supported, after being deprecated in JAX version 0.4.7.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* Breaking changes
* To fix a corner case, calls to {func}`jax.lax.cond` with five

View File

@ -23,7 +23,6 @@ __all__ = ['register_jax_array_methods']
import abc
from functools import partial, wraps
import inspect
from typing import Any, Optional, Union
import warnings
@ -465,27 +464,6 @@ class _IndexUpdateHelper:
return f"_IndexUpdateHelper({repr(self.array)})"
# TODO(jakevdp): remove these deprecation warnings after June 2023
def allow_pass_by_position_with_warning(f):
@wraps(f)
def wrapped(*args, **kwargs):
sig = inspect.signature(f)
try:
sig.bind(*args, **kwargs)
except TypeError:
argspec = inspect.getfullargspec(f)
n_positional = len(argspec.args)
keywords = argspec.kwonlyargs[:len(args) - n_positional]
warnings.warn(
f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. "
f"Pass by keyword instead", category=FutureWarning, stacklevel=2)
converted_kwargs = dict(unsafe_zip(keywords, args[n_positional:]))
return f(*args[:n_positional], **converted_kwargs, **kwargs)
else:
return f(*args, **kwargs)
return wrapped
class _IndexUpdateRef:
"""Helper object to call indexed update functions for an (advanced) index.
@ -502,7 +480,6 @@ class _IndexUpdateRef:
def __repr__(self):
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
@allow_pass_by_position_with_warning
def get(self, *, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
"""Equivalent to ``x[idx]``.
@ -519,7 +496,6 @@ class _IndexUpdateRef:
unique_indices=unique_indices, mode=mode,
fill_value=fill_value)
@allow_pass_by_position_with_warning
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] = y``.
@ -533,7 +509,6 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
@allow_pass_by_position_with_warning
def apply(self, func, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
@ -557,7 +532,6 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
@allow_pass_by_position_with_warning
def add(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] += y``.
@ -572,7 +546,6 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
@allow_pass_by_position_with_warning
def multiply(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] *= y``.
@ -589,7 +562,6 @@ class _IndexUpdateRef:
mode=mode)
mul = multiply
@allow_pass_by_position_with_warning
def divide(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] /= y``.
@ -606,7 +578,6 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode))
@allow_pass_by_position_with_warning
def power(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] **= y``.
@ -623,7 +594,6 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode))
@allow_pass_by_position_with_warning
def min(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
@ -639,7 +609,6 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
@allow_pass_by_position_with_warning
def max(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.

View File

@ -988,13 +988,6 @@ class IndexingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, msg):
jnp.zeros(2)[:, 'abc']
def testIndexingPositionalArgumentWarning(self):
x = jnp.arange(4)
with self.assertWarnsRegex(
FutureWarning, "Passing 'indices_are_sorted' by position is deprecated"):
out = x.at[5].set(1, True, mode='drop')
self.assertArraysEqual(out, x)
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
x = jnp.arange(5, dtype=jnp.int32) + 1
self.assertAllClose(x, x[:10])