mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Implement jax.ops.index_mul. (#2696)
* Implement jax.ops.index_mul. * Add index_mul to documentation. * Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested. * Fix typo in docstring.
This commit is contained in:
parent
04de8339a8
commit
714b276b9a
@ -20,6 +20,7 @@ pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
|
||||
index
|
||||
index_update
|
||||
index_add
|
||||
index_mul
|
||||
index_min
|
||||
index_max
|
||||
|
||||
|
@ -858,6 +858,33 @@ def scatter_add(operand: Array, scatter_indices: Array, updates: Array,
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers)
|
||||
|
||||
def scatter_mul(operand: Array, scatter_indices: Array, updates: Array,
|
||||
dimension_numbers: ScatterDimensionNumbers) -> Array:
|
||||
"""Scatter-multiply operator.
|
||||
|
||||
Wraps `XLA's Scatter operator
|
||||
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
|
||||
multiplication is used to combine updates and values from `operand`.
|
||||
|
||||
The semantics of scatter are complicated and its API is subject to change.
|
||||
|
||||
Args:
|
||||
operand: an array to which the scatter should be applied
|
||||
scatter_indices: an array that gives the indices in `operand` to which each
|
||||
update in `updates` should be applied.
|
||||
updates: the updates that should be scattered onto `operand`.
|
||||
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
|
||||
how dimensions of `operand`, `start_indices`, `updates` and the output
|
||||
relate.
|
||||
|
||||
Returns:
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = _reduction_jaxpr(mul, _abstractify(_const(operand, 1)))
|
||||
return scatter_mul_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers)
|
||||
|
||||
def scatter_min(operand: Array, scatter_indices: Array, updates: Array,
|
||||
dimension_numbers: ScatterDimensionNumbers) -> Array:
|
||||
"""Scatter-min operator.
|
||||
@ -3458,6 +3485,39 @@ def _scatter_add_transpose_rule(t, operand, scatter_indices, updates, *,
|
||||
slice_sizes=slice_sizes)
|
||||
return [operand_t, None, update_t]
|
||||
|
||||
def _scatter_mul_transpose_rule(t, operand, scatter_indices, updates, *,
|
||||
update_jaxpr, update_consts, dimension_numbers):
|
||||
assert not ad.is_undefined_primal(scatter_indices)
|
||||
if ad.is_undefined_primal(updates):
|
||||
updates_shape = updates.aval.shape
|
||||
else:
|
||||
updates_shape = updates.shape
|
||||
if t is ad_util.zero:
|
||||
return [ad_util.zero, None, ad_util.zero]
|
||||
|
||||
operand_t = update_t = None
|
||||
if ad.is_undefined_primal(operand):
|
||||
operand_t = scatter_mul(t, scatter_indices, updates,
|
||||
dimension_numbers=dimension_numbers)
|
||||
|
||||
if ad.is_undefined_primal(updates):
|
||||
gather_dnums = GatherDimensionNumbers(
|
||||
offset_dims=dimension_numbers.update_window_dims,
|
||||
collapsed_slice_dims=dimension_numbers.inserted_window_dims,
|
||||
start_index_map=dimension_numbers.scatter_dims_to_operand_dims)
|
||||
slice_sizes = []
|
||||
pos = 0
|
||||
for i in range(len(t.shape)):
|
||||
if i in dimension_numbers.inserted_window_dims:
|
||||
slice_sizes.append(1)
|
||||
else:
|
||||
slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]])
|
||||
pos += 1
|
||||
update_t = gather(mul(t, operand), scatter_indices,
|
||||
dimension_numbers=gather_dnums, slice_sizes=slice_sizes)
|
||||
return [operand_t, None, update_t]
|
||||
|
||||
|
||||
def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
|
||||
update_jaxpr, update_consts, dimension_numbers):
|
||||
operand, scatter_indices, updates = batched_args
|
||||
@ -3512,6 +3572,23 @@ ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
|
||||
batching.primitive_batchers[scatter_add_p] = (
|
||||
partial(_scatter_batching_rule, scatter_add))
|
||||
|
||||
|
||||
scatter_mul_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
|
||||
_scatter_translation_rule)
|
||||
|
||||
def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, **kw):
|
||||
return mul(x, scatter_add(zeros_like_array(x), i, g,
|
||||
dimension_numbers=dimension_numbers))
|
||||
|
||||
ad.defjvp(scatter_mul_p,
|
||||
lambda g, x, i, y, **kw: scatter_mul_p.bind(g, i, y, **kw),
|
||||
None,
|
||||
_scatter_mul_jvp_rhs)
|
||||
ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
|
||||
batching.primitive_batchers[scatter_mul_p] = (
|
||||
partial(_scatter_batching_rule, scatter_mul))
|
||||
|
||||
# TODO(jlebar): Add derivatives.
|
||||
scatter_min_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
|
||||
|
@ -13,4 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .scatter import index, index_add, index_update, index_min, index_max, segment_sum
|
||||
from .scatter import (
|
||||
index, index_add, index_mul, index_update, index_min, index_max, segment_sum
|
||||
)
|
||||
|
@ -41,7 +41,6 @@ def _scatter_update(x, idx, y, scatter_op):
|
||||
|
||||
x = np.asarray(x)
|
||||
y = np.asarray(y)
|
||||
|
||||
# XLA gathers and scatters are very similar in structure; the scatter logic
|
||||
# is more or less a transpose of the gather equivalent.
|
||||
treedef, static_idx, dynamic_idx = np._split_index_for_jit(idx)
|
||||
@ -52,7 +51,8 @@ def _scatter_update(x, idx, y, scatter_op):
|
||||
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
|
||||
# @partial(jit, static_argnums=(2, 3, 4))
|
||||
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx):
|
||||
y = lax.convert_element_type(y, lax.dtype(x))
|
||||
dtype = lax.dtype(x)
|
||||
x, y = np._promote_dtypes(x, y)
|
||||
|
||||
idx = np._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
||||
indexer = np._index_to_gather(np.shape(x), idx)
|
||||
@ -71,7 +71,8 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx):
|
||||
inserted_window_dims=indexer.dnums.collapsed_slice_dims,
|
||||
scatter_dims_to_operand_dims=indexer.dnums.start_index_map
|
||||
)
|
||||
return scatter_op(x, indexer.gather_indices, y, dnums)
|
||||
out = scatter_op(x, indexer.gather_indices, y, dnums)
|
||||
return lax.convert_element_type(out, dtype)
|
||||
|
||||
|
||||
class _Indexable(object):
|
||||
@ -130,6 +131,46 @@ def index_add(x, idx, y):
|
||||
"""
|
||||
return _scatter_update(x, idx, y, lax.scatter_add)
|
||||
|
||||
|
||||
def index_mul(x, idx, y):
|
||||
"""Pure equivalent of :code:`x[idx] *= y`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
|
||||
x[idx] *= y
|
||||
|
||||
Note the `index_mul` operator is pure; `x` itself is
|
||||
not modified, instead the new value that `x` would have taken is returned.
|
||||
|
||||
Unlike the NumPy code :code:`x[idx] *= y`, if multiple indices refer to the
|
||||
same location the updates will be multiplied. (NumPy would only apply the last
|
||||
update, rather than multiplying the updates.) The order in which conflicting
|
||||
updates are applied is implementation-defined and may be nondeterministic
|
||||
(e.g., due to concurrency on some hardware platforms).
|
||||
|
||||
Args:
|
||||
x: an array with the values to be updated.
|
||||
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
|
||||
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
|
||||
convenient syntactic sugar for forming indices is via the
|
||||
:data:`jax.ops.index` object.
|
||||
y: the array of updates. `y` must be broadcastable to the shape of the
|
||||
array that would be returned by `x[idx]`.
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
|
||||
>>> x = jax.numpy.ones((5, 6))
|
||||
>>> jax.ops.index_mul(x, jax.ops.index[2:4, 3:], 6.)
|
||||
array([[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
[1., 1., 1., 6., 6., 6.],
|
||||
[1., 1., 1., 1., 1., 1.]], dtype=float32)
|
||||
"""
|
||||
return _scatter_update(x, idx, y, lax.scatter_mul)
|
||||
|
||||
|
||||
def index_min(x, idx, y):
|
||||
"""Pure equivalent of :code:`x[idx] = minimum(x[idx], y)`.
|
||||
|
||||
|
@ -816,8 +816,9 @@ def _update_shape(shape, indexer):
|
||||
class UpdateOps(enum.Enum):
|
||||
UPDATE = 0
|
||||
ADD = 1
|
||||
MIN = 2
|
||||
MAX = 3
|
||||
MUL = 2
|
||||
MIN = 3
|
||||
MAX = 4
|
||||
|
||||
@suppress_deprecated_indexing_warnings()
|
||||
def onp_fn(op, indexer, x, y):
|
||||
@ -825,6 +826,7 @@ class UpdateOps(enum.Enum):
|
||||
x[indexer] = {
|
||||
UpdateOps.UPDATE: lambda: y,
|
||||
UpdateOps.ADD: lambda: x[indexer] + y,
|
||||
UpdateOps.MUL: lambda: x[indexer] * y,
|
||||
UpdateOps.MIN: lambda: onp.minimum(x[indexer], y),
|
||||
UpdateOps.MAX: lambda: onp.maximum(x[indexer], y),
|
||||
}[op]()
|
||||
@ -834,6 +836,7 @@ class UpdateOps(enum.Enum):
|
||||
return {
|
||||
UpdateOps.UPDATE: ops.index_update,
|
||||
UpdateOps.ADD: ops.index_add,
|
||||
UpdateOps.MUL: ops.index_mul,
|
||||
UpdateOps.MIN: ops.index_min,
|
||||
UpdateOps.MAX: ops.index_max,
|
||||
}[op](x, indexer, y)
|
||||
@ -919,7 +922,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
"op": op
|
||||
} for name, index_specs in STATIC_INDEXING_TESTS
|
||||
for shape, indexer in index_specs
|
||||
for op in UpdateOps
|
||||
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
|
||||
for dtype in float_dtypes
|
||||
for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
|
||||
for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)
|
||||
@ -928,8 +931,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
||||
rng_factory, indexer, op):
|
||||
rng = rng_factory()
|
||||
jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add
|
||||
jax_fn = lambda x, y: jax_op(x, indexer, y)
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
x = rng(shape, dtype)
|
||||
y = rng(update_shape, update_dtype)
|
||||
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
|
||||
|
Loading…
x
Reference in New Issue
Block a user