Fix scatter batching rule for scatter_apply

The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
This commit is contained in:
Jake VanderPlas 2023-07-10 16:42:45 -07:00
parent f4eed78e90
commit 1b3da85758
4 changed files with 72 additions and 17 deletions

View File

@ -1383,11 +1383,12 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
inserted_window_dims=(),
scatter_dims_to_operand_dims=dims)
index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd)
return _scatter_batching_rule(
scatter, (operand, index, update), (operand_bd, index_bdim, update_bd),
update_jaxpr=None, update_consts=None, dimension_numbers=dnums,
return jax.vmap(
partial(scatter, dimension_numbers=dnums,
indices_are_sorted=True, unique_indices=True,
mode=GatherScatterMode.CLIP)
mode=GatherScatterMode.CLIP),
in_axes=(operand_bd, index_bdim, update_bd),
out_axes=0)(operand, index, update), 0
dynamic_update_slice_p = standard_primitive(
@ -2067,7 +2068,6 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
indices_are_sorted, unique_indices, mode):
operand, indices, updates = batched_args
operand_bdim, indices_bdim, updates_bdim = batch_dims
del update_jaxpr, update_consts # Unused.
# move the operand batch dim to the front if it is not None, otherwise create
# it at the front (so that we can scatter into it)
@ -2086,10 +2086,10 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
update_window_dims=update_window_dims,
inserted_window_dims=inserted_window_dims,
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
return scatter_op(
operand, indices, updates, dnums,
return scatter_op.bind(
operand, indices, updates, dimension_numbers=dnums,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode), 0
mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0
# see the third case in _gather_batching_rule for comparison and comments
@ -2108,10 +2108,10 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
update_window_dims=update_window_dims,
inserted_window_dims=inserted_window_dims,
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
return scatter_op(
operand, indices, updates, dnums,
return scatter_op.bind(
operand, indices, updates, dimension_numbers=dnums,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode), 0
mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0
scatter_add_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
@ -2119,7 +2119,7 @@ scatter_add_p = standard_primitive(
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
batching.primitive_batchers[scatter_add_p] = (
partial(_scatter_batching_rule, scatter_add))
partial(_scatter_batching_rule, scatter_add_p))
scatter_mul_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
@ -2141,7 +2141,7 @@ ad.defjvp(scatter_mul_p,
_scatter_mul_jvp_rhs)
ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
batching.primitive_batchers[scatter_mul_p] = (
partial(_scatter_batching_rule, scatter_mul))
partial(_scatter_batching_rule, scatter_mul_p))
def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
update_consts, dimension_numbers,
@ -2248,14 +2248,14 @@ scatter_min_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
weak_type_rule=_argnum_weak_type(0))
batching.primitive_batchers[scatter_min_p] = (
partial(_scatter_batching_rule, scatter_min))
partial(_scatter_batching_rule, scatter_min_p))
ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p)
scatter_max_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
weak_type_rule=_argnum_weak_type(0))
batching.primitive_batchers[scatter_max_p] = (
partial(_scatter_batching_rule, scatter_max))
partial(_scatter_batching_rule, scatter_max_p))
ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p)
def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
@ -2401,7 +2401,7 @@ scatter_p = standard_primitive(
ad.primitive_jvps[scatter_p] = _scatter_jvp
ad.primitive_transposes[scatter_p] = _scatter_transpose_rule
batching.primitive_batchers[scatter_p] = (
partial(_scatter_batching_rule, scatter))
partial(_scatter_batching_rule, scatter_p))
def _scatter_lower_opaque(ctx, operand, indices, updates, *,

View File

@ -466,6 +466,15 @@ class IndexingTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker)
self._CompileAndCheck(jnp_op_idx, args_maker)
def testIndexApplyBatchingBug(self):
# https://github.com/google/jax/issues/16655
arr = jnp.array([[1, 2, 3, 4, 5, 6]])
ind = jnp.array([3])
func = lambda a, i: a.at[i].apply(lambda x: x - 1)
expected = jnp.array(list(map(func, arr, ind)))
out = jax.vmap(func)(arr, ind)
self.assertArraysEqual(out, expected)
def testIndexUpdateScalarBug(self):
# https://github.com/google/jax/issues/14923
a = jnp.arange(10.)

View File

@ -2302,6 +2302,29 @@ class LaxTest(jtu.JaxTestCase):
fun = partial(lax.scatter_max, dimension_numbers=dnums)
self._CompileAndCheck(fun, args_maker)
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape, dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]],
dtype=lax_test_util.float_dtypes,
)
def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums):
rng = jtu.rand_default(self.rng())
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs()]
fun = partial(lax.scatter_apply, func=jnp.sin, update_shape=update_shape, dimension_numbers=dnums)
self._CompileAndCheck(fun, args_maker)
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)

View File

@ -601,6 +601,29 @@ class LaxVmapTest(jtu.JaxTestCase):
[dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()),
rtol={np.float16: 5e-3, dtypes.bfloat16: 7e-2})
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums, bdims=bdims)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]
for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape)],
dtype=lax_test_util.float_dtypes,
)
def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
fun = partial(lax.scatter_apply, func=jnp.sin, update_shape=update_shape, dimension_numbers=dnums)
self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape],
[dtype, idxs.dtype], jtu.rand_default(self.rng()),
rtol={np.float16: 5e-3, dtypes.bfloat16: 7e-2})
def testShapeUsesBuiltinInt(self):
x = lax.iota(np.int32, 3) + 1
self.assertIsInstance(x.shape[0], int) # not np.int64