mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jnp.ndarray.at: deprecate passing additional arguments by position
This commit is contained in:
parent
1925aa1109
commit
6dd0e0153a
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
|
|
||||||
## jax 0.4.7
|
## jax 0.4.7
|
||||||
|
|
||||||
|
* Deprecations
|
||||||
|
* Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated.
|
||||||
|
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
||||||
|
|
||||||
## jaxlib 0.4.7
|
## jaxlib 0.4.7
|
||||||
|
|
||||||
## jax 0.4.6 (Mar 9, 2023)
|
## jax 0.4.6 (Mar 9, 2023)
|
||||||
|
@ -26,7 +26,8 @@ rules for the underlying :code:`lax` primitives.
|
|||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import collections
|
import collections
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
import types
|
import types
|
||||||
@ -5326,6 +5327,27 @@ class _IndexUpdateHelper:
|
|||||||
Array.at.__doc__ = _IndexUpdateHelper.__doc__
|
Array.at.__doc__ = _IndexUpdateHelper.__doc__
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(jakevdp): remove these deprecation warnings after June 2023
|
||||||
|
def allow_pass_by_position_with_warning(f):
|
||||||
|
@wraps(f)
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
sig = inspect.signature(f)
|
||||||
|
try:
|
||||||
|
sig.bind(*args, **kwargs)
|
||||||
|
except TypeError:
|
||||||
|
argspec = inspect.getfullargspec(f)
|
||||||
|
n_positional = len(argspec.args)
|
||||||
|
keywords = argspec.kwonlyargs[:len(args) - n_positional]
|
||||||
|
warnings.warn(
|
||||||
|
f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. "
|
||||||
|
f"Pass by keyword instead", category=FutureWarning, stacklevel=2)
|
||||||
|
converted_kwargs = dict(zip(keywords, args[n_positional:]))
|
||||||
|
return f(*args[:n_positional], **converted_kwargs, **kwargs)
|
||||||
|
else:
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class _IndexUpdateRef:
|
class _IndexUpdateRef:
|
||||||
"""Helper object to call indexed update functions for an (advanced) index.
|
"""Helper object to call indexed update functions for an (advanced) index.
|
||||||
|
|
||||||
@ -5342,7 +5364,8 @@ class _IndexUpdateRef:
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
|
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
|
||||||
|
|
||||||
def get(self, indices_are_sorted=False, unique_indices=False,
|
@allow_pass_by_position_with_warning
|
||||||
|
def get(self, *, indices_are_sorted=False, unique_indices=False,
|
||||||
mode=None, fill_value=None):
|
mode=None, fill_value=None):
|
||||||
"""Equivalent to ``x[idx]``.
|
"""Equivalent to ``x[idx]``.
|
||||||
|
|
||||||
@ -5358,7 +5381,8 @@ class _IndexUpdateRef:
|
|||||||
unique_indices=unique_indices, mode=mode,
|
unique_indices=unique_indices, mode=mode,
|
||||||
fill_value=fill_value)
|
fill_value=fill_value)
|
||||||
|
|
||||||
def set(self, values, indices_are_sorted=False, unique_indices=False,
|
@allow_pass_by_position_with_warning
|
||||||
|
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||||
mode=None):
|
mode=None):
|
||||||
"""Pure equivalent of ``x[idx] = y``.
|
"""Pure equivalent of ``x[idx] = y``.
|
||||||
|
|
||||||
@ -5371,7 +5395,8 @@ class _IndexUpdateRef:
|
|||||||
indices_are_sorted=indices_are_sorted,
|
indices_are_sorted=indices_are_sorted,
|
||||||
unique_indices=unique_indices, mode=mode)
|
unique_indices=unique_indices, mode=mode)
|
||||||
|
|
||||||
def apply(self, func, indices_are_sorted=False, unique_indices=False,
|
@allow_pass_by_position_with_warning
|
||||||
|
def apply(self, func, *, indices_are_sorted=False, unique_indices=False,
|
||||||
mode=None):
|
mode=None):
|
||||||
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
|
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
|
||||||
|
|
||||||
@ -5394,7 +5419,8 @@ class _IndexUpdateRef:
|
|||||||
indices_are_sorted=indices_are_sorted,
|
indices_are_sorted=indices_are_sorted,
|
||||||
unique_indices=unique_indices, mode=mode)
|
unique_indices=unique_indices, mode=mode)
|
||||||
|
|
||||||
def add(self, values, indices_are_sorted=False, unique_indices=False,
|
@allow_pass_by_position_with_warning
|
||||||
|
def add(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||||
mode=None):
|
mode=None):
|
||||||
"""Pure equivalent of ``x[idx] += y``.
|
"""Pure equivalent of ``x[idx] += y``.
|
||||||
|
|
||||||
@ -5408,7 +5434,8 @@ class _IndexUpdateRef:
|
|||||||
indices_are_sorted=indices_are_sorted,
|
indices_are_sorted=indices_are_sorted,
|
||||||
unique_indices=unique_indices, mode=mode)
|
unique_indices=unique_indices, mode=mode)
|
||||||
|
|
||||||
def multiply(self, values, indices_are_sorted=False, unique_indices=False,
|
@allow_pass_by_position_with_warning
|
||||||
|
def multiply(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||||
mode=None):
|
mode=None):
|
||||||
"""Pure equivalent of ``x[idx] *= y``.
|
"""Pure equivalent of ``x[idx] *= y``.
|
||||||
|
|
||||||
@ -5424,7 +5451,8 @@ class _IndexUpdateRef:
|
|||||||
mode=mode)
|
mode=mode)
|
||||||
mul = multiply
|
mul = multiply
|
||||||
|
|
||||||
def divide(self, values, indices_are_sorted=False, unique_indices=False,
|
@allow_pass_by_position_with_warning
|
||||||
|
def divide(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||||
mode=None):
|
mode=None):
|
||||||
"""Pure equivalent of ``x[idx] /= y``.
|
"""Pure equivalent of ``x[idx] /= y``.
|
||||||
|
|
||||||
@ -5440,7 +5468,8 @@ class _IndexUpdateRef:
|
|||||||
indices_are_sorted=indices_are_sorted,
|
indices_are_sorted=indices_are_sorted,
|
||||||
unique_indices=unique_indices, mode=mode))
|
unique_indices=unique_indices, mode=mode))
|
||||||
|
|
||||||
def power(self, values, indices_are_sorted=False, unique_indices=False,
|
@allow_pass_by_position_with_warning
|
||||||
|
def power(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||||
mode=None):
|
mode=None):
|
||||||
"""Pure equivalent of ``x[idx] **= y``.
|
"""Pure equivalent of ``x[idx] **= y``.
|
||||||
|
|
||||||
@ -5456,7 +5485,8 @@ class _IndexUpdateRef:
|
|||||||
indices_are_sorted=indices_are_sorted,
|
indices_are_sorted=indices_are_sorted,
|
||||||
unique_indices=unique_indices, mode=mode))
|
unique_indices=unique_indices, mode=mode))
|
||||||
|
|
||||||
def min(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
|
@allow_pass_by_position_with_warning
|
||||||
|
def min(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||||
mode=None):
|
mode=None):
|
||||||
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
|
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
|
||||||
|
|
||||||
@ -5471,7 +5501,8 @@ class _IndexUpdateRef:
|
|||||||
indices_are_sorted=indices_are_sorted,
|
indices_are_sorted=indices_are_sorted,
|
||||||
unique_indices=unique_indices, mode=mode)
|
unique_indices=unique_indices, mode=mode)
|
||||||
|
|
||||||
def max(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
|
@allow_pass_by_position_with_warning
|
||||||
|
def max(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||||
mode=None):
|
mode=None):
|
||||||
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
|
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
|
||||||
|
|
||||||
|
@ -929,6 +929,13 @@ class IndexingTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
||||||
jnp.zeros(2).at[0.].set(1.)
|
jnp.zeros(2).at[0.].set(1.)
|
||||||
|
|
||||||
|
def testIndexingPositionalArgumentWarning(self):
|
||||||
|
x = jnp.arange(4)
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
FutureWarning, "Passing 'indices_are_sorted' by position is deprecated"):
|
||||||
|
out = x.at[5].set(1, True, mode='drop')
|
||||||
|
self.assertArraysEqual(out, x)
|
||||||
|
|
||||||
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
|
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
|
||||||
x = jnp.arange(5, dtype=jnp.int32) + 1
|
x = jnp.arange(5, dtype=jnp.int32) + 1
|
||||||
self.assertAllClose(x, x[:10])
|
self.assertAllClose(x, x[:10])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user