mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Implement a scatter-update operator.
This commit is contained in:
parent
1497c31590
commit
148073edf7
167
jax/lax.py
167
jax/lax.py
@ -604,10 +604,11 @@ def gather(operand, start_indices, dimension_numbers, slice_sizes):
|
||||
slice_sizes=tuple(slice_sizes), operand_shape=operand.shape)
|
||||
|
||||
def scatter_add(operand, scatter_indices, updates, dimension_numbers):
|
||||
"""Scatter operator.
|
||||
"""Scatter-add operator.
|
||||
|
||||
Wraps `XLA's Scatter operator
|
||||
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_.
|
||||
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
|
||||
addition is used to combine updates and values from `operand`.
|
||||
|
||||
The semantics of scatter are complicated and its API is subject to change.
|
||||
|
||||
@ -624,6 +625,36 @@ def scatter_add(operand, scatter_indices, updates, dimension_numbers):
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = _reduction_jaxpr(add, _const(operand, 0))
|
||||
return scatter_add_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
updates_shape=updates.shape)
|
||||
|
||||
def scatter(operand, scatter_indices, updates, dimension_numbers):
|
||||
"""Scatter-update operator.
|
||||
|
||||
Wraps `XLA's Scatter operator
|
||||
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where updates
|
||||
replace values from `operand`.
|
||||
|
||||
If multiple updates are performed to the same index of operand, they may be
|
||||
applied in any order.
|
||||
|
||||
The semantics of scatter are complicated and its API is subject to change.
|
||||
|
||||
Args:
|
||||
operand: an array to which the scatter should be applied
|
||||
scatter_indices: an array that gives the indices in `operand` to which each
|
||||
update in `updates` should be applied.
|
||||
updates: the updates that should be scattered onto `operand`.
|
||||
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
|
||||
how dimensions of `operand`, `start_indices`, `updates` and the output
|
||||
relate.
|
||||
|
||||
Returns:
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = _reduction_jaxpr(lambda x, y: y, _const(operand, 0))
|
||||
return scatter_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -2712,12 +2743,12 @@ def _scatter_translation_rule(c, operand, scatter_indices, updates,
|
||||
return c.Scatter(operand, scatter_indices, updates, update_computation,
|
||||
_scatter_dimensions_proto(indices_shape, dimension_numbers))
|
||||
|
||||
def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
|
||||
dimension_numbers, updates_shape):
|
||||
def _scatter_add_jvp(primals, tangents, update_jaxpr, update_consts,
|
||||
dimension_numbers, updates_shape):
|
||||
operand, scatter_indices, updates = primals
|
||||
g_operand, g_scatter_indices, g_updates = tangents
|
||||
assert g_scatter_indices is ad_util.zero
|
||||
val_out = scatter_p.bind(
|
||||
val_out = scatter_add_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
||||
updates_shape=updates_shape)
|
||||
@ -2726,15 +2757,15 @@ def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
|
||||
else:
|
||||
g_operand = ad.instantiate_zeros(operand, g_operand)
|
||||
g_updates = ad.instantiate_zeros(updates, g_updates)
|
||||
tangent_out = scatter_p.bind(
|
||||
tangent_out = scatter_add_p.bind(
|
||||
g_operand, scatter_indices, g_updates, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
||||
updates_shape=updates_shape)
|
||||
return val_out, tangent_out
|
||||
|
||||
def _scatter_transpose_rule(t, operand, scatter_indices, updates,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
updates_shape):
|
||||
def _scatter_add_transpose_rule(t, operand, scatter_indices, updates,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
updates_shape):
|
||||
assert scatter_indices is not None
|
||||
operand_t = update_t = None
|
||||
if operand is None:
|
||||
@ -2757,8 +2788,9 @@ def _scatter_transpose_rule(t, operand, scatter_indices, updates,
|
||||
slice_sizes=slice_sizes)
|
||||
return [operand_t, None, update_t]
|
||||
|
||||
def _scatter_batching_rule(batched_args, batch_dims, update_jaxpr,
|
||||
update_consts, dimension_numbers, updates_shape):
|
||||
def _scatter_batching_rule(
|
||||
scatter_op, batched_args, batch_dims, update_jaxpr, update_consts,
|
||||
dimension_numbers, updates_shape):
|
||||
operand, scatter_indices, updates = batched_args
|
||||
operand_bdim, scatter_indices_bdim, updates_bdim = batch_dims
|
||||
del update_jaxpr, update_consts, updates_shape # Unused.
|
||||
@ -2782,7 +2814,7 @@ def _scatter_batching_rule(batched_args, batch_dims, update_jaxpr,
|
||||
update_window_dims=update_window_dims,
|
||||
inserted_window_dims=inserted_window_dims,
|
||||
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
|
||||
return scatter_add(operand, scatter_indices, updates, dnums), 0
|
||||
return scatter_op(operand, scatter_indices, updates, dnums), 0
|
||||
else:
|
||||
# see the third case in _gather_batching_rule for comparison and comments
|
||||
scatter_indices = batching.move_dim_to_front(scatter_indices,
|
||||
@ -2803,14 +2835,117 @@ def _scatter_batching_rule(batched_args, batch_dims, update_jaxpr,
|
||||
update_window_dims=update_window_dims,
|
||||
inserted_window_dims=inserted_window_dims,
|
||||
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
|
||||
return scatter_add(operand, scatter_indices, updates, dnums), 0
|
||||
return scatter_op(operand, scatter_indices, updates, dnums), 0
|
||||
|
||||
scatter_p = standard_primitive(
|
||||
scatter_add_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
|
||||
_scatter_translation_rule)
|
||||
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
|
||||
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
|
||||
batching.primitive_batchers[scatter_add_p] = (
|
||||
partial(_scatter_batching_rule, scatter_add))
|
||||
|
||||
|
||||
def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
|
||||
dimension_numbers, updates_shape):
|
||||
operand, scatter_indices, updates = primals
|
||||
g_operand, g_scatter_indices, g_updates = tangents
|
||||
dnums = dimension_numbers
|
||||
assert g_scatter_indices is ad_util.zero
|
||||
|
||||
if g_operand is ad_util.zero and g_updates is ad_util.zero:
|
||||
val_out = scatter_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=dnums,
|
||||
updates_shape=updates_shape)
|
||||
tangent_out = ad_util.zero
|
||||
return val_out, tangent_out
|
||||
|
||||
# If there are overlapping indices in the scatter, it is unspecified which
|
||||
# update "wins". So we use the following perhaps surprising scheme:
|
||||
# a) attach a positive ID to each update in updates, forming (value, id) pairs
|
||||
# (using a new array dimension because scatter doesn't actually support
|
||||
# pairs).
|
||||
# b) perform the scatter, yielding (value, id) updates, which we split apart.
|
||||
# c) perform the inverse gather on the ids (similar to
|
||||
# _scatter_add_transpose), and use it to build a mask for the tangent of
|
||||
# `updates`.
|
||||
# d) perform a scatter-add on the masked JVP values. A benefit of using
|
||||
# scatter-add here is that we don't need a `scatter` transpose rule.
|
||||
|
||||
# a) add unique positive IDs (iotas) to the updates, and zeros to the operand.
|
||||
operand_shape = operand.shape
|
||||
updates_shape = updates.shape
|
||||
updates_dtype = _dtype(updates)
|
||||
|
||||
new_operand = reshape(operand, (1,) + operand_shape)
|
||||
new_operand = pad(new_operand, _zero(operand),
|
||||
((0, 1, 0),) + tuple((0, 0, 0) for _ in operand_shape))
|
||||
|
||||
ids_shape = onp.array(updates_shape)
|
||||
ids_shape[dnums.update_window_dims,] = 1
|
||||
num_ids = onp.prod(ids_shape)
|
||||
update_ids = add(reshape(iota(updates_dtype, num_ids), ids_shape),
|
||||
_ones(updates))
|
||||
|
||||
# TODO(phawkins): there is a potential bug here if the number of updates
|
||||
# is large enough to overflow the number of mantissa bits in a float so IDs
|
||||
# end up colliding. We could also utilize the exponent and sign bits, with a
|
||||
# little more work.
|
||||
assert num_ids < (2 ** onp.finfo(updates_dtype).nmant)
|
||||
|
||||
updates = reshape(updates, (1,) + updates_shape)
|
||||
reshaped_update_ids = reshape(update_ids, (1,) + updates_shape)
|
||||
updates_and_ids = concatenate((updates, reshaped_update_ids), 0)
|
||||
|
||||
new_dnums = ScatterDimensionNumbers(
|
||||
update_window_dims=(0,) + tuple(d + 1 for d in dnums.update_window_dims),
|
||||
inserted_window_dims=tuple(d + 1 for d in dnums.inserted_window_dims),
|
||||
scatter_dims_to_operand_dims=tuple(d + 1 for d in dnums.scatter_dims_to_operand_dims))
|
||||
outputs = scatter_p.bind(
|
||||
new_operand, scatter_indices, updates_and_ids, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=new_dnums,
|
||||
updates_shape=updates_shape)
|
||||
val_out = index_in_dim(outputs, 0, keepdims=False)
|
||||
scattered_ids = index_in_dim(outputs, 1, keepdims=False)
|
||||
|
||||
# b) compute the inverse gather that "undoes" the scatter on the id values.
|
||||
gather_dnums = GatherDimensionNumbers(
|
||||
offset_dims=dnums.update_window_dims,
|
||||
collapsed_slice_dims=dnums.inserted_window_dims,
|
||||
start_index_map=dnums.scatter_dims_to_operand_dims)
|
||||
slice_sizes = []
|
||||
pos = 0
|
||||
for i in xrange(len(scattered_ids.shape)):
|
||||
if i in dnums.inserted_window_dims:
|
||||
slice_sizes.append(1)
|
||||
else:
|
||||
slice_sizes.append(updates_shape[dnums.update_window_dims[pos]])
|
||||
pos += 1
|
||||
gathered_update_ids = gather(scattered_ids, scatter_indices,
|
||||
dimension_numbers=gather_dnums,
|
||||
slice_sizes=slice_sizes)
|
||||
|
||||
# c) mask off input JVP elements that do not correspond to a primal output.
|
||||
g_operand = ad.instantiate_zeros(operand, g_operand)
|
||||
g_updates = ad.instantiate_zeros(updates, g_updates)
|
||||
masked_g_operand = select(eq(scattered_ids, _zeros(scattered_ids)),
|
||||
g_operand, _zeros(g_operand))
|
||||
masked_g_updates = select(eq(update_ids, gathered_update_ids),
|
||||
g_updates, _zeros(g_updates))
|
||||
|
||||
# d) perform a scatter-add to compute the tangent output.
|
||||
tangent_out = scatter_add(masked_g_operand, scatter_indices, masked_g_updates,
|
||||
dimension_numbers=dnums)
|
||||
return val_out, tangent_out
|
||||
|
||||
|
||||
scatter_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
|
||||
_scatter_translation_rule)
|
||||
ad.primitive_jvps[scatter_p] = _scatter_jvp
|
||||
ad.primitive_transposes[scatter_p] = _scatter_transpose_rule
|
||||
batching.primitive_batchers[scatter_p] = _scatter_batching_rule
|
||||
batching.primitive_batchers[scatter_p] = (
|
||||
partial(_scatter_batching_rule, scatter))
|
||||
|
||||
|
||||
def _reduce_shape_rule(operand, init_value, computation, jaxpr, consts, dimensions):
|
||||
|
@ -1437,6 +1437,35 @@ class LaxTest(jtu.JaxTestCase):
|
||||
fun = partial(lax.scatter_add, dimension_numbers=dnums)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
idxs, update_shape, dnums),
|
||||
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
||||
"update_shape": update_shape, "dnums": dnums, "rng": rng,
|
||||
"rng_idx": rng_idx}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(arg_shape))]
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testScatter(self, arg_shape, dtype, idxs, update_shape, dnums, rng,
|
||||
rng_idx):
|
||||
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
||||
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
||||
rng(update_shape, dtype)]
|
||||
fun = partial(lax.scatter, dimension_numbers=dnums)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
class DeviceConstantTest(jtu.JaxTestCase):
|
||||
def _CheckDeviceConstant(self, make_const, expected):
|
||||
@ -2121,7 +2150,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
idxs = tuple(rng(e.shape, e.dtype) for e in idxs)
|
||||
src = rng(shape, dtype)
|
||||
index_take = lambda src: lax.index_take(src, idxs, axes)
|
||||
check_grads(index_take, (src,), 2, 1e-2, 1e-2, 1e-2)
|
||||
check_grads(index_take, (src,), 2, 1e-2, 1e-2, 1)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
|
||||
@ -2148,7 +2177,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes)
|
||||
x = rng(shape, dtype)
|
||||
check_grads(gather, (x,), 2, 1e-2, 1e-2, 1e-2)
|
||||
check_grads(gather, (x,), 2, 1e-2, 1e-2, 1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
@ -2178,7 +2207,37 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
dimension_numbers=dnums)
|
||||
x = rng(arg_shape, dtype)
|
||||
y = rng(update_shape, dtype)
|
||||
check_grads(scatter_add, (x, y), 2, 1e-2, 1e-2, 1e-2)
|
||||
check_grads(scatter_add, (x, y), 2, 1e-2, 1e-2, 1.)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
idxs, update_shape, dnums),
|
||||
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
||||
"update_shape": update_shape, "dnums": dnums, "rng": rng,
|
||||
"rng_idx": rng_idx}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(arg_shape))]
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng,
|
||||
rng_idx):
|
||||
idxs = rng_idx(idxs.shape, idxs.dtype)
|
||||
scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
|
||||
x = rng(arg_shape, dtype)
|
||||
y = rng(update_shape, dtype)
|
||||
check_grads(scatter, (x, y), 2, 1e-2, 1e-2, 1.)
|
||||
|
||||
def testStopGradient(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user