Add batching rule for dynamic_update_slice.

This commit is contained in:
Peter Hawkins 2019-04-30 11:48:53 -04:00
parent da98c52e24
commit 6b4c74b182
2 changed files with 36 additions and 0 deletions

View File

@ -2536,12 +2536,31 @@ def _dynamic_update_slice_translation_rule(c, operand, update, start_indices,
update_shape):
return c.DynamicUpdateSlice(operand, update, start_indices)
def _dynamic_update_slice_batching_rule(batched_args, batch_dims, update_shape):
# A dynamic update slice is a special case of scatter; we can delegate to the
# scatter batching rule.
# TODO(phawkins): consider removing dynamic_update_slice entirely and using
# scatter always.
operand, update, index = batched_args
operand_bdims, update_bdims, index_bdims = batch_dims
dims = tuple(range(len(update_shape)))
dnums = ScatterDimensionNumbers(update_window_dims=dims,
inserted_window_dims=(),
scatter_dims_to_operand_dims=dims)
return _scatter_batching_rule(
scatter,
(operand, index, update), (operand_bdims, index_bdims, update_bdims),
None, None, dnums, update_shape)
dynamic_update_slice_p = standard_primitive(
_dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule,
'dynamic_update_slice', _dynamic_update_slice_translation_rule)
ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp
ad.primitive_transposes[dynamic_update_slice_p] = \
_dynamic_update_slice_transpose_rule
batching.primitive_batchers[dynamic_update_slice_p] = \
_dynamic_update_slice_batching_rule

View File

@ -367,6 +367,23 @@ class BatchingTest(jtu.JaxTestCase):
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)
def testDynamicUpdateSlice(self):
x = onp.random.randn(10, 3)
y = onp.random.randn(10)
ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
in_axes=(0, 0, None))(x, y, 1)
expected = x.copy()
expected[:, 1] = y
self.assertAllClose(ans, expected, check_dtypes=False)
x = onp.random.randn(3)
idx = onp.array([0, 1, 2, 1, 0] * 2)
ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
in_axes=(None, 0, 0))(x, y, idx)
expected = onp.broadcast_to(x, (10, 3)).copy()
expected[onp.arange(10), idx] = y
self.assertAllClose(ans, expected, check_dtypes=False)
def testRandom(self):
seeds = vmap(random.PRNGKey)(onp.arange(10))
ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)