Remove jax.ops.index... functions.

These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
This commit is contained in:
Peter Hawkins 2022-02-17 15:23:44 -05:00
parent 3948fde842
commit f51a05a889
4 changed files with 20 additions and 342 deletions

View File

@ -11,6 +11,20 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.2 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.1...main).
* Changes:
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by
default. To recover the previous behavior, use the `jax.test_util.with_config`
decorator:
```python
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
...
```
* The functions `jax.ops.index_update`, `jax.ops.index_add`, which were
deprecated in 0.2.22, have been removed. Please use
[the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html)
instead, e.g., `x.at[idx].set(y)`.
## jaxlib 0.3.1 (Unreleased)
* Changes

View File

@ -1,6 +1,6 @@
jax.ops package
=================
===============
.. currentmodule:: jax.ops
@ -8,68 +8,12 @@ jax.ops package
.. _syntactic-sugar-for-ops:
Indexed update operators
------------------------
The functions ``jax.ops.index_update``, ``jax.ops.index_add``, etc., which were
deprecated in JAX 0.2.22, have been removed. Please use the
:attr:`jax.numpy.ndarray.at` property on JAX arrays instead.
JAX is intended to be used with a functional style of programming, and
does not support NumPy-style indexed assignment directly. Instead, JAX provides
alternative pure functional operators for indexed updates to arrays.
JAX array types have a property ``at``, which can be used as
follows (where ``idx`` is a NumPy index expression).
========================= ===================================================
Alternate syntax Equivalent in-place expression
========================= ===================================================
``x.at[idx].get()`` ``x[idx]``
``x.at[idx].set(y)`` ``x[idx] = y``
``x.at[idx].add(y)`` ``x[idx] += y``
``x.at[idx].multiply(y)`` ``x[idx] *= y``
``x.at[idx].divide(y)`` ``x[idx] /= y``
``x.at[idx].power(y)`` ``x[idx] **= y``
``x.at[idx].min(y)`` ``x[idx] = np.minimum(x[idx], y)``
``x.at[idx].max(y)`` ``x[idx] = np.maximum(x[idx], y)``
========================= ===================================================
None of these expressions modify the original `x`; instead they return
a modified copy of `x`. However, inside a :py:func:`jit` compiled function,
expressions like ``x = x.at[idx].set(y)`` are guaranteed to be applied in-place.
Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple
indices refer to the same location, all updates will be applied (NumPy would
only apply the last update, rather than applying all updates.) The order
in which conflicting updates are applied is implementation-defined and may be
nondeterministic (e.g., due to concurrency on some hardware platforms).
By default, JAX assumes that all indices are in-bounds. There is experimental
support for giving more precise semantics to out-of-bounds indexed accesses,
via the ``mode`` parameter to functions such as ``get`` and ``set``. Valid
values for ``mode`` include ``"clip"``, which means that out-of-bounds indices
will be clamped into range, and ``"fill"``/``"drop"``, which are aliases and
mean that out-of-bounds reads will be filled with a scalar ``fill_value``,
and out-of-bounds writes will be discarded.
Indexed update functions (deprecated)
-------------------------------------
The following functions are aliases for the ``x.at[idx].set(y)``
style operators. Use the ``x.at[idx]`` operators instead.
.. autosummary::
:toctree: _autosummary
index
index_update
index_add
index_mul
index_min
index_max
Other operators
---------------
Segment reduction operators
---------------------------
.. autosummary::
:toctree: _autosummary

View File

@ -14,7 +14,6 @@
# Helpers for indexed updates.
import warnings
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Union
@ -113,279 +112,6 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
return lax._convert_element_type(out, dtype, weak_type)
class _Indexable(object):
"""Helper object for building indexes for indexed update functions.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`. If an explicit index
is needed, use :func:`jax.numpy.index_exp`.
This is a singleton object that overrides the :code:`__getitem__` method
to return the index it is passed.
>>> jax.ops.index[1:2, 3, None, ..., ::2]
(slice(1, 2, None), 3, None, Ellipsis, slice(None, None, 2))
"""
__slots__ = ()
def __getitem__(self, index):
return index
#: Index object singleton
index = _Indexable()
def index_add(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] += y`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] += y
Note the `index_add` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] += y`, if multiple indices refer to the
same location the updates will be summed. (NumPy would only apply the last
update, rather than summing the updates.) The order in which conflicting
updates are applied is implementation-defined and may be nondeterministic
(e.g., due to concurrency on some hardware platforms).
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_add(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 7., 7., 7.],
[1., 1., 1., 7., 7., 7.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_add is deprecated. Use x.at[idx].add(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter_add, indices_are_sorted, unique_indices)
def index_mul(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] *= y`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] *= y
Note the `index_mul` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] *= y`, if multiple indices refer to the
same location the updates will be multiplied. (NumPy would only apply the last
update, rather than multiplying the updates.) The order in which conflicting
updates are applied is implementation-defined and may be nondeterministic
(e.g., due to concurrency on some hardware platforms).
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_mul(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_mul is deprecated. Use x.at[idx].mul(y) instead.",
DeprecationWarning)
return _scatter_update(x, idx, y, lax.scatter_mul,
indices_are_sorted, unique_indices)
def index_min(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] = minimum(x[idx], y)`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] = minimum(x[idx], y)
Note the `index_min` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] = minimum(x[idx], y)`, if multiple indices
refer to the same location the final value will be the overall min. (NumPy
would only look at the last update, rather than all of the updates.)
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_min(x, jnp.index_exp[2:4, 3:], 0.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_min is deprecated. Use x.at[idx].min(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter_min, indices_are_sorted, unique_indices)
def index_max(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] = maximum(x[idx], y)`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] = maximum(x[idx], y)
Note the `index_max` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] = maximum(x[idx], y)`, if multiple indices
refer to the same location the final value will be the overall max. (NumPy
would only look at the last update, rather than all of the updates.)
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_max(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_max is deprecated. Use x.at[idx].max(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter_max, indices_are_sorted, unique_indices)
def index_update(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] = y`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] = y
Note the `index_update` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike NumPy's :code:`x[idx] = y`, if multiple indices refer to the same
location it is undefined which update is chosen; JAX may choose the order of
updates arbitrarily and nondeterministically (e.g., due to concurrent
updates on some hardware platforms).
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_update(x, jnp.index_exp[::2, 3:], 6.)
DeviceArray([[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.]], dtype=float32)
"""
warnings.warn("index_update is deprecated. Use x.at[idx].set(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter, indices_are_sorted, unique_indices)
def _get_identity(op, dtype):

View File

@ -14,12 +14,6 @@
# flake8: noqa: F401
from jax._src.ops.scatter import (
index as index,
index_add as index_add,
index_mul as index_mul,
index_update as index_update,
index_min as index_min,
index_max as index_max,
segment_sum as segment_sum,
segment_prod as segment_prod,
segment_min as segment_min,