jnp.ndarray.at: deprecate passing additional arguments by position

This commit is contained in:
Jake VanderPlas 2023-03-13 10:04:39 -07:00
parent 1925aa1109
commit 6dd0e0153a
3 changed files with 52 additions and 10 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.7
* Deprecations
* Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
## jaxlib 0.4.7
## jax 0.4.6 (Mar 9, 2023)

View File

@ -26,7 +26,8 @@ rules for the underlying :code:`lax` primitives.
import builtins
import collections
from functools import partial
from functools import partial, wraps
import inspect
import math
import operator
import types
@ -5326,6 +5327,27 @@ class _IndexUpdateHelper:
Array.at.__doc__ = _IndexUpdateHelper.__doc__
# TODO(jakevdp): remove these deprecation warnings after June 2023
def allow_pass_by_position_with_warning(f):
@wraps(f)
def wrapped(*args, **kwargs):
sig = inspect.signature(f)
try:
sig.bind(*args, **kwargs)
except TypeError:
argspec = inspect.getfullargspec(f)
n_positional = len(argspec.args)
keywords = argspec.kwonlyargs[:len(args) - n_positional]
warnings.warn(
f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. "
f"Pass by keyword instead", category=FutureWarning, stacklevel=2)
converted_kwargs = dict(zip(keywords, args[n_positional:]))
return f(*args[:n_positional], **converted_kwargs, **kwargs)
else:
return f(*args, **kwargs)
return wrapped
class _IndexUpdateRef:
"""Helper object to call indexed update functions for an (advanced) index.
@ -5342,7 +5364,8 @@ class _IndexUpdateRef:
def __repr__(self):
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
def get(self, indices_are_sorted=False, unique_indices=False,
@allow_pass_by_position_with_warning
def get(self, *, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
"""Equivalent to ``x[idx]``.
@ -5358,7 +5381,8 @@ class _IndexUpdateRef:
unique_indices=unique_indices, mode=mode,
fill_value=fill_value)
def set(self, values, indices_are_sorted=False, unique_indices=False,
@allow_pass_by_position_with_warning
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] = y``.
@ -5371,7 +5395,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
def apply(self, func, indices_are_sorted=False, unique_indices=False,
@allow_pass_by_position_with_warning
def apply(self, func, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
@ -5394,7 +5419,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
def add(self, values, indices_are_sorted=False, unique_indices=False,
@allow_pass_by_position_with_warning
def add(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] += y``.
@ -5408,7 +5434,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
def multiply(self, values, indices_are_sorted=False, unique_indices=False,
@allow_pass_by_position_with_warning
def multiply(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] *= y``.
@ -5424,7 +5451,8 @@ class _IndexUpdateRef:
mode=mode)
mul = multiply
def divide(self, values, indices_are_sorted=False, unique_indices=False,
@allow_pass_by_position_with_warning
def divide(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] /= y``.
@ -5440,7 +5468,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode))
def power(self, values, indices_are_sorted=False, unique_indices=False,
@allow_pass_by_position_with_warning
def power(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] **= y``.
@ -5456,7 +5485,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode))
def min(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
@allow_pass_by_position_with_warning
def min(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
@ -5471,7 +5501,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
def max(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
@allow_pass_by_position_with_warning
def max(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.

View File

@ -929,6 +929,13 @@ class IndexingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros(2).at[0.].set(1.)
def testIndexingPositionalArgumentWarning(self):
x = jnp.arange(4)
with self.assertWarnsRegex(
FutureWarning, "Passing 'indices_are_sorted' by position is deprecated"):
out = x.at[5].set(1, True, mode='drop')
self.assertArraysEqual(out, x)
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
x = jnp.arange(5, dtype=jnp.int32) + 1
self.assertAllClose(x, x[:10])