mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove jax.ops.index... functions.
These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
This commit is contained in:
parent
3948fde842
commit
f51a05a889
14
CHANGELOG.md
14
CHANGELOG.md
@ -11,6 +11,20 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.3.2 (Unreleased)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.3.1...main).
|
||||
* Changes:
|
||||
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by
|
||||
default. To recover the previous behavior, use the `jax.test_util.with_config`
|
||||
decorator:
|
||||
```python
|
||||
@jtu.with_config(jax_numpy_rank_promotion='allow')
|
||||
class MyTestCase(jtu.JaxTestCase):
|
||||
...
|
||||
```
|
||||
* The functions `jax.ops.index_update`, `jax.ops.index_add`, which were
|
||||
deprecated in 0.2.22, have been removed. Please use
|
||||
[the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html)
|
||||
instead, e.g., `x.at[idx].set(y)`.
|
||||
|
||||
|
||||
## jaxlib 0.3.1 (Unreleased)
|
||||
* Changes
|
||||
|
@ -1,6 +1,6 @@
|
||||
|
||||
jax.ops package
|
||||
=================
|
||||
===============
|
||||
|
||||
.. currentmodule:: jax.ops
|
||||
|
||||
@ -8,68 +8,12 @@ jax.ops package
|
||||
|
||||
.. _syntactic-sugar-for-ops:
|
||||
|
||||
Indexed update operators
|
||||
------------------------
|
||||
The functions ``jax.ops.index_update``, ``jax.ops.index_add``, etc., which were
|
||||
deprecated in JAX 0.2.22, have been removed. Please use the
|
||||
:attr:`jax.numpy.ndarray.at` property on JAX arrays instead.
|
||||
|
||||
JAX is intended to be used with a functional style of programming, and
|
||||
does not support NumPy-style indexed assignment directly. Instead, JAX provides
|
||||
alternative pure functional operators for indexed updates to arrays.
|
||||
|
||||
JAX array types have a property ``at``, which can be used as
|
||||
follows (where ``idx`` is a NumPy index expression).
|
||||
|
||||
========================= ===================================================
|
||||
Alternate syntax Equivalent in-place expression
|
||||
========================= ===================================================
|
||||
``x.at[idx].get()`` ``x[idx]``
|
||||
``x.at[idx].set(y)`` ``x[idx] = y``
|
||||
``x.at[idx].add(y)`` ``x[idx] += y``
|
||||
``x.at[idx].multiply(y)`` ``x[idx] *= y``
|
||||
``x.at[idx].divide(y)`` ``x[idx] /= y``
|
||||
``x.at[idx].power(y)`` ``x[idx] **= y``
|
||||
``x.at[idx].min(y)`` ``x[idx] = np.minimum(x[idx], y)``
|
||||
``x.at[idx].max(y)`` ``x[idx] = np.maximum(x[idx], y)``
|
||||
========================= ===================================================
|
||||
|
||||
None of these expressions modify the original `x`; instead they return
|
||||
a modified copy of `x`. However, inside a :py:func:`jit` compiled function,
|
||||
expressions like ``x = x.at[idx].set(y)`` are guaranteed to be applied in-place.
|
||||
|
||||
Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple
|
||||
indices refer to the same location, all updates will be applied (NumPy would
|
||||
only apply the last update, rather than applying all updates.) The order
|
||||
in which conflicting updates are applied is implementation-defined and may be
|
||||
nondeterministic (e.g., due to concurrency on some hardware platforms).
|
||||
|
||||
By default, JAX assumes that all indices are in-bounds. There is experimental
|
||||
support for giving more precise semantics to out-of-bounds indexed accesses,
|
||||
via the ``mode`` parameter to functions such as ``get`` and ``set``. Valid
|
||||
values for ``mode`` include ``"clip"``, which means that out-of-bounds indices
|
||||
will be clamped into range, and ``"fill"``/``"drop"``, which are aliases and
|
||||
mean that out-of-bounds reads will be filled with a scalar ``fill_value``,
|
||||
and out-of-bounds writes will be discarded.
|
||||
|
||||
|
||||
Indexed update functions (deprecated)
|
||||
-------------------------------------
|
||||
|
||||
The following functions are aliases for the ``x.at[idx].set(y)``
|
||||
style operators. Use the ``x.at[idx]`` operators instead.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
index
|
||||
index_update
|
||||
index_add
|
||||
index_mul
|
||||
index_min
|
||||
index_max
|
||||
|
||||
|
||||
|
||||
Other operators
|
||||
---------------
|
||||
Segment reduction operators
|
||||
---------------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
# Helpers for indexed updates.
|
||||
|
||||
import warnings
|
||||
import sys
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
@ -113,279 +112,6 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
return lax._convert_element_type(out, dtype, weak_type)
|
||||
|
||||
|
||||
class _Indexable(object):
|
||||
"""Helper object for building indexes for indexed update functions.
|
||||
|
||||
.. deprecated:: 0.2.22
|
||||
Prefer the use of :attr:`jax.numpy.ndarray.at`. If an explicit index
|
||||
is needed, use :func:`jax.numpy.index_exp`.
|
||||
|
||||
This is a singleton object that overrides the :code:`__getitem__` method
|
||||
to return the index it is passed.
|
||||
|
||||
>>> jax.ops.index[1:2, 3, None, ..., ::2]
|
||||
(slice(1, 2, None), 3, None, Ellipsis, slice(None, None, 2))
|
||||
"""
|
||||
__slots__ = ()
|
||||
|
||||
def __getitem__(self, index):
|
||||
return index
|
||||
|
||||
#: Index object singleton
|
||||
index = _Indexable()
|
||||
|
||||
|
||||
def index_add(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] += y`.
|
||||
|
||||
.. deprecated:: 0.2.22
|
||||
Prefer the use of :attr:`jax.numpy.ndarray.at`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
|
||||
|
||||
x[idx] += y
|
||||
|
||||
Note the `index_add` operator is pure; `x` itself is
|
||||
not modified, instead the new value that `x` would have taken is returned.
|
||||
|
||||
Unlike the NumPy code :code:`x[idx] += y`, if multiple indices refer to the
|
||||
same location the updates will be summed. (NumPy would only apply the last
|
||||
update, rather than summing the updates.) The order in which conflicting
|
||||
updates are applied is implementation-defined and may be nondeterministic
|
||||
(e.g., due to concurrency on some hardware platforms).
|
||||
|
||||
Args:
|
||||
x: an array with the values to be updated.
|
||||
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
|
||||
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
|
||||
convenient syntactic sugar for forming indices is via the
|
||||
:data:`jax.ops.index` object.
|
||||
y: the array of updates. `y` must be broadcastable to the shape of the
|
||||
array that would be returned by `x[idx]`.
|
||||
indices_are_sorted: whether `idx` is known to be sorted
|
||||
unique_indices: whether `idx` is known to be free of duplicates
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_add(x, jnp.index_exp[2:4, 3:], 6.)
|
||||
DeviceArray([[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 7., 7., 7.],
|
||||
[1., 1., 1., 7., 7., 7.],
|
||||
[1., 1., 1., 1., 1., 1.]], dtype=float32)
|
||||
"""
|
||||
warnings.warn("index_add is deprecated. Use x.at[idx].add(y) instead.",
|
||||
DeprecationWarning)
|
||||
return _scatter_update(
|
||||
x, idx, y, lax.scatter_add, indices_are_sorted, unique_indices)
|
||||
|
||||
|
||||
def index_mul(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] *= y`.
|
||||
|
||||
.. deprecated:: 0.2.22
|
||||
Prefer the use of :attr:`jax.numpy.ndarray.at`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
|
||||
|
||||
x[idx] *= y
|
||||
|
||||
Note the `index_mul` operator is pure; `x` itself is
|
||||
not modified, instead the new value that `x` would have taken is returned.
|
||||
|
||||
Unlike the NumPy code :code:`x[idx] *= y`, if multiple indices refer to the
|
||||
same location the updates will be multiplied. (NumPy would only apply the last
|
||||
update, rather than multiplying the updates.) The order in which conflicting
|
||||
updates are applied is implementation-defined and may be nondeterministic
|
||||
(e.g., due to concurrency on some hardware platforms).
|
||||
|
||||
Args:
|
||||
x: an array with the values to be updated.
|
||||
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
|
||||
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
|
||||
convenient syntactic sugar for forming indices is via the
|
||||
:data:`jax.ops.index` object.
|
||||
y: the array of updates. `y` must be broadcastable to the shape of the
|
||||
array that would be returned by `x[idx]`.
|
||||
indices_are_sorted: whether `idx` is known to be sorted
|
||||
unique_indices: whether `idx` is known to be free of duplicates
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_mul(x, jnp.index_exp[2:4, 3:], 6.)
|
||||
DeviceArray([[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
[1., 1., 1., 1., 1., 1.]], dtype=float32)
|
||||
"""
|
||||
warnings.warn("index_mul is deprecated. Use x.at[idx].mul(y) instead.",
|
||||
DeprecationWarning)
|
||||
return _scatter_update(x, idx, y, lax.scatter_mul,
|
||||
indices_are_sorted, unique_indices)
|
||||
|
||||
|
||||
def index_min(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] = minimum(x[idx], y)`.
|
||||
|
||||
.. deprecated:: 0.2.22
|
||||
Prefer the use of :attr:`jax.numpy.ndarray.at`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
|
||||
|
||||
x[idx] = minimum(x[idx], y)
|
||||
|
||||
Note the `index_min` operator is pure; `x` itself is
|
||||
not modified, instead the new value that `x` would have taken is returned.
|
||||
|
||||
Unlike the NumPy code :code:`x[idx] = minimum(x[idx], y)`, if multiple indices
|
||||
refer to the same location the final value will be the overall min. (NumPy
|
||||
would only look at the last update, rather than all of the updates.)
|
||||
|
||||
Args:
|
||||
x: an array with the values to be updated.
|
||||
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
|
||||
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
|
||||
convenient syntactic sugar for forming indices is via the
|
||||
:data:`jax.ops.index` object.
|
||||
y: the array of updates. `y` must be broadcastable to the shape of the
|
||||
array that would be returned by `x[idx]`.
|
||||
indices_are_sorted: whether `idx` is known to be sorted
|
||||
unique_indices: whether `idx` is known to be free of duplicates
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_min(x, jnp.index_exp[2:4, 3:], 0.)
|
||||
DeviceArray([[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 1., 1., 1.]], dtype=float32)
|
||||
"""
|
||||
warnings.warn("index_min is deprecated. Use x.at[idx].min(y) instead.",
|
||||
DeprecationWarning)
|
||||
return _scatter_update(
|
||||
x, idx, y, lax.scatter_min, indices_are_sorted, unique_indices)
|
||||
|
||||
def index_max(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] = maximum(x[idx], y)`.
|
||||
|
||||
.. deprecated:: 0.2.22
|
||||
Prefer the use of :attr:`jax.numpy.ndarray.at`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
|
||||
|
||||
x[idx] = maximum(x[idx], y)
|
||||
|
||||
Note the `index_max` operator is pure; `x` itself is
|
||||
not modified, instead the new value that `x` would have taken is returned.
|
||||
|
||||
Unlike the NumPy code :code:`x[idx] = maximum(x[idx], y)`, if multiple indices
|
||||
refer to the same location the final value will be the overall max. (NumPy
|
||||
would only look at the last update, rather than all of the updates.)
|
||||
|
||||
Args:
|
||||
x: an array with the values to be updated.
|
||||
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
|
||||
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
|
||||
convenient syntactic sugar for forming indices is via the
|
||||
:data:`jax.ops.index` object.
|
||||
y: the array of updates. `y` must be broadcastable to the shape of the
|
||||
array that would be returned by `x[idx]`.
|
||||
indices_are_sorted: whether `idx` is known to be sorted
|
||||
unique_indices: whether `idx` is known to be free of duplicates
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_max(x, jnp.index_exp[2:4, 3:], 6.)
|
||||
DeviceArray([[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
[1., 1., 1., 1., 1., 1.]], dtype=float32)
|
||||
"""
|
||||
warnings.warn("index_max is deprecated. Use x.at[idx].max(y) instead.",
|
||||
DeprecationWarning)
|
||||
return _scatter_update(
|
||||
x, idx, y, lax.scatter_max, indices_are_sorted, unique_indices)
|
||||
|
||||
def index_update(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] = y`.
|
||||
|
||||
.. deprecated:: 0.2.22
|
||||
Prefer the use of :attr:`jax.numpy.ndarray.at`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
|
||||
|
||||
x[idx] = y
|
||||
|
||||
Note the `index_update` operator is pure; `x` itself is
|
||||
not modified, instead the new value that `x` would have taken is returned.
|
||||
|
||||
Unlike NumPy's :code:`x[idx] = y`, if multiple indices refer to the same
|
||||
location it is undefined which update is chosen; JAX may choose the order of
|
||||
updates arbitrarily and nondeterministically (e.g., due to concurrent
|
||||
updates on some hardware platforms).
|
||||
|
||||
Args:
|
||||
x: an array with the values to be updated.
|
||||
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
|
||||
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
|
||||
convenient syntactic sugar for forming indices is via the
|
||||
:data:`jax.ops.index` object.
|
||||
y: the array of updates. `y` must be broadcastable to the shape of the
|
||||
array that would be returned by `x[idx]`.
|
||||
indices_are_sorted: whether `idx` is known to be sorted
|
||||
unique_indices: whether `idx` is known to be free of duplicates
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_update(x, jnp.index_exp[::2, 3:], 6.)
|
||||
DeviceArray([[1., 1., 1., 6., 6., 6.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 6., 6., 6.]], dtype=float32)
|
||||
"""
|
||||
warnings.warn("index_update is deprecated. Use x.at[idx].set(y) instead.",
|
||||
DeprecationWarning)
|
||||
return _scatter_update(
|
||||
x, idx, y, lax.scatter, indices_are_sorted, unique_indices)
|
||||
|
||||
|
||||
def _get_identity(op, dtype):
|
||||
|
@ -14,12 +14,6 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.ops.scatter import (
|
||||
index as index,
|
||||
index_add as index_add,
|
||||
index_mul as index_mul,
|
||||
index_update as index_update,
|
||||
index_min as index_min,
|
||||
index_max as index_max,
|
||||
segment_sum as segment_sum,
|
||||
segment_prod as segment_prod,
|
||||
segment_min as segment_min,
|
||||
|
Loading…
x
Reference in New Issue
Block a user