mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Added scatter_sub_p
The new primitive is used for in-place subtract and update. Closes #23933 PiperOrigin-RevId: 681754037
This commit is contained in:
parent
b768b659e3
commit
4cf33c0239
@ -429,6 +429,67 @@ def scatter_add(
|
||||
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
||||
mode=GatherScatterMode.from_any(mode))
|
||||
|
||||
|
||||
def scatter_sub(
|
||||
operand: ArrayLike,
|
||||
scatter_indices: ArrayLike,
|
||||
updates: ArrayLike,
|
||||
dimension_numbers: ScatterDimensionNumbers,
|
||||
*,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
mode: str | GatherScatterMode | None = None,
|
||||
) -> Array:
|
||||
"""Scatter-sub operator.
|
||||
|
||||
Wraps `XLA's Scatter operator
|
||||
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
|
||||
subtraction is used to combine updates and values from `operand`.
|
||||
|
||||
The semantics of scatter are complicated, and its API might change in the
|
||||
future. For most use cases, you should prefer the
|
||||
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
|
||||
the familiar NumPy indexing syntax.
|
||||
|
||||
Args:
|
||||
operand: an array to which the scatter should be applied
|
||||
scatter_indices: an array that gives the indices in `operand` to which each
|
||||
update in `updates` should be applied.
|
||||
updates: the updates that should be scattered onto `operand`.
|
||||
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how
|
||||
dimensions of `operand`, `start_indices`, `updates` and the output relate.
|
||||
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
|
||||
true, may improve performance on some backends.
|
||||
unique_indices: whether the elements to be updated in ``operand`` are
|
||||
guaranteed to not overlap with each other. If true, may improve
|
||||
performance on some backends. JAX does not check this promise: if the
|
||||
updated elements overlap when ``unique_indices`` is ``True`` the behavior
|
||||
is undefined.
|
||||
mode: how to handle indices that are out of bounds: when set to 'clip',
|
||||
indices are clamped so that the slice is within bounds, and when set to
|
||||
'fill' or 'drop' out-of-bounds updates are dropped. The behavior for
|
||||
out-of-bounds indices when set to 'promise_in_bounds' is
|
||||
implementation-defined.
|
||||
|
||||
Returns:
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(
|
||||
lax.sub, lax._abstractify(lax._const(operand, 0))
|
||||
)
|
||||
return scatter_sub_p.bind(
|
||||
operand,
|
||||
scatter_indices,
|
||||
updates,
|
||||
update_jaxpr=jaxpr,
|
||||
update_consts=consts,
|
||||
dimension_numbers=dimension_numbers,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices,
|
||||
mode=GatherScatterMode.from_any(mode),
|
||||
)
|
||||
|
||||
|
||||
def scatter_mul(
|
||||
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
|
||||
dimension_numbers: ScatterDimensionNumbers, *,
|
||||
@ -1991,32 +2052,66 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
|
||||
return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64),
|
||||
upper_bound)
|
||||
|
||||
def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts,
|
||||
dimension_numbers, indices_are_sorted, unique_indices,
|
||||
mode):
|
||||
|
||||
def _scatter_addsub_jvp(
|
||||
prim,
|
||||
primals,
|
||||
tangents,
|
||||
*,
|
||||
update_jaxpr,
|
||||
update_consts,
|
||||
dimension_numbers,
|
||||
indices_are_sorted,
|
||||
unique_indices,
|
||||
mode,
|
||||
):
|
||||
operand, indices, updates = primals
|
||||
g_operand, g_indices, g_updates = tangents
|
||||
del g_indices # ignored
|
||||
val_out = scatter_add_p.bind(
|
||||
operand, indices, updates, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
||||
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
||||
mode=mode)
|
||||
val_out = prim.bind(
|
||||
operand,
|
||||
indices,
|
||||
updates,
|
||||
update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts,
|
||||
dimension_numbers=dimension_numbers,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices,
|
||||
mode=mode,
|
||||
)
|
||||
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
|
||||
tangent_out = ad_util.Zero.from_primal_value(val_out)
|
||||
else:
|
||||
g_operand = ad.instantiate_zeros(g_operand)
|
||||
g_updates = ad.instantiate_zeros(g_updates)
|
||||
tangent_out = scatter_add_p.bind(
|
||||
g_operand, indices, g_updates, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
||||
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
||||
mode=mode)
|
||||
tangent_out = prim.bind(
|
||||
g_operand,
|
||||
indices,
|
||||
g_updates,
|
||||
update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts,
|
||||
dimension_numbers=dimension_numbers,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices,
|
||||
mode=mode,
|
||||
)
|
||||
return val_out, tangent_out
|
||||
|
||||
def _scatter_add_transpose_rule(t, operand, indices, updates, *,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
indices_are_sorted, unique_indices, mode):
|
||||
|
||||
def _scatter_addsub_transpose_rule(
|
||||
prim,
|
||||
t,
|
||||
operand,
|
||||
indices,
|
||||
updates,
|
||||
*,
|
||||
update_jaxpr,
|
||||
update_consts,
|
||||
dimension_numbers,
|
||||
indices_are_sorted,
|
||||
unique_indices,
|
||||
mode,
|
||||
):
|
||||
assert not ad.is_undefined_primal(indices)
|
||||
if ad.is_undefined_primal(updates):
|
||||
updates_shape = updates.aval.shape
|
||||
@ -2045,6 +2140,8 @@ def _scatter_add_transpose_rule(t, operand, indices, updates, *,
|
||||
pos += 1
|
||||
update_t = gather(t, indices, dimension_numbers=gather_dnums,
|
||||
slice_sizes=slice_sizes, mode=mode, fill_value=0)
|
||||
if prim is scatter_sub_p:
|
||||
update_t = lax.neg(update_t)
|
||||
return [operand_t, None, update_t]
|
||||
|
||||
def _scatter_mul_transpose_rule(t, operand, indices, updates, *,
|
||||
@ -2140,11 +2237,23 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
|
||||
scatter_add_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
|
||||
weak_type_rule=_argnum_weak_type(0))
|
||||
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
|
||||
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
|
||||
ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p)
|
||||
ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p)
|
||||
batching.primitive_batchers[scatter_add_p] = (
|
||||
partial(_scatter_batching_rule, scatter_add_p))
|
||||
|
||||
scatter_sub_p = standard_primitive(
|
||||
_scatter_shape_rule,
|
||||
_scatter_dtype_rule,
|
||||
"scatter-sub",
|
||||
weak_type_rule=_argnum_weak_type(0),
|
||||
)
|
||||
ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p)
|
||||
ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p)
|
||||
batching.primitive_batchers[scatter_sub_p] = partial(
|
||||
_scatter_batching_rule, scatter_sub_p
|
||||
)
|
||||
|
||||
scatter_mul_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
|
||||
weak_type_rule=_argnum_weak_type(0))
|
||||
@ -2513,6 +2622,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
|
||||
mlir.register_lowering(scatter_p, _scatter_lower)
|
||||
mlir.register_lowering(scatter_add_p, _scatter_lower)
|
||||
mlir.register_lowering(scatter_sub_p, _scatter_lower)
|
||||
mlir.register_lowering(scatter_mul_p, _scatter_lower)
|
||||
mlir.register_lowering(scatter_min_p, _scatter_lower)
|
||||
mlir.register_lowering(scatter_max_p, _scatter_lower)
|
||||
@ -2520,9 +2630,21 @@ mlir.register_lowering(scatter_max_p, _scatter_lower)
|
||||
|
||||
def _real_dtype(dtype): return np.finfo(dtype).dtype
|
||||
|
||||
def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
*, update_jaxpr, update_consts, dimension_numbers,
|
||||
indices_are_sorted, unique_indices, mode):
|
||||
|
||||
def _scatter_addsub_lower_gpu(
|
||||
ctx,
|
||||
operand,
|
||||
indices,
|
||||
updates,
|
||||
*,
|
||||
update_jaxpr,
|
||||
update_consts,
|
||||
dimension_numbers,
|
||||
indices_are_sorted,
|
||||
unique_indices,
|
||||
mode,
|
||||
reduce_op,
|
||||
):
|
||||
operand_aval_in, _, updates_aval_in = ctx.avals_in
|
||||
if operand_aval_in.dtype != np.complex128:
|
||||
return _scatter_lower(ctx, operand, indices, updates,
|
||||
@ -2566,15 +2688,24 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype))
|
||||
reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer):
|
||||
add = hlo.AddOp(*reducer.arguments).result
|
||||
hlo.return_([add])
|
||||
hlo.return_([reduce_op(*reducer.arguments).result])
|
||||
return scatter.result
|
||||
|
||||
real = _scatter(hlo.real(operand), hlo.real(updates))
|
||||
imag = _scatter(hlo.imag(operand), hlo.imag(updates))
|
||||
return [hlo.complex(real, imag)]
|
||||
|
||||
mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")
|
||||
|
||||
mlir.register_lowering(
|
||||
scatter_add_p,
|
||||
partial(_scatter_addsub_lower_gpu, reduce_op=hlo.AddOp),
|
||||
platform="gpu",
|
||||
)
|
||||
mlir.register_lowering(
|
||||
scatter_sub_p,
|
||||
partial(_scatter_addsub_lower_gpu, reduce_op=hlo.SubtractOp),
|
||||
platform="gpu",
|
||||
)
|
||||
|
||||
|
||||
def _dynamic_slice_indices(
|
||||
|
@ -661,6 +661,7 @@ class _IndexUpdateHelper:
|
||||
============================== ================================
|
||||
``x = x.at[idx].set(y)`` ``x[idx] = y``
|
||||
``x = x.at[idx].add(y)`` ``x[idx] += y``
|
||||
``x = x.at[idx].subtract(y)`` ``x[idx] -= y``
|
||||
``x = x.at[idx].multiply(y)`` ``x[idx] *= y``
|
||||
``x = x.at[idx].divide(y)`` ``x[idx] /= y``
|
||||
``x = x.at[idx].power(y)`` ``x[idx] **= y``
|
||||
@ -826,6 +827,20 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
def subtract(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] -= y``.
|
||||
|
||||
Returns the value of ``x`` that would result from the NumPy-style
|
||||
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] -= y``.
|
||||
|
||||
See :mod:`jax.ops` for details.
|
||||
"""
|
||||
return scatter._scatter_update(self.array, self.index, values,
|
||||
lax.scatter_sub,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
def multiply(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] *= y``.
|
||||
|
@ -3043,6 +3043,7 @@ tf_impl_with_avals[lax.scatter_min_p] = _scatter
|
||||
tf_impl_with_avals[lax.scatter_max_p] = _scatter
|
||||
tf_impl_with_avals[lax.scatter_mul_p] = _scatter
|
||||
tf_impl_with_avals[lax.scatter_add_p] = _scatter
|
||||
tf_impl_with_avals[lax.scatter_sub_p] = _scatter
|
||||
|
||||
|
||||
def _cond(
|
||||
|
@ -280,6 +280,8 @@ from jax._src.lax.slicing import (
|
||||
scatter_mul as scatter_mul,
|
||||
scatter_mul_p as scatter_mul_p,
|
||||
scatter_p as scatter_p,
|
||||
scatter_sub as scatter_sub,
|
||||
scatter_sub_p as scatter_sub_p,
|
||||
slice as slice,
|
||||
slice_in_dim as slice_in_dim,
|
||||
slice_p as slice_p,
|
||||
|
@ -1252,7 +1252,7 @@ def _can_cast(from_, to):
|
||||
|
||||
|
||||
def _compatible_dtypes(op, dtype, inexact=False):
|
||||
if op == UpdateOps.ADD:
|
||||
if op == UpdateOps.ADD or op == UpdateOps.SUB:
|
||||
return [dtype]
|
||||
elif inexact:
|
||||
return [dt for dt in float_dtypes if _can_cast(dt, dtype)]
|
||||
@ -1263,17 +1263,19 @@ def _compatible_dtypes(op, dtype, inexact=False):
|
||||
class UpdateOps(enum.Enum):
|
||||
UPDATE = 0
|
||||
ADD = 1
|
||||
MUL = 2
|
||||
DIV = 3
|
||||
POW = 4
|
||||
MIN = 5
|
||||
MAX = 6
|
||||
SUB = 2
|
||||
MUL = 3
|
||||
DIV = 4
|
||||
POW = 5
|
||||
MIN = 6
|
||||
MAX = 7
|
||||
|
||||
def np_fn(op, indexer, x, y):
|
||||
x = x.copy()
|
||||
x[indexer] = {
|
||||
UpdateOps.UPDATE: lambda: y,
|
||||
UpdateOps.ADD: lambda: x[indexer] + y,
|
||||
UpdateOps.SUB: lambda: x[indexer] - y,
|
||||
UpdateOps.MUL: lambda: x[indexer] * y,
|
||||
UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)(
|
||||
lambda: x[indexer] / y.astype(x.dtype)),
|
||||
@ -1290,6 +1292,7 @@ class UpdateOps(enum.Enum):
|
||||
return {
|
||||
UpdateOps.UPDATE: x.at[indexer].set,
|
||||
UpdateOps.ADD: x.at[indexer].add,
|
||||
UpdateOps.SUB: x.at[indexer].subtract,
|
||||
UpdateOps.MUL: x.at[indexer].multiply,
|
||||
UpdateOps.DIV: x.at[indexer].divide,
|
||||
UpdateOps.POW: x.at[indexer].power,
|
||||
@ -1420,7 +1423,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
for update_shape in _broadcastable_shapes(index_shape)
|
||||
],
|
||||
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
||||
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
|
||||
for op in [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE]
|
||||
for dtype in float_dtypes
|
||||
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
|
||||
],
|
||||
@ -1447,8 +1450,9 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
],
|
||||
[dict(op=op, dtype=dtype, update_dtype=update_dtype)
|
||||
for op in (
|
||||
[UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices
|
||||
else [UpdateOps.ADD])
|
||||
[UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE]
|
||||
if unique_indices
|
||||
else [UpdateOps.ADD, UpdateOps.SUB])
|
||||
for dtype in float_dtypes
|
||||
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)
|
||||
],
|
||||
|
@ -2782,14 +2782,15 @@ class LaxTest(jtu.JaxTestCase):
|
||||
]],
|
||||
dtype=lax_test_util.inexact_dtypes,
|
||||
mode=["clip", "fill", None],
|
||||
op=[lax.scatter_add, lax.scatter_sub],
|
||||
)
|
||||
def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, mode):
|
||||
def testScatterAddSub(self, arg_shape, dtype, idxs, update_shape, dnums, mode, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
|
||||
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
||||
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
||||
rng(update_shape, dtype)]
|
||||
fun = partial(lax.scatter_add, dimension_numbers=dnums, mode=mode)
|
||||
fun = partial(op, dimension_numbers=dnums, mode=mode)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
|
Loading…
x
Reference in New Issue
Block a user