mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add batching rule for dynamic_update_slice.
This commit is contained in:
parent
da98c52e24
commit
6b4c74b182
@ -2536,12 +2536,31 @@ def _dynamic_update_slice_translation_rule(c, operand, update, start_indices,
|
||||
update_shape):
|
||||
return c.DynamicUpdateSlice(operand, update, start_indices)
|
||||
|
||||
def _dynamic_update_slice_batching_rule(batched_args, batch_dims, update_shape):
|
||||
# A dynamic update slice is a special case of scatter; we can delegate to the
|
||||
# scatter batching rule.
|
||||
# TODO(phawkins): consider removing dynamic_update_slice entirely and using
|
||||
# scatter always.
|
||||
operand, update, index = batched_args
|
||||
operand_bdims, update_bdims, index_bdims = batch_dims
|
||||
dims = tuple(range(len(update_shape)))
|
||||
dnums = ScatterDimensionNumbers(update_window_dims=dims,
|
||||
inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=dims)
|
||||
return _scatter_batching_rule(
|
||||
scatter,
|
||||
(operand, index, update), (operand_bdims, index_bdims, update_bdims),
|
||||
None, None, dnums, update_shape)
|
||||
|
||||
|
||||
dynamic_update_slice_p = standard_primitive(
|
||||
_dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule,
|
||||
'dynamic_update_slice', _dynamic_update_slice_translation_rule)
|
||||
ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp
|
||||
ad.primitive_transposes[dynamic_update_slice_p] = \
|
||||
_dynamic_update_slice_transpose_rule
|
||||
batching.primitive_batchers[dynamic_update_slice_p] = \
|
||||
_dynamic_update_slice_batching_rule
|
||||
|
||||
|
||||
|
||||
|
@ -367,6 +367,23 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
expected = x[idx]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testDynamicUpdateSlice(self):
|
||||
x = onp.random.randn(10, 3)
|
||||
y = onp.random.randn(10)
|
||||
ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
|
||||
in_axes=(0, 0, None))(x, y, 1)
|
||||
expected = x.copy()
|
||||
expected[:, 1] = y
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
x = onp.random.randn(3)
|
||||
idx = onp.array([0, 1, 2, 1, 0] * 2)
|
||||
ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
|
||||
in_axes=(None, 0, 0))(x, y, idx)
|
||||
expected = onp.broadcast_to(x, (10, 3)).copy()
|
||||
expected[onp.arange(10), idx] = y
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testRandom(self):
|
||||
seeds = vmap(random.PRNGKey)(onp.arange(10))
|
||||
ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
|
||||
|
Loading…
x
Reference in New Issue
Block a user