Promote the x.at[idx].set(y) operators as the preferred way to do indexed updates.

Mark the index_update() etc. operators as deprecated in the documentation.

Add new .divide and .power operators. Fixes #2694.
Add .multiply as an alias for .mul. To be more numpy-like we should probably prefer the longer names.
This commit is contained in:
Peter Hawkins 2021-05-10 17:44:18 -04:00
parent 1509f995ed
commit d005e38f78
4 changed files with 150 additions and 103 deletions

View File

@ -6,13 +6,39 @@ jax.ops package
.. automodule:: jax.ops
.. _syntactic-sugar-for-ops:
Indexed update operators
------------------------
JAX is intended to be used with a functional style of programming, and hence
JAX is intended to be used with a functional style of programming, and
does not support NumPy-style indexed assignment directly. Instead, JAX provides
pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
alternative pure functional operators for indexed updates to arrays.
JAX array types have a property ``at``, which can be used as
follows (where ``idx`` is a NumPy index expression).
========================= ===================================================
Alternate syntax Equivalent in-place expression
========================= ===================================================
``x.at[idx].set(y)`` ``x[idx] = y``
``x.at[idx].add(y)`` ``x[idx] += y``
``x.at[idx].multiply(y)`` ``x[idx] *= y``
``x.at[idx].divide(y)`` ``x[idx] /= y``
``x.at[idx].power(y)`` ``x[idx] **= y``
``x.at[idx].min(y)`` ``x[idx] = np.minimum(x[idx], y)``
``x.at[idx].max(y)`` ``x[idx] = np.maximum(x[idx], y)``
========================= ===================================================
None of these expressions modify the original `x`; instead they return
a modified copy of `x`.
Indexed update functions (deprecated)
-------------------------------------
The following functions are aliases for the ``x.at[idx].set(y)``
style operators. Prefer to use the ``x.at[idx]`` operators instead.
.. autosummary::
:toctree: _autosummary
@ -24,27 +50,7 @@ pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
index_min
index_max
.. _syntactic-sugar-for-ops:
Syntactic sugar for indexed update operators
--------------------------------------------
JAX also provides an alternate syntax for these indexed update operators.
Specifically, JAX ndarray types have a property ``at``, which can be used as
follows (where ``idx`` can be an arbitrary index expression).
==================== ===================================================
Alternate syntax Equivalent expression
==================== ===================================================
``x.at[idx].set(y)`` ``jax.ops.index_update(x, jax.ops.index[idx], y)``
``x.at[idx].add(y)`` ``jax.ops.index_add(x, jax.ops.index[idx], y)``
``x.at[idx].mul(y)`` ``jax.ops.index_mul(x, jax.ops.index[idx], y)``
``x.at[idx].min(y)`` ``jax.ops.index_min(x, jax.ops.index[idx], y)``
``x.at[idx].max(y)`` ``jax.ops.index_max(x, jax.ops.index[idx], y)``
==================== ===================================================
Note that none of these expressions modify the original `x`; instead they return
a modified copy of `x`.
Other operators
---------------

View File

@ -49,6 +49,7 @@ from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
from jax import lax
from jax._src.lax.lax import _device_put_raw
from jax import ops
from jax._src.ops import scatter
from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip,
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
from jax.tree_util import tree_leaves, tree_flatten, tree_map
@ -5842,14 +5843,17 @@ class _IndexUpdateHelper:
The ``at`` property is syntactic sugar for calling the indexed update functions
defined in :mod:`jax.ops`, and acts as a pure equivalent of in-place
modificatons. For further information, see `Syntactic Sugar for Index Update Operators
<https://jax.readthedocs.io/en/latest/jax.ops.html#syntactic-sugar-for-indexed-update-operators>`_.
modificatons. For further information, see `Indexed Update Operators
<https://jax.readthedocs.io/en/latest/jax.ops.html#indexed-update-operators>`_.
In particular:
- ``x = x.at[idx].set(y)`` is a pure equivalent of ``x[idx] = y``.
- ``x = x.at[idx].add(y)`` is a pure equivalent of ``x[idx] += y``.
- ``x = x.at[idx].mul(y)`` is a pure equivalent of ``x[idx] *= y``.
- ``x = x.at[idx].multiply(y)`` (aka ``mul``) is a pure equivalent of
``x[idx] *= y``.
- ``x = x.at[idx].divide(y)`` is a pure equivalent of ``x[idx] /= y``.
- ``x = x.at[idx].power(y)`` is a pure equivalent of ``x[idx] **= y``.
- ``x = x.at[idx].min(y)`` is a pure equivalent of
``x[idx] = minimum(x[idx], y)``.
- ``x = x.at[idx].max(y)`` is a pure equivalent of
@ -5866,6 +5870,8 @@ class _IndexUpdateHelper:
def __repr__(self):
return f"_IndexUpdateHelper({repr(self.array)})"
_power = power
_divide = divide
class _IndexUpdateRef:
"""Helper object to call indexed update functions for an (advanced) index.
@ -5886,74 +5892,100 @@ class _IndexUpdateRef:
def set(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] = y``.
``x.at[idx].set(y)`` is syntactic sugar for
``jax.ops.index_update(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] = y``.
See :mod:`jax.ops` for details.
"""
return ops.index_update(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
return scatter._scatter_update(self.array, self.index, values, lax.scatter,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
def add(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] += y``.
``x.at[idx].add(y)`` is syntactic sugar for
``jax.ops.index_add(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] += y``.
See :mod:`jax.ops` for details.
"""
return ops.index_add(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
return scatter._scatter_update(self.array, self.index, values,
lax.scatter_add,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
def mul(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] += y``.
def multiply(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] *= y``.
``x.at[idx].mul(y)`` is syntactic sugar for
``jax.ops.index_mul(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] *= y``.
See :mod:`jax.ops` for details.
"""
return ops.index_mul(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
return scatter._scatter_update(self.array, self.index, values,
lax.scatter_mul,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
mul = multiply
def divide(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] /= y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] /= y``.
See :mod:`jax.ops` for details.
"""
return _divide(
self.array,
scatter._scatter_update(ones_like(self.array), self.index, values,
lax.scatter_mul,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices))
def power(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] **= y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] **= y``.
See :mod:`jax.ops` for details.
"""
return _power(
self.array,
scatter._scatter_update(ones_like(self.array), self.index, values,
lax.scatter_mul,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices))
def min(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
``x.at[idx].min(y)`` is syntactic sugar for
``jax.ops.index_min(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>`
``x[idx] = minimum(x[idx], y)``.
See :mod:`jax.ops` for details.
"""
return ops.index_min(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
return scatter._scatter_update(self.array, self.index, values,
lax.scatter_min,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
def max(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
``x.at[idx].max(y)`` is syntactic sugar for
``jax.ops.index_max(x, jax.ops.index[idx], y)``, and
returns the value of ``x`` that would result from the NumPy-style
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>`
``x[idx] = maximum(x[idx], y)``.
See :mod:`jax.ops` for details.
"""
return ops.index_max(self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
return scatter._scatter_update(self.array, self.index, values,
lax.scatter_max,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
setattr(_DeviceArray, "at", property(_IndexUpdateHelper))
setattr(_CppDeviceArray, "at", property(_IndexUpdateHelper))

View File

@ -138,7 +138,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
kw = {}
if atol: kw["atol"] = atol
if rtol: kw["rtol"] = rtol
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
with np.errstate(invalid='ignore'):
# TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid
# value errors. It should not do that.
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
def tolerance(dtype, tol=None):
tol = {} if tol is None else tol

View File

@ -849,8 +849,10 @@ class UpdateOps(enum.Enum):
UPDATE = 0
ADD = 1
MUL = 2
MIN = 3
MAX = 4
DIV = 3
POW = 4
MIN = 5
MAX = 6
def np_fn(op, indexer, x, y):
x = x.copy()
@ -858,6 +860,10 @@ class UpdateOps(enum.Enum):
UpdateOps.UPDATE: lambda: y,
UpdateOps.ADD: lambda: x[indexer] + y,
UpdateOps.MUL: lambda: x[indexer] * y,
UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)(
lambda: x[indexer] / y.astype(x.dtype)),
UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)(
lambda: x[indexer] ** y.astype(x.dtype)),
UpdateOps.MIN: lambda: np.minimum(x[indexer], y),
UpdateOps.MAX: lambda: np.maximum(x[indexer], y),
}[op]()
@ -880,12 +886,21 @@ class UpdateOps(enum.Enum):
return {
UpdateOps.UPDATE: x.at[indexer].set,
UpdateOps.ADD: x.at[indexer].add,
UpdateOps.MUL: x.at[indexer].mul,
UpdateOps.MUL: x.at[indexer].multiply,
UpdateOps.DIV: x.at[indexer].divide,
UpdateOps.POW: x.at[indexer].power,
UpdateOps.MIN: x.at[indexer].min,
UpdateOps.MAX: x.at[indexer].max,
}[op](y, indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
def dtypes(op):
if op == UpdateOps.UPDATE:
return all_dtypes
elif op == UpdateOps.DIV or op == UpdateOps.POW:
return jtu.dtypes.inexact
else:
return default_dtypes
class IndexedUpdateTest(jtu.JaxTestCase):
@ -899,10 +914,10 @@ class IndexedUpdateTest(jtu.JaxTestCase):
} for name, index_specs in s(STATIC_INDEXING_TESTS)
for shape, indexer in s(index_specs)
for op in s(UpdateOps)
for dtype in s(all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
for sugared in s([True, False]))))
for sugared in (s([True, False]) if op not in [UpdateOps.DIV, UpdateOps.POW] else [True]))))
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, sugared, op):
rng = jtu.rand_default(self.rng())
@ -912,87 +927,78 @@ class IndexedUpdateTest(jtu.JaxTestCase):
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
else:
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker,
tol={np.complex128: 1e-14})
self._CompileAndCheck(jax_fn, args_maker)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
jtu.format_shape_dtype_string(update_shape, update_dtype), sugared, op.name),
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
"shape": shape, "dtype": dtype, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op, "sugared": sugared
"op": op
} for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS)
for shape, indexer in s(index_specs)
for op in s(UpdateOps)
for dtype in s(all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
for sugared in s([True, False]))))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, sugared, op):
indexer, op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
if sugared:
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y, unique_indices=True)
else:
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, unique_indices=True)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker)
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y, unique_indices=True)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker,
tol={np.complex128: 1e-14})
self._CompileAndCheck(jax_fn, args_maker)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
jtu.format_shape_dtype_string(update_shape, update_dtype), sugared, op.name),
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
"shape": shape, "dtype": dtype, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op, "sugared": sugared
"op": op
} for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED)
for shape, indexer in s(index_specs)
for op in s(UpdateOps)
for dtype in s(all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
for sugared in s([True, False]))))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
indexer, sugared, op):
indexer, op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
if sugared:
jax_fn = lambda x, y: UpdateOps.sugar_fn(
op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
else:
jax_fn = lambda x, y: UpdateOps.jax_fn(
op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True)
jax_fn = lambda x, y: UpdateOps.sugar_fn(
op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True,
tol={np.complex128: 1e-14})
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}_sugared={}".format(
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name, sugared),
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
"shape": shape, "dtype": dtype, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op, "sugared": sugared
"op": op
} for name, index_specs in s(MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS)
for shape, indexer in s(index_specs)
for op in s(UpdateOps)
for dtype in s(all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
for sugared in s([True, False]))))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, sugared, op):
indexer, op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
if sugared:
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
else:
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker)
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker,
tol={np.complex128: 1e-14})
self._CompileAndCheck(jax_fn, args_maker)
@parameterized.named_parameters(jtu.cases_from_list({
@ -1012,7 +1018,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
x = rng(shape, dtype)
y = rng(update_shape, update_dtype)
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)