[JAX] Update users of jax.ops.index... functions, which are deprecated.

* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.

PiperOrigin-RevId: 404395250
This commit is contained in:
Peter Hawkins 2021-10-19 16:38:42 -07:00 committed by jax authors
parent b09501f80e
commit 9fee130d6b
2 changed files with 8 additions and 7 deletions

View File

@ -115,7 +115,8 @@ class _Indexable(object):
"""Helper object for building indexes for indexed update functions.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Prefer the use of :attr:`jax.numpy.ndarray.at`. If an explicit index
is needed, use :func:`jax.numpy.index_exp`.
This is a singleton object that overrides the :code:`__getitem__` method
to return the index it is passed.
@ -171,7 +172,7 @@ def index_add(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_add(x, jax.ops.index[2:4, 3:], 6.)
>>> jax.ops.index_add(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 7., 7., 7.],
@ -223,7 +224,7 @@ def index_mul(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_mul(x, jax.ops.index[2:4, 3:], 6.)
>>> jax.ops.index_mul(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
@ -273,7 +274,7 @@ def index_min(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_min(x, jax.ops.index[2:4, 3:], 0.)
>>> jax.ops.index_min(x, jnp.index_exp[2:4, 3:], 0.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 0., 0., 0.],
@ -322,7 +323,7 @@ def index_max(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_max(x, jax.ops.index[2:4, 3:], 6.)
>>> jax.ops.index_max(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
@ -372,7 +373,7 @@ def index_update(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_update(x, jax.ops.index[::2, 3:], 6.)
>>> jax.ops.index_update(x, jnp.index_exp[::2, 3:], 6.)
DeviceArray([[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],

View File

@ -2917,7 +2917,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for dtype in default_dtypes
))
def testDigitize(self, xshape, binshape, right, reverse, dtype):
order = jax.ops.index[::-1] if reverse else jax.ops.index[:]
order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:]
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
np_fun = lambda x, bins: np.digitize(x, bins, right=right)