mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jnp.ndarray.at: deprecate passing additional arguments by position
This commit is contained in:
parent
1925aa1109
commit
6dd0e0153a
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.7
|
||||
|
||||
* Deprecations
|
||||
* Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated.
|
||||
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
||||
|
||||
## jaxlib 0.4.7
|
||||
|
||||
## jax 0.4.6 (Mar 9, 2023)
|
||||
|
@ -26,7 +26,8 @@ rules for the underlying :code:`lax` primitives.
|
||||
|
||||
import builtins
|
||||
import collections
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
import inspect
|
||||
import math
|
||||
import operator
|
||||
import types
|
||||
@ -5326,6 +5327,27 @@ class _IndexUpdateHelper:
|
||||
Array.at.__doc__ = _IndexUpdateHelper.__doc__
|
||||
|
||||
|
||||
# TODO(jakevdp): remove these deprecation warnings after June 2023
|
||||
def allow_pass_by_position_with_warning(f):
|
||||
@wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
sig = inspect.signature(f)
|
||||
try:
|
||||
sig.bind(*args, **kwargs)
|
||||
except TypeError:
|
||||
argspec = inspect.getfullargspec(f)
|
||||
n_positional = len(argspec.args)
|
||||
keywords = argspec.kwonlyargs[:len(args) - n_positional]
|
||||
warnings.warn(
|
||||
f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. "
|
||||
f"Pass by keyword instead", category=FutureWarning, stacklevel=2)
|
||||
converted_kwargs = dict(zip(keywords, args[n_positional:]))
|
||||
return f(*args[:n_positional], **converted_kwargs, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
|
||||
class _IndexUpdateRef:
|
||||
"""Helper object to call indexed update functions for an (advanced) index.
|
||||
|
||||
@ -5342,7 +5364,8 @@ class _IndexUpdateRef:
|
||||
def __repr__(self):
|
||||
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
|
||||
|
||||
def get(self, indices_are_sorted=False, unique_indices=False,
|
||||
@allow_pass_by_position_with_warning
|
||||
def get(self, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None, fill_value=None):
|
||||
"""Equivalent to ``x[idx]``.
|
||||
|
||||
@ -5358,7 +5381,8 @@ class _IndexUpdateRef:
|
||||
unique_indices=unique_indices, mode=mode,
|
||||
fill_value=fill_value)
|
||||
|
||||
def set(self, values, indices_are_sorted=False, unique_indices=False,
|
||||
@allow_pass_by_position_with_warning
|
||||
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] = y``.
|
||||
|
||||
@ -5371,7 +5395,8 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
def apply(self, func, indices_are_sorted=False, unique_indices=False,
|
||||
@allow_pass_by_position_with_warning
|
||||
def apply(self, func, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
|
||||
|
||||
@ -5394,7 +5419,8 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
def add(self, values, indices_are_sorted=False, unique_indices=False,
|
||||
@allow_pass_by_position_with_warning
|
||||
def add(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] += y``.
|
||||
|
||||
@ -5408,7 +5434,8 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
def multiply(self, values, indices_are_sorted=False, unique_indices=False,
|
||||
@allow_pass_by_position_with_warning
|
||||
def multiply(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] *= y``.
|
||||
|
||||
@ -5424,7 +5451,8 @@ class _IndexUpdateRef:
|
||||
mode=mode)
|
||||
mul = multiply
|
||||
|
||||
def divide(self, values, indices_are_sorted=False, unique_indices=False,
|
||||
@allow_pass_by_position_with_warning
|
||||
def divide(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] /= y``.
|
||||
|
||||
@ -5440,7 +5468,8 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode))
|
||||
|
||||
def power(self, values, indices_are_sorted=False, unique_indices=False,
|
||||
@allow_pass_by_position_with_warning
|
||||
def power(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] **= y``.
|
||||
|
||||
@ -5456,7 +5485,8 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode))
|
||||
|
||||
def min(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
|
||||
@allow_pass_by_position_with_warning
|
||||
def min(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
|
||||
|
||||
@ -5471,7 +5501,8 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
def max(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
|
||||
@allow_pass_by_position_with_warning
|
||||
def max(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
|
||||
|
||||
|
@ -929,6 +929,13 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
||||
jnp.zeros(2).at[0.].set(1.)
|
||||
|
||||
def testIndexingPositionalArgumentWarning(self):
|
||||
x = jnp.arange(4)
|
||||
with self.assertWarnsRegex(
|
||||
FutureWarning, "Passing 'indices_are_sorted' by position is deprecated"):
|
||||
out = x.at[5].set(1, True, mode='drop')
|
||||
self.assertArraysEqual(out, x)
|
||||
|
||||
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
|
||||
x = jnp.arange(5, dtype=jnp.int32) + 1
|
||||
self.assertAllClose(x, x[:10])
|
||||
|
Loading…
x
Reference in New Issue
Block a user