jnp.ndarray.at: deprecate passing additional arguments by position

This commit is contained in:
Jake VanderPlas 2023-03-13 10:04:39 -07:00
parent 1925aa1109
commit 6dd0e0153a
3 changed files with 52 additions and 10 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.7 ## jax 0.4.7
* Deprecations
* Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
## jaxlib 0.4.7 ## jaxlib 0.4.7
## jax 0.4.6 (Mar 9, 2023) ## jax 0.4.6 (Mar 9, 2023)

View File

@ -26,7 +26,8 @@ rules for the underlying :code:`lax` primitives.
import builtins import builtins
import collections import collections
from functools import partial from functools import partial, wraps
import inspect
import math import math
import operator import operator
import types import types
@ -5326,6 +5327,27 @@ class _IndexUpdateHelper:
Array.at.__doc__ = _IndexUpdateHelper.__doc__ Array.at.__doc__ = _IndexUpdateHelper.__doc__
# TODO(jakevdp): remove these deprecation warnings after June 2023
def allow_pass_by_position_with_warning(f):
@wraps(f)
def wrapped(*args, **kwargs):
sig = inspect.signature(f)
try:
sig.bind(*args, **kwargs)
except TypeError:
argspec = inspect.getfullargspec(f)
n_positional = len(argspec.args)
keywords = argspec.kwonlyargs[:len(args) - n_positional]
warnings.warn(
f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. "
f"Pass by keyword instead", category=FutureWarning, stacklevel=2)
converted_kwargs = dict(zip(keywords, args[n_positional:]))
return f(*args[:n_positional], **converted_kwargs, **kwargs)
else:
return f(*args, **kwargs)
return wrapped
class _IndexUpdateRef: class _IndexUpdateRef:
"""Helper object to call indexed update functions for an (advanced) index. """Helper object to call indexed update functions for an (advanced) index.
@ -5342,7 +5364,8 @@ class _IndexUpdateRef:
def __repr__(self): def __repr__(self):
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})" return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
def get(self, indices_are_sorted=False, unique_indices=False, @allow_pass_by_position_with_warning
def get(self, *, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None): mode=None, fill_value=None):
"""Equivalent to ``x[idx]``. """Equivalent to ``x[idx]``.
@ -5358,7 +5381,8 @@ class _IndexUpdateRef:
unique_indices=unique_indices, mode=mode, unique_indices=unique_indices, mode=mode,
fill_value=fill_value) fill_value=fill_value)
def set(self, values, indices_are_sorted=False, unique_indices=False, @allow_pass_by_position_with_warning
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None): mode=None):
"""Pure equivalent of ``x[idx] = y``. """Pure equivalent of ``x[idx] = y``.
@ -5371,7 +5395,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted, indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode) unique_indices=unique_indices, mode=mode)
def apply(self, func, indices_are_sorted=False, unique_indices=False, @allow_pass_by_position_with_warning
def apply(self, func, *, indices_are_sorted=False, unique_indices=False,
mode=None): mode=None):
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
@ -5394,7 +5419,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted, indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode) unique_indices=unique_indices, mode=mode)
def add(self, values, indices_are_sorted=False, unique_indices=False, @allow_pass_by_position_with_warning
def add(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None): mode=None):
"""Pure equivalent of ``x[idx] += y``. """Pure equivalent of ``x[idx] += y``.
@ -5408,7 +5434,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted, indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode) unique_indices=unique_indices, mode=mode)
def multiply(self, values, indices_are_sorted=False, unique_indices=False, @allow_pass_by_position_with_warning
def multiply(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None): mode=None):
"""Pure equivalent of ``x[idx] *= y``. """Pure equivalent of ``x[idx] *= y``.
@ -5424,7 +5451,8 @@ class _IndexUpdateRef:
mode=mode) mode=mode)
mul = multiply mul = multiply
def divide(self, values, indices_are_sorted=False, unique_indices=False, @allow_pass_by_position_with_warning
def divide(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None): mode=None):
"""Pure equivalent of ``x[idx] /= y``. """Pure equivalent of ``x[idx] /= y``.
@ -5440,7 +5468,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted, indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)) unique_indices=unique_indices, mode=mode))
def power(self, values, indices_are_sorted=False, unique_indices=False, @allow_pass_by_position_with_warning
def power(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None): mode=None):
"""Pure equivalent of ``x[idx] **= y``. """Pure equivalent of ``x[idx] **= y``.
@ -5456,7 +5485,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted, indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)) unique_indices=unique_indices, mode=mode))
def min(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811 @allow_pass_by_position_with_warning
def min(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None): mode=None):
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``. """Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
@ -5471,7 +5501,8 @@ class _IndexUpdateRef:
indices_are_sorted=indices_are_sorted, indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode) unique_indices=unique_indices, mode=mode)
def max(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811 @allow_pass_by_position_with_warning
def max(self, values, *, indices_are_sorted=False, unique_indices=False,
mode=None): mode=None):
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``. """Pure equivalent of ``x[idx] = maximum(x[idx], y)``.

View File

@ -929,6 +929,13 @@ class IndexingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros(2).at[0.].set(1.) jnp.zeros(2).at[0.].set(1.)
def testIndexingPositionalArgumentWarning(self):
x = jnp.arange(4)
with self.assertWarnsRegex(
FutureWarning, "Passing 'indices_are_sorted' by position is deprecated"):
out = x.at[5].set(1, True, mode='drop')
self.assertArraysEqual(out, x)
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
x = jnp.arange(5, dtype=jnp.int32) + 1 x = jnp.arange(5, dtype=jnp.int32) + 1
self.assertAllClose(x, x[:10]) self.assertAllClose(x, x[:10])