Implement a scatter-update operator.

This commit is contained in:
Peter Hawkins 2019-03-01 15:41:49 -05:00
parent 1497c31590
commit 148073edf7
2 changed files with 213 additions and 19 deletions

View File

@ -604,10 +604,11 @@ def gather(operand, start_indices, dimension_numbers, slice_sizes):
slice_sizes=tuple(slice_sizes), operand_shape=operand.shape)
def scatter_add(operand, scatter_indices, updates, dimension_numbers):
"""Scatter operator.
"""Scatter-add operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_.
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
addition is used to combine updates and values from `operand`.
The semantics of scatter are complicated and its API is subject to change.
@ -624,6 +625,36 @@ def scatter_add(operand, scatter_indices, updates, dimension_numbers):
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = _reduction_jaxpr(add, _const(operand, 0))
return scatter_add_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
updates_shape=updates.shape)
def scatter(operand, scatter_indices, updates, dimension_numbers):
"""Scatter-update operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where updates
replace values from `operand`.
If multiple updates are performed to the same index of operand, they may be
applied in any order.
The semantics of scatter are complicated and its API is subject to change.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
relate.
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = _reduction_jaxpr(lambda x, y: y, _const(operand, 0))
return scatter_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -2712,12 +2743,12 @@ def _scatter_translation_rule(c, operand, scatter_indices, updates,
return c.Scatter(operand, scatter_indices, updates, update_computation,
_scatter_dimensions_proto(indices_shape, dimension_numbers))
def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
dimension_numbers, updates_shape):
def _scatter_add_jvp(primals, tangents, update_jaxpr, update_consts,
dimension_numbers, updates_shape):
operand, scatter_indices, updates = primals
g_operand, g_scatter_indices, g_updates = tangents
assert g_scatter_indices is ad_util.zero
val_out = scatter_p.bind(
val_out = scatter_add_p.bind(
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
updates_shape=updates_shape)
@ -2726,15 +2757,15 @@ def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
else:
g_operand = ad.instantiate_zeros(operand, g_operand)
g_updates = ad.instantiate_zeros(updates, g_updates)
tangent_out = scatter_p.bind(
tangent_out = scatter_add_p.bind(
g_operand, scatter_indices, g_updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
updates_shape=updates_shape)
return val_out, tangent_out
def _scatter_transpose_rule(t, operand, scatter_indices, updates,
update_jaxpr, update_consts, dimension_numbers,
updates_shape):
def _scatter_add_transpose_rule(t, operand, scatter_indices, updates,
update_jaxpr, update_consts, dimension_numbers,
updates_shape):
assert scatter_indices is not None
operand_t = update_t = None
if operand is None:
@ -2757,8 +2788,9 @@ def _scatter_transpose_rule(t, operand, scatter_indices, updates,
slice_sizes=slice_sizes)
return [operand_t, None, update_t]
def _scatter_batching_rule(batched_args, batch_dims, update_jaxpr,
update_consts, dimension_numbers, updates_shape):
def _scatter_batching_rule(
scatter_op, batched_args, batch_dims, update_jaxpr, update_consts,
dimension_numbers, updates_shape):
operand, scatter_indices, updates = batched_args
operand_bdim, scatter_indices_bdim, updates_bdim = batch_dims
del update_jaxpr, update_consts, updates_shape # Unused.
@ -2782,7 +2814,7 @@ def _scatter_batching_rule(batched_args, batch_dims, update_jaxpr,
update_window_dims=update_window_dims,
inserted_window_dims=inserted_window_dims,
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
return scatter_add(operand, scatter_indices, updates, dnums), 0
return scatter_op(operand, scatter_indices, updates, dnums), 0
else:
# see the third case in _gather_batching_rule for comparison and comments
scatter_indices = batching.move_dim_to_front(scatter_indices,
@ -2803,14 +2835,117 @@ def _scatter_batching_rule(batched_args, batch_dims, update_jaxpr,
update_window_dims=update_window_dims,
inserted_window_dims=inserted_window_dims,
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
return scatter_add(operand, scatter_indices, updates, dnums), 0
return scatter_op(operand, scatter_indices, updates, dnums), 0
scatter_p = standard_primitive(
scatter_add_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
_scatter_translation_rule)
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
batching.primitive_batchers[scatter_add_p] = (
partial(_scatter_batching_rule, scatter_add))
def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
dimension_numbers, updates_shape):
operand, scatter_indices, updates = primals
g_operand, g_scatter_indices, g_updates = tangents
dnums = dimension_numbers
assert g_scatter_indices is ad_util.zero
if g_operand is ad_util.zero and g_updates is ad_util.zero:
val_out = scatter_p.bind(
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dnums,
updates_shape=updates_shape)
tangent_out = ad_util.zero
return val_out, tangent_out
# If there are overlapping indices in the scatter, it is unspecified which
# update "wins". So we use the following perhaps surprising scheme:
# a) attach a positive ID to each update in updates, forming (value, id) pairs
# (using a new array dimension because scatter doesn't actually support
# pairs).
# b) perform the scatter, yielding (value, id) updates, which we split apart.
# c) perform the inverse gather on the ids (similar to
# _scatter_add_transpose), and use it to build a mask for the tangent of
# `updates`.
# d) perform a scatter-add on the masked JVP values. A benefit of using
# scatter-add here is that we don't need a `scatter` transpose rule.
# a) add unique positive IDs (iotas) to the updates, and zeros to the operand.
operand_shape = operand.shape
updates_shape = updates.shape
updates_dtype = _dtype(updates)
new_operand = reshape(operand, (1,) + operand_shape)
new_operand = pad(new_operand, _zero(operand),
((0, 1, 0),) + tuple((0, 0, 0) for _ in operand_shape))
ids_shape = onp.array(updates_shape)
ids_shape[dnums.update_window_dims,] = 1
num_ids = onp.prod(ids_shape)
update_ids = add(reshape(iota(updates_dtype, num_ids), ids_shape),
_ones(updates))
# TODO(phawkins): there is a potential bug here if the number of updates
# is large enough to overflow the number of mantissa bits in a float so IDs
# end up colliding. We could also utilize the exponent and sign bits, with a
# little more work.
assert num_ids < (2 ** onp.finfo(updates_dtype).nmant)
updates = reshape(updates, (1,) + updates_shape)
reshaped_update_ids = reshape(update_ids, (1,) + updates_shape)
updates_and_ids = concatenate((updates, reshaped_update_ids), 0)
new_dnums = ScatterDimensionNumbers(
update_window_dims=(0,) + tuple(d + 1 for d in dnums.update_window_dims),
inserted_window_dims=tuple(d + 1 for d in dnums.inserted_window_dims),
scatter_dims_to_operand_dims=tuple(d + 1 for d in dnums.scatter_dims_to_operand_dims))
outputs = scatter_p.bind(
new_operand, scatter_indices, updates_and_ids, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=new_dnums,
updates_shape=updates_shape)
val_out = index_in_dim(outputs, 0, keepdims=False)
scattered_ids = index_in_dim(outputs, 1, keepdims=False)
# b) compute the inverse gather that "undoes" the scatter on the id values.
gather_dnums = GatherDimensionNumbers(
offset_dims=dnums.update_window_dims,
collapsed_slice_dims=dnums.inserted_window_dims,
start_index_map=dnums.scatter_dims_to_operand_dims)
slice_sizes = []
pos = 0
for i in xrange(len(scattered_ids.shape)):
if i in dnums.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates_shape[dnums.update_window_dims[pos]])
pos += 1
gathered_update_ids = gather(scattered_ids, scatter_indices,
dimension_numbers=gather_dnums,
slice_sizes=slice_sizes)
# c) mask off input JVP elements that do not correspond to a primal output.
g_operand = ad.instantiate_zeros(operand, g_operand)
g_updates = ad.instantiate_zeros(updates, g_updates)
masked_g_operand = select(eq(scattered_ids, _zeros(scattered_ids)),
g_operand, _zeros(g_operand))
masked_g_updates = select(eq(update_ids, gathered_update_ids),
g_updates, _zeros(g_updates))
# d) perform a scatter-add to compute the tangent output.
tangent_out = scatter_add(masked_g_operand, scatter_indices, masked_g_updates,
dimension_numbers=dnums)
return val_out, tangent_out
scatter_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
_scatter_translation_rule)
ad.primitive_jvps[scatter_p] = _scatter_jvp
ad.primitive_transposes[scatter_p] = _scatter_transpose_rule
batching.primitive_batchers[scatter_p] = _scatter_batching_rule
batching.primitive_batchers[scatter_p] = (
partial(_scatter_batching_rule, scatter))
def _reduce_shape_rule(operand, init_value, computation, jaxpr, consts, dimensions):

View File

@ -1437,6 +1437,35 @@ class LaxTest(jtu.JaxTestCase):
fun = partial(lax.scatter_add, dimension_numbers=dnums)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
idxs, update_shape, dnums),
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
"update_shape": update_shape, "dnums": dnums, "rng": rng,
"rng_idx": rng_idx}
for dtype in float_dtypes
for arg_shape, idxs, update_shape, dnums in [
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]
for rng_idx in [jtu.rand_int(max(arg_shape))]
for rng in [jtu.rand_default()]))
def testScatter(self, arg_shape, dtype, idxs, update_shape, dnums, rng,
rng_idx):
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
rng(update_shape, dtype)]
fun = partial(lax.scatter, dimension_numbers=dnums)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
class DeviceConstantTest(jtu.JaxTestCase):
def _CheckDeviceConstant(self, make_const, expected):
@ -2121,7 +2150,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
idxs = tuple(rng(e.shape, e.dtype) for e in idxs)
src = rng(shape, dtype)
index_take = lambda src: lax.index_take(src, idxs, axes)
check_grads(index_take, (src,), 2, 1e-2, 1e-2, 1e-2)
check_grads(index_take, (src,), 2, 1e-2, 1e-2, 1)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
@ -2148,7 +2177,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums,
slice_sizes=slice_sizes)
x = rng(shape, dtype)
check_grads(gather, (x,), 2, 1e-2, 1e-2, 1e-2)
check_grads(gather, (x,), 2, 1e-2, 1e-2, 1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
@ -2178,7 +2207,37 @@ class LaxAutodiffTest(jtu.JaxTestCase):
dimension_numbers=dnums)
x = rng(arg_shape, dtype)
y = rng(update_shape, dtype)
check_grads(scatter_add, (x, y), 2, 1e-2, 1e-2, 1e-2)
check_grads(scatter_add, (x, y), 2, 1e-2, 1e-2, 1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
idxs, update_shape, dnums),
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
"update_shape": update_shape, "dnums": dnums, "rng": rng,
"rng_idx": rng_idx}
for dtype in float_dtypes
for arg_shape, idxs, update_shape, dnums in [
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]
for rng_idx in [jtu.rand_int(max(arg_shape))]
for rng in [jtu.rand_default()]))
def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng,
rng_idx):
idxs = rng_idx(idxs.shape, idxs.dtype)
scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
x = rng(arg_shape, dtype)
y = rng(update_shape, dtype)
check_grads(scatter, (x, y), 2, 1e-2, 1e-2, 1.)
def testStopGradient(self):
def f(x):