diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index e5006608d..99d748407 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -286,14 +286,32 @@ def _pallas_call_batching_rule(args, dims, *, interpret: bool, which_linear: tuple[bool, ...], **compiler_params: Any): + + def _maybe_squeeze_out_bdim( + x: jax.Array, bdim: int | batching.NotMapped + ) -> jax.Array: + if bdim is batching.not_mapped: + return x + return jnp.squeeze(x, axis=bdim) + dynamic_grid_args, args = split_list( args, [grid_mapping.num_dynamic_grid_bounds] ) dynamic_grid_dims, dims = split_list( dims, [grid_mapping.num_dynamic_grid_bounds] ) - if any(dim is not batching.not_mapped for dim in dynamic_grid_dims): - raise NotImplementedError("Batched dynamic grid bounds unsupported") + if all( + bdim is batching.not_mapped or arg.shape[bdim] == 1 + for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims) + ): + dynamic_grid_args = safe_map( + _maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims) + elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims): + raise NotImplementedError( + f"Batched dynamic grid bounds unsupported: {dynamic_grid_dims}" + ) + else: + pass # No dynamic grid dimensions del dynamic_grid_dims if grid_mapping.num_index_operands: scalar_args, args = split_list(args, [grid_mapping.num_index_operands]) @@ -306,17 +324,15 @@ def _pallas_call_batching_rule(args, dims, *, bdim is batching.not_mapped or arg.shape[bdim] == 1 for arg, bdim in zip(scalar_args, scalar_bdims) ): - def _squeeze_out_bdim(x: jax.Array, bdim: int | batching.NotMapped): - if bdim is batching.not_mapped: - return x - return jnp.squeeze(x, axis=bdim) - scalar_args = safe_map(_squeeze_out_bdim, scalar_args, scalar_bdims) + scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims) scalar_bdims = [None] * len(scalar_args) args = (*scalar_args, *args) dims = (*scalar_bdims, *bdims) else: # TODO(sharadmv,apaszke): enable batching over prefetched scalar args raise NotImplementedError + if not dims: + raise NotImplementedError("vmapping pallas_call with no arguments.") axis_size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} block_mappings = grid_mapping.block_mappings diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 11ecb5708..55707ff4d 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -1113,6 +1113,52 @@ class PallasCallTest(PallasTPUTest): dynamic_kernel(4), np.full(shape, 8.0, np.float32) ) + def test_vmap_trivial_dynamic_grid(self): + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(x_ref, y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = x_ref[...] + y_ref[...] += 1 + + @jax.jit + @jax.vmap + def dynamic_kernel(steps, x): + return pl.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_shape=result_ty, + )(x) + x = jnp.arange(8 * 128.).reshape((1, *shape)) + np.testing.assert_array_equal(dynamic_kernel(jnp.array([4]), x), x + 8.0) + + def test_vmap_nontrivial_dynamic_grid(self): + # Dynamic grid doesn't support vmapping over multiple distinct grid values + # at the moment. + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = jnp.zeros_like(y_ref) + y_ref[...] += 1 + + @jax.jit + @jax.vmap + def dynamic_kernel(steps): + return pl.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_shape=result_ty, + )() + with self.assertRaises(NotImplementedError): + dynamic_kernel(jnp.array([4, 8])) + def test_vmap_dynamic_grid(self): shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)