mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[JAX] Update users of jax.ops.index...
functions, which are deprecated.
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing. * remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`. * update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs. PiperOrigin-RevId: 404395250
This commit is contained in:
parent
b09501f80e
commit
9fee130d6b
@ -115,7 +115,8 @@ class _Indexable(object):
|
||||
"""Helper object for building indexes for indexed update functions.
|
||||
|
||||
.. deprecated:: 0.2.22
|
||||
Prefer the use of :attr:`jax.numpy.ndarray.at`.
|
||||
Prefer the use of :attr:`jax.numpy.ndarray.at`. If an explicit index
|
||||
is needed, use :func:`jax.numpy.index_exp`.
|
||||
|
||||
This is a singleton object that overrides the :code:`__getitem__` method
|
||||
to return the index it is passed.
|
||||
@ -171,7 +172,7 @@ def index_add(x: Array,
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_add(x, jax.ops.index[2:4, 3:], 6.)
|
||||
>>> jax.ops.index_add(x, jnp.index_exp[2:4, 3:], 6.)
|
||||
DeviceArray([[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 7., 7., 7.],
|
||||
@ -223,7 +224,7 @@ def index_mul(x: Array,
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_mul(x, jax.ops.index[2:4, 3:], 6.)
|
||||
>>> jax.ops.index_mul(x, jnp.index_exp[2:4, 3:], 6.)
|
||||
DeviceArray([[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
@ -273,7 +274,7 @@ def index_min(x: Array,
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_min(x, jax.ops.index[2:4, 3:], 0.)
|
||||
>>> jax.ops.index_min(x, jnp.index_exp[2:4, 3:], 0.)
|
||||
DeviceArray([[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 0., 0., 0.],
|
||||
@ -322,7 +323,7 @@ def index_max(x: Array,
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_max(x, jax.ops.index[2:4, 3:], 6.)
|
||||
>>> jax.ops.index_max(x, jnp.index_exp[2:4, 3:], 6.)
|
||||
DeviceArray([[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
@ -372,7 +373,7 @@ def index_update(x: Array,
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_update(x, jax.ops.index[::2, 3:], 6.)
|
||||
>>> jax.ops.index_update(x, jnp.index_exp[::2, 3:], 6.)
|
||||
DeviceArray([[1., 1., 1., 6., 6., 6.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
|
@ -2917,7 +2917,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for dtype in default_dtypes
|
||||
))
|
||||
def testDigitize(self, xshape, binshape, right, reverse, dtype):
|
||||
order = jax.ops.index[::-1] if reverse else jax.ops.index[:]
|
||||
order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:]
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
|
||||
np_fun = lambda x, bins: np.digitize(x, bins, right=right)
|
||||
|
Loading…
x
Reference in New Issue
Block a user