mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Require index update optional arguments to be passed by keyword.
Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023) PiperOrigin-RevId: 544620172
This commit is contained in:
parent
3f9da19c63
commit
d0e75ca117
@ -18,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `jax.interpreters.pxla.make_sharded_device_array` has been removed. This was
|
||||
deprecated in JAX version 0.4.6: use `jax.make_array_from_single_device_arrays`
|
||||
instead.
|
||||
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
|
||||
no longer supported, after being deprecated in JAX version 0.4.7.
|
||||
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
||||
|
||||
* Breaking changes
|
||||
* To fix a corner case, calls to {func}`jax.lax.cond` with five
|
||||
|
@ -23,7 +23,6 @@ __all__ = ['register_jax_array_methods']
|
||||
|
||||
import abc
|
||||
from functools import partial, wraps
|
||||
import inspect
|
||||
from typing import Any, Optional, Union
|
||||
import warnings
|
||||
|
||||
@ -465,27 +464,6 @@ class _IndexUpdateHelper:
|
||||
return f"_IndexUpdateHelper({repr(self.array)})"
|
||||
|
||||
|
||||
# TODO(jakevdp): remove these deprecation warnings after June 2023
|
||||
def allow_pass_by_position_with_warning(f):
|
||||
@wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
sig = inspect.signature(f)
|
||||
try:
|
||||
sig.bind(*args, **kwargs)
|
||||
except TypeError:
|
||||
argspec = inspect.getfullargspec(f)
|
||||
n_positional = len(argspec.args)
|
||||
keywords = argspec.kwonlyargs[:len(args) - n_positional]
|
||||
warnings.warn(
|
||||
f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. "
|
||||
f"Pass by keyword instead", category=FutureWarning, stacklevel=2)
|
||||
converted_kwargs = dict(unsafe_zip(keywords, args[n_positional:]))
|
||||
return f(*args[:n_positional], **converted_kwargs, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
|
||||
class _IndexUpdateRef:
|
||||
"""Helper object to call indexed update functions for an (advanced) index.
|
||||
|
||||
@ -502,7 +480,6 @@ class _IndexUpdateRef:
|
||||
def __repr__(self):
|
||||
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
|
||||
|
||||
@allow_pass_by_position_with_warning
|
||||
def get(self, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None, fill_value=None):
|
||||
"""Equivalent to ``x[idx]``.
|
||||
@ -519,7 +496,6 @@ class _IndexUpdateRef:
|
||||
unique_indices=unique_indices, mode=mode,
|
||||
fill_value=fill_value)
|
||||
|
||||
@allow_pass_by_position_with_warning
|
||||
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] = y``.
|
||||
@ -533,7 +509,6 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
@allow_pass_by_position_with_warning
|
||||
def apply(self, func, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
|
||||
@ -557,7 +532,6 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
@allow_pass_by_position_with_warning
|
||||
def add(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] += y``.
|
||||
@ -572,7 +546,6 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
@allow_pass_by_position_with_warning
|
||||
def multiply(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] *= y``.
|
||||
@ -589,7 +562,6 @@ class _IndexUpdateRef:
|
||||
mode=mode)
|
||||
mul = multiply
|
||||
|
||||
@allow_pass_by_position_with_warning
|
||||
def divide(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] /= y``.
|
||||
@ -606,7 +578,6 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode))
|
||||
|
||||
@allow_pass_by_position_with_warning
|
||||
def power(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] **= y``.
|
||||
@ -623,7 +594,6 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode))
|
||||
|
||||
@allow_pass_by_position_with_warning
|
||||
def min(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
|
||||
@ -639,7 +609,6 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
@allow_pass_by_position_with_warning
|
||||
def max(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
|
||||
|
@ -988,13 +988,6 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
jnp.zeros(2)[:, 'abc']
|
||||
|
||||
def testIndexingPositionalArgumentWarning(self):
|
||||
x = jnp.arange(4)
|
||||
with self.assertWarnsRegex(
|
||||
FutureWarning, "Passing 'indices_are_sorted' by position is deprecated"):
|
||||
out = x.at[5].set(1, True, mode='drop')
|
||||
self.assertArraysEqual(out, x)
|
||||
|
||||
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
|
||||
x = jnp.arange(5, dtype=jnp.int32) + 1
|
||||
self.assertAllClose(x, x[:10])
|
||||
|
Loading…
x
Reference in New Issue
Block a user