diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index bcc3f1db1..6e019853c 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1383,11 +1383,12 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): inserted_window_dims=(), scatter_dims_to_operand_dims=dims) index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd) - return _scatter_batching_rule( - scatter, (operand, index, update), (operand_bd, index_bdim, update_bd), - update_jaxpr=None, update_consts=None, dimension_numbers=dnums, - indices_are_sorted=True, unique_indices=True, - mode=GatherScatterMode.CLIP) + return jax.vmap( + partial(scatter, dimension_numbers=dnums, + indices_are_sorted=True, unique_indices=True, + mode=GatherScatterMode.CLIP), + in_axes=(operand_bd, index_bdim, update_bd), + out_axes=0)(operand, index, update), 0 dynamic_update_slice_p = standard_primitive( @@ -2067,7 +2068,6 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, indices_are_sorted, unique_indices, mode): operand, indices, updates = batched_args operand_bdim, indices_bdim, updates_bdim = batch_dims - del update_jaxpr, update_consts # Unused. # move the operand batch dim to the front if it is not None, otherwise create # it at the front (so that we can scatter into it) @@ -2086,10 +2086,10 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) - return scatter_op( - operand, indices, updates, dnums, + return scatter_op.bind( + operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode), 0 + mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0 # see the third case in _gather_batching_rule for comparison and comments @@ -2108,10 +2108,10 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) - return scatter_op( - operand, indices, updates, dnums, + return scatter_op.bind( + operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode), 0 + mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0 scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', @@ -2119,7 +2119,7 @@ scatter_add_p = standard_primitive( ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule batching.primitive_batchers[scatter_add_p] = ( - partial(_scatter_batching_rule, scatter_add)) + partial(_scatter_batching_rule, scatter_add_p)) scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', @@ -2141,7 +2141,7 @@ ad.defjvp(scatter_mul_p, _scatter_mul_jvp_rhs) ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule batching.primitive_batchers[scatter_mul_p] = ( - partial(_scatter_batching_rule, scatter_mul)) + partial(_scatter_batching_rule, scatter_mul_p)) def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, update_consts, dimension_numbers, @@ -2248,14 +2248,14 @@ scatter_min_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min', weak_type_rule=_argnum_weak_type(0)) batching.primitive_batchers[scatter_min_p] = ( - partial(_scatter_batching_rule, scatter_min)) + partial(_scatter_batching_rule, scatter_min_p)) ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p) scatter_max_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max', weak_type_rule=_argnum_weak_type(0)) batching.primitive_batchers[scatter_max_p] = ( - partial(_scatter_batching_rule, scatter_max)) + partial(_scatter_batching_rule, scatter_max_p)) ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p) def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, @@ -2401,7 +2401,7 @@ scatter_p = standard_primitive( ad.primitive_jvps[scatter_p] = _scatter_jvp ad.primitive_transposes[scatter_p] = _scatter_transpose_rule batching.primitive_batchers[scatter_p] = ( - partial(_scatter_batching_rule, scatter)) + partial(_scatter_batching_rule, scatter_p)) def _scatter_lower_opaque(ctx, operand, indices, updates, *, diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 8b19c0a2c..e630586e6 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -466,6 +466,15 @@ class IndexingTest(jtu.JaxTestCase): self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker) self._CompileAndCheck(jnp_op_idx, args_maker) + def testIndexApplyBatchingBug(self): + # https://github.com/google/jax/issues/16655 + arr = jnp.array([[1, 2, 3, 4, 5, 6]]) + ind = jnp.array([3]) + func = lambda a, i: a.at[i].apply(lambda x: x - 1) + expected = jnp.array(list(map(func, arr, ind))) + out = jax.vmap(func)(arr, ind) + self.assertArraysEqual(out, expected) + def testIndexUpdateScalarBug(self): # https://github.com/google/jax/issues/14923 a = jnp.arange(10.) diff --git a/tests/lax_test.py b/tests/lax_test.py index 97d9ce4cf..523c02c62 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2302,6 +2302,29 @@ class LaxTest(jtu.JaxTestCase): fun = partial(lax.scatter_max, dimension_numbers=dnums) self._CompileAndCheck(fun, args_maker) + @jtu.sample_product( + [dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape, dnums=dnums) + for arg_shape, idxs, update_shape, dnums in [ + ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,))), + ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ]], + dtype=lax_test_util.float_dtypes, + ) + def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums): + rng = jtu.rand_default(self.rng()) + rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape)) + rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype) + args_maker = lambda: [rng(arg_shape, dtype), rand_idxs()] + fun = partial(lax.scatter_apply, func=jnp.sin, update_shape=update_shape, dimension_numbers=dnums) + self._CompileAndCheck(fun, args_maker) + @jtu.sample_product( [dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape, dnums=dnums) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index e77b16747..a626d31a1 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -601,6 +601,29 @@ class LaxVmapTest(jtu.JaxTestCase): [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), rtol={np.float16: 5e-3, dtypes.bfloat16: 7e-2}) + @jtu.sample_product( + [dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape, + dnums=dnums, bdims=bdims) + for arg_shape, idxs, update_shape, dnums in [ + ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,))), + ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ] + for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape)], + dtype=lax_test_util.float_dtypes, + ) + def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): + fun = partial(lax.scatter_apply, func=jnp.sin, update_shape=update_shape, dimension_numbers=dnums) + self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape], + [dtype, idxs.dtype], jtu.rand_default(self.rng()), + rtol={np.float16: 5e-3, dtypes.bfloat16: 7e-2}) + def testShapeUsesBuiltinInt(self): x = lax.iota(np.int32, 3) + 1 self.assertIsInstance(x.shape[0], int) # not np.int64