mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Promote the x.at[idx].set(y) operators as the preferred way to do indexed updates.
Mark the index_update() etc. operators as deprecated in the documentation. Add new .divide and .power operators. Fixes #2694. Add .multiply as an alias for .mul. To be more numpy-like we should probably prefer the longer names.
This commit is contained in:
parent
1509f995ed
commit
d005e38f78
@ -6,13 +6,39 @@ jax.ops package
|
||||
|
||||
.. automodule:: jax.ops
|
||||
|
||||
.. _syntactic-sugar-for-ops:
|
||||
|
||||
Indexed update operators
|
||||
------------------------
|
||||
|
||||
JAX is intended to be used with a functional style of programming, and hence
|
||||
JAX is intended to be used with a functional style of programming, and
|
||||
does not support NumPy-style indexed assignment directly. Instead, JAX provides
|
||||
pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
|
||||
alternative pure functional operators for indexed updates to arrays.
|
||||
|
||||
JAX array types have a property ``at``, which can be used as
|
||||
follows (where ``idx`` is a NumPy index expression).
|
||||
|
||||
========================= ===================================================
|
||||
Alternate syntax Equivalent in-place expression
|
||||
========================= ===================================================
|
||||
``x.at[idx].set(y)`` ``x[idx] = y``
|
||||
``x.at[idx].add(y)`` ``x[idx] += y``
|
||||
``x.at[idx].multiply(y)`` ``x[idx] *= y``
|
||||
``x.at[idx].divide(y)`` ``x[idx] /= y``
|
||||
``x.at[idx].power(y)`` ``x[idx] **= y``
|
||||
``x.at[idx].min(y)`` ``x[idx] = np.minimum(x[idx], y)``
|
||||
``x.at[idx].max(y)`` ``x[idx] = np.maximum(x[idx], y)``
|
||||
========================= ===================================================
|
||||
|
||||
None of these expressions modify the original `x`; instead they return
|
||||
a modified copy of `x`.
|
||||
|
||||
|
||||
Indexed update functions (deprecated)
|
||||
-------------------------------------
|
||||
|
||||
The following functions are aliases for the ``x.at[idx].set(y)``
|
||||
style operators. Prefer to use the ``x.at[idx]`` operators instead.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
@ -24,27 +50,7 @@ pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
|
||||
index_min
|
||||
index_max
|
||||
|
||||
.. _syntactic-sugar-for-ops:
|
||||
|
||||
Syntactic sugar for indexed update operators
|
||||
--------------------------------------------
|
||||
|
||||
JAX also provides an alternate syntax for these indexed update operators.
|
||||
Specifically, JAX ndarray types have a property ``at``, which can be used as
|
||||
follows (where ``idx`` can be an arbitrary index expression).
|
||||
|
||||
==================== ===================================================
|
||||
Alternate syntax Equivalent expression
|
||||
==================== ===================================================
|
||||
``x.at[idx].set(y)`` ``jax.ops.index_update(x, jax.ops.index[idx], y)``
|
||||
``x.at[idx].add(y)`` ``jax.ops.index_add(x, jax.ops.index[idx], y)``
|
||||
``x.at[idx].mul(y)`` ``jax.ops.index_mul(x, jax.ops.index[idx], y)``
|
||||
``x.at[idx].min(y)`` ``jax.ops.index_min(x, jax.ops.index[idx], y)``
|
||||
``x.at[idx].max(y)`` ``jax.ops.index_max(x, jax.ops.index[idx], y)``
|
||||
==================== ===================================================
|
||||
|
||||
Note that none of these expressions modify the original `x`; instead they return
|
||||
a modified copy of `x`.
|
||||
|
||||
Other operators
|
||||
---------------
|
||||
|
@ -49,6 +49,7 @@ from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
|
||||
from jax import lax
|
||||
from jax._src.lax.lax import _device_put_raw
|
||||
from jax import ops
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip,
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
|
||||
from jax.tree_util import tree_leaves, tree_flatten, tree_map
|
||||
@ -5842,14 +5843,17 @@ class _IndexUpdateHelper:
|
||||
|
||||
The ``at`` property is syntactic sugar for calling the indexed update functions
|
||||
defined in :mod:`jax.ops`, and acts as a pure equivalent of in-place
|
||||
modificatons. For further information, see `Syntactic Sugar for Index Update Operators
|
||||
<https://jax.readthedocs.io/en/latest/jax.ops.html#syntactic-sugar-for-indexed-update-operators>`_.
|
||||
modificatons. For further information, see `Indexed Update Operators
|
||||
<https://jax.readthedocs.io/en/latest/jax.ops.html#indexed-update-operators>`_.
|
||||
|
||||
In particular:
|
||||
|
||||
- ``x = x.at[idx].set(y)`` is a pure equivalent of ``x[idx] = y``.
|
||||
- ``x = x.at[idx].add(y)`` is a pure equivalent of ``x[idx] += y``.
|
||||
- ``x = x.at[idx].mul(y)`` is a pure equivalent of ``x[idx] *= y``.
|
||||
- ``x = x.at[idx].multiply(y)`` (aka ``mul``) is a pure equivalent of
|
||||
``x[idx] *= y``.
|
||||
- ``x = x.at[idx].divide(y)`` is a pure equivalent of ``x[idx] /= y``.
|
||||
- ``x = x.at[idx].power(y)`` is a pure equivalent of ``x[idx] **= y``.
|
||||
- ``x = x.at[idx].min(y)`` is a pure equivalent of
|
||||
``x[idx] = minimum(x[idx], y)``.
|
||||
- ``x = x.at[idx].max(y)`` is a pure equivalent of
|
||||
@ -5866,6 +5870,8 @@ class _IndexUpdateHelper:
|
||||
def __repr__(self):
|
||||
return f"_IndexUpdateHelper({repr(self.array)})"
|
||||
|
||||
_power = power
|
||||
_divide = divide
|
||||
|
||||
class _IndexUpdateRef:
|
||||
"""Helper object to call indexed update functions for an (advanced) index.
|
||||
@ -5886,74 +5892,100 @@ class _IndexUpdateRef:
|
||||
def set(self, values, indices_are_sorted=False, unique_indices=False):
|
||||
"""Pure equivalent of ``x[idx] = y``.
|
||||
|
||||
``x.at[idx].set(y)`` is syntactic sugar for
|
||||
``jax.ops.index_update(x, jax.ops.index[idx], y)``, and
|
||||
returns the value of ``x`` that would result from the NumPy-style
|
||||
Returns the value of ``x`` that would result from the NumPy-style
|
||||
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] = y``.
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
return ops.index_update(self.array, self.index, values,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
return scatter._scatter_update(self.array, self.index, values, lax.scatter,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
|
||||
def add(self, values, indices_are_sorted=False, unique_indices=False):
|
||||
"""Pure equivalent of ``x[idx] += y``.
|
||||
|
||||
``x.at[idx].add(y)`` is syntactic sugar for
|
||||
``jax.ops.index_add(x, jax.ops.index[idx], y)``, and
|
||||
returns the value of ``x`` that would result from the NumPy-style
|
||||
Returns the value of ``x`` that would result from the NumPy-style
|
||||
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] += y``.
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
return ops.index_add(self.array, self.index, values,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
return scatter._scatter_update(self.array, self.index, values,
|
||||
lax.scatter_add,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
|
||||
def mul(self, values, indices_are_sorted=False, unique_indices=False):
|
||||
"""Pure equivalent of ``x[idx] += y``.
|
||||
def multiply(self, values, indices_are_sorted=False, unique_indices=False):
|
||||
"""Pure equivalent of ``x[idx] *= y``.
|
||||
|
||||
``x.at[idx].mul(y)`` is syntactic sugar for
|
||||
``jax.ops.index_mul(x, jax.ops.index[idx], y)``, and
|
||||
returns the value of ``x`` that would result from the NumPy-style
|
||||
Returns the value of ``x`` that would result from the NumPy-style
|
||||
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] *= y``.
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
return ops.index_mul(self.array, self.index, values,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
return scatter._scatter_update(self.array, self.index, values,
|
||||
lax.scatter_mul,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
mul = multiply
|
||||
|
||||
def divide(self, values, indices_are_sorted=False, unique_indices=False):
|
||||
"""Pure equivalent of ``x[idx] /= y``.
|
||||
|
||||
Returns the value of ``x`` that would result from the NumPy-style
|
||||
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] /= y``.
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
return _divide(
|
||||
self.array,
|
||||
scatter._scatter_update(ones_like(self.array), self.index, values,
|
||||
lax.scatter_mul,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices))
|
||||
|
||||
def power(self, values, indices_are_sorted=False, unique_indices=False):
|
||||
"""Pure equivalent of ``x[idx] **= y``.
|
||||
|
||||
Returns the value of ``x`` that would result from the NumPy-style
|
||||
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] **= y``.
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
return _power(
|
||||
self.array,
|
||||
scatter._scatter_update(ones_like(self.array), self.index, values,
|
||||
lax.scatter_mul,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices))
|
||||
|
||||
def min(self, values, indices_are_sorted=False, unique_indices=False):
|
||||
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
|
||||
|
||||
``x.at[idx].min(y)`` is syntactic sugar for
|
||||
``jax.ops.index_min(x, jax.ops.index[idx], y)``, and
|
||||
returns the value of ``x`` that would result from the NumPy-style
|
||||
Returns the value of ``x`` that would result from the NumPy-style
|
||||
:mod:indexed assignment <numpy.doc.indexing>`
|
||||
``x[idx] = minimum(x[idx], y)``.
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
return ops.index_min(self.array, self.index, values,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
return scatter._scatter_update(self.array, self.index, values,
|
||||
lax.scatter_min,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
|
||||
def max(self, values, indices_are_sorted=False, unique_indices=False):
|
||||
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
|
||||
|
||||
``x.at[idx].max(y)`` is syntactic sugar for
|
||||
``jax.ops.index_max(x, jax.ops.index[idx], y)``, and
|
||||
returns the value of ``x`` that would result from the NumPy-style
|
||||
Returns the value of ``x`` that would result from the NumPy-style
|
||||
:mod:indexed assignment <numpy.doc.indexing>`
|
||||
``x[idx] = maximum(x[idx], y)``.
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
return ops.index_max(self.array, self.index, values,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
return scatter._scatter_update(self.array, self.index, values,
|
||||
lax.scatter_max,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
|
||||
|
||||
setattr(_DeviceArray, "at", property(_IndexUpdateHelper))
|
||||
setattr(_CppDeviceArray, "at", property(_IndexUpdateHelper))
|
||||
|
@ -138,7 +138,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
|
||||
kw = {}
|
||||
if atol: kw["atol"] = atol
|
||||
if rtol: kw["rtol"] = rtol
|
||||
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
|
||||
with np.errstate(invalid='ignore'):
|
||||
# TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid
|
||||
# value errors. It should not do that.
|
||||
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
|
||||
|
||||
def tolerance(dtype, tol=None):
|
||||
tol = {} if tol is None else tol
|
||||
|
@ -849,8 +849,10 @@ class UpdateOps(enum.Enum):
|
||||
UPDATE = 0
|
||||
ADD = 1
|
||||
MUL = 2
|
||||
MIN = 3
|
||||
MAX = 4
|
||||
DIV = 3
|
||||
POW = 4
|
||||
MIN = 5
|
||||
MAX = 6
|
||||
|
||||
def np_fn(op, indexer, x, y):
|
||||
x = x.copy()
|
||||
@ -858,6 +860,10 @@ class UpdateOps(enum.Enum):
|
||||
UpdateOps.UPDATE: lambda: y,
|
||||
UpdateOps.ADD: lambda: x[indexer] + y,
|
||||
UpdateOps.MUL: lambda: x[indexer] * y,
|
||||
UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)(
|
||||
lambda: x[indexer] / y.astype(x.dtype)),
|
||||
UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)(
|
||||
lambda: x[indexer] ** y.astype(x.dtype)),
|
||||
UpdateOps.MIN: lambda: np.minimum(x[indexer], y),
|
||||
UpdateOps.MAX: lambda: np.maximum(x[indexer], y),
|
||||
}[op]()
|
||||
@ -880,12 +886,21 @@ class UpdateOps(enum.Enum):
|
||||
return {
|
||||
UpdateOps.UPDATE: x.at[indexer].set,
|
||||
UpdateOps.ADD: x.at[indexer].add,
|
||||
UpdateOps.MUL: x.at[indexer].mul,
|
||||
UpdateOps.MUL: x.at[indexer].multiply,
|
||||
UpdateOps.DIV: x.at[indexer].divide,
|
||||
UpdateOps.POW: x.at[indexer].power,
|
||||
UpdateOps.MIN: x.at[indexer].min,
|
||||
UpdateOps.MAX: x.at[indexer].max,
|
||||
}[op](y, indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
|
||||
def dtypes(op):
|
||||
if op == UpdateOps.UPDATE:
|
||||
return all_dtypes
|
||||
elif op == UpdateOps.DIV or op == UpdateOps.POW:
|
||||
return jtu.dtypes.inexact
|
||||
else:
|
||||
return default_dtypes
|
||||
|
||||
class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
|
||||
@ -899,10 +914,10 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
} for name, index_specs in s(STATIC_INDEXING_TESTS)
|
||||
for shape, indexer in s(index_specs)
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
|
||||
for sugared in s([True, False]))))
|
||||
for sugared in (s([True, False]) if op not in [UpdateOps.DIV, UpdateOps.POW] else [True]))))
|
||||
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, sugared, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -912,87 +927,78 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker,
|
||||
tol={np.complex128: 1e-14})
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), sugared, op.name),
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op, "sugared": sugared
|
||||
"op": op
|
||||
} for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS)
|
||||
for shape, indexer in s(index_specs)
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
|
||||
for sugared in s([True, False]))))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
|
||||
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, sugared, op):
|
||||
indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
||||
if sugared:
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y, unique_indices=True)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, unique_indices=True)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker)
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y, unique_indices=True)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker,
|
||||
tol={np.complex128: 1e-14})
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), sugared, op.name),
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op, "sugared": sugared
|
||||
"op": op
|
||||
} for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED)
|
||||
for shape, indexer in s(index_specs)
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
|
||||
for sugared in s([True, False]))))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
|
||||
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, sugared, op):
|
||||
indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
||||
if sugared:
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(
|
||||
op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(
|
||||
op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True)
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(
|
||||
op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True,
|
||||
tol={np.complex128: 1e-14})
|
||||
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}_sugared={}".format(
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name, sugared),
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
|
||||
"shape": shape, "dtype": dtype, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op, "sugared": sugared
|
||||
"op": op
|
||||
} for name, index_specs in s(MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS)
|
||||
for shape, indexer in s(index_specs)
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
|
||||
for sugared in s([True, False]))))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
|
||||
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, sugared, op):
|
||||
indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
||||
if sugared:
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker)
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker,
|
||||
tol={np.complex128: 1e-14})
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
@ -1012,7 +1018,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
x = rng(shape, dtype)
|
||||
y = rng(update_shape, update_dtype)
|
||||
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
|
||||
|
Loading…
x
Reference in New Issue
Block a user