mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix scatter batching rule for scatter_apply
The issue is that the batching rule assumes that each scatter variant always has the same update_jaxpr. This is not true of scatter_apply, which lowers to scatter with a custom update_jaxpr. To address this, we change the batching rule such that it re-uses the input jaxpr rather than always re-generating it.
This commit is contained in:
parent
f4eed78e90
commit
1b3da85758
@ -1383,11 +1383,12 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
|
||||
inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=dims)
|
||||
index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd)
|
||||
return _scatter_batching_rule(
|
||||
scatter, (operand, index, update), (operand_bd, index_bdim, update_bd),
|
||||
update_jaxpr=None, update_consts=None, dimension_numbers=dnums,
|
||||
return jax.vmap(
|
||||
partial(scatter, dimension_numbers=dnums,
|
||||
indices_are_sorted=True, unique_indices=True,
|
||||
mode=GatherScatterMode.CLIP)
|
||||
mode=GatherScatterMode.CLIP),
|
||||
in_axes=(operand_bd, index_bdim, update_bd),
|
||||
out_axes=0)(operand, index, update), 0
|
||||
|
||||
|
||||
dynamic_update_slice_p = standard_primitive(
|
||||
@ -2067,7 +2068,6 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
|
||||
indices_are_sorted, unique_indices, mode):
|
||||
operand, indices, updates = batched_args
|
||||
operand_bdim, indices_bdim, updates_bdim = batch_dims
|
||||
del update_jaxpr, update_consts # Unused.
|
||||
|
||||
# move the operand batch dim to the front if it is not None, otherwise create
|
||||
# it at the front (so that we can scatter into it)
|
||||
@ -2086,10 +2086,10 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
|
||||
update_window_dims=update_window_dims,
|
||||
inserted_window_dims=inserted_window_dims,
|
||||
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
|
||||
return scatter_op(
|
||||
operand, indices, updates, dnums,
|
||||
return scatter_op.bind(
|
||||
operand, indices, updates, dimension_numbers=dnums,
|
||||
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
||||
mode=mode), 0
|
||||
mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0
|
||||
|
||||
|
||||
# see the third case in _gather_batching_rule for comparison and comments
|
||||
@ -2108,10 +2108,10 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
|
||||
update_window_dims=update_window_dims,
|
||||
inserted_window_dims=inserted_window_dims,
|
||||
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
|
||||
return scatter_op(
|
||||
operand, indices, updates, dnums,
|
||||
return scatter_op.bind(
|
||||
operand, indices, updates, dimension_numbers=dnums,
|
||||
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
||||
mode=mode), 0
|
||||
mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0
|
||||
|
||||
scatter_add_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
|
||||
@ -2119,7 +2119,7 @@ scatter_add_p = standard_primitive(
|
||||
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
|
||||
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
|
||||
batching.primitive_batchers[scatter_add_p] = (
|
||||
partial(_scatter_batching_rule, scatter_add))
|
||||
partial(_scatter_batching_rule, scatter_add_p))
|
||||
|
||||
scatter_mul_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
|
||||
@ -2141,7 +2141,7 @@ ad.defjvp(scatter_mul_p,
|
||||
_scatter_mul_jvp_rhs)
|
||||
ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
|
||||
batching.primitive_batchers[scatter_mul_p] = (
|
||||
partial(_scatter_batching_rule, scatter_mul))
|
||||
partial(_scatter_batching_rule, scatter_mul_p))
|
||||
|
||||
def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
|
||||
update_consts, dimension_numbers,
|
||||
@ -2248,14 +2248,14 @@ scatter_min_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
|
||||
weak_type_rule=_argnum_weak_type(0))
|
||||
batching.primitive_batchers[scatter_min_p] = (
|
||||
partial(_scatter_batching_rule, scatter_min))
|
||||
partial(_scatter_batching_rule, scatter_min_p))
|
||||
ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p)
|
||||
|
||||
scatter_max_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
|
||||
weak_type_rule=_argnum_weak_type(0))
|
||||
batching.primitive_batchers[scatter_max_p] = (
|
||||
partial(_scatter_batching_rule, scatter_max))
|
||||
partial(_scatter_batching_rule, scatter_max_p))
|
||||
ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p)
|
||||
|
||||
def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
|
||||
@ -2401,7 +2401,7 @@ scatter_p = standard_primitive(
|
||||
ad.primitive_jvps[scatter_p] = _scatter_jvp
|
||||
ad.primitive_transposes[scatter_p] = _scatter_transpose_rule
|
||||
batching.primitive_batchers[scatter_p] = (
|
||||
partial(_scatter_batching_rule, scatter))
|
||||
partial(_scatter_batching_rule, scatter_p))
|
||||
|
||||
|
||||
def _scatter_lower_opaque(ctx, operand, indices, updates, *,
|
||||
|
@ -466,6 +466,15 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker)
|
||||
self._CompileAndCheck(jnp_op_idx, args_maker)
|
||||
|
||||
def testIndexApplyBatchingBug(self):
|
||||
# https://github.com/google/jax/issues/16655
|
||||
arr = jnp.array([[1, 2, 3, 4, 5, 6]])
|
||||
ind = jnp.array([3])
|
||||
func = lambda a, i: a.at[i].apply(lambda x: x - 1)
|
||||
expected = jnp.array(list(map(func, arr, ind)))
|
||||
out = jax.vmap(func)(arr, ind)
|
||||
self.assertArraysEqual(out, expected)
|
||||
|
||||
def testIndexUpdateScalarBug(self):
|
||||
# https://github.com/google/jax/issues/14923
|
||||
a = jnp.arange(10.)
|
||||
|
@ -2302,6 +2302,29 @@ class LaxTest(jtu.JaxTestCase):
|
||||
fun = partial(lax.scatter_max, dimension_numbers=dnums)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape, dnums=dnums)
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]],
|
||||
dtype=lax_test_util.float_dtypes,
|
||||
)
|
||||
def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
|
||||
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
||||
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs()]
|
||||
fun = partial(lax.scatter_apply, func=jnp.sin, update_shape=update_shape, dimension_numbers=dnums)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
|
||||
dnums=dnums)
|
||||
|
@ -601,6 +601,29 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()),
|
||||
rtol={np.float16: 5e-3, dtypes.bfloat16: 7e-2})
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
|
||||
dnums=dnums, bdims=bdims)
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape)],
|
||||
dtype=lax_test_util.float_dtypes,
|
||||
)
|
||||
def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
|
||||
fun = partial(lax.scatter_apply, func=jnp.sin, update_shape=update_shape, dimension_numbers=dnums)
|
||||
self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape],
|
||||
[dtype, idxs.dtype], jtu.rand_default(self.rng()),
|
||||
rtol={np.float16: 5e-3, dtypes.bfloat16: 7e-2})
|
||||
|
||||
def testShapeUsesBuiltinInt(self):
|
||||
x = lax.iota(np.int32, 3) + 1
|
||||
self.assertIsInstance(x.shape[0], int) # not np.int64
|
||||
|
Loading…
x
Reference in New Issue
Block a user