diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 423aafe18..e5006608d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -286,8 +286,15 @@ def _pallas_call_batching_rule(args, dims, *, interpret: bool, which_linear: tuple[bool, ...], **compiler_params: Any): - if grid_mapping.num_dynamic_grid_bounds: - raise NotImplementedError("interpret with dynamic grid bounds unsupported") + dynamic_grid_args, args = split_list( + args, [grid_mapping.num_dynamic_grid_bounds] + ) + dynamic_grid_dims, dims = split_list( + dims, [grid_mapping.num_dynamic_grid_bounds] + ) + if any(dim is not batching.not_mapped for dim in dynamic_grid_dims): + raise NotImplementedError("Batched dynamic grid bounds unsupported") + del dynamic_grid_dims if grid_mapping.num_index_operands: scalar_args, args = split_list(args, [grid_mapping.num_index_operands]) scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands]) @@ -363,15 +370,20 @@ def _pallas_call_batching_rule(args, dims, *, grid=(axis_size, *grid_mapping.grid), block_mappings=tuple(batched_block_mappings), mapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.mapped_dims)) - out = pallas_call_p.bind(*args, jaxpr=jaxpr, name=f"batched_{name}", - in_shapes=batched_in_shapes, - out_shapes=batched_out_shapes, - which_linear=which_linear, - grid_mapping=batched_grid_mapping, - input_output_aliases=input_output_aliases, - debug=debug, - interpret=interpret, - **compiler_params) + out = pallas_call_p.bind( + *dynamic_grid_args, + *args, + jaxpr=jaxpr, + name=f"batched_{name}", + in_shapes=batched_in_shapes, + out_shapes=batched_out_shapes, + which_linear=which_linear, + grid_mapping=batched_grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + **compiler_params, + ) return out, (0,) * len(out) batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 5200a498d..11ecb5708 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -1113,6 +1113,30 @@ class PallasCallTest(PallasTPUTest): dynamic_kernel(4), np.full(shape, 8.0, np.float32) ) + def test_vmap_dynamic_grid(self): + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(x_ref, y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = x_ref[...] + y_ref[...] += 1. + + @jax.jit + def dynamic_kernel(x, steps): + return pl.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_shape=result_ty, + )(x) + x = jnp.arange(4 * 8 * 128.).reshape((4, *shape)) + np.testing.assert_array_equal( + jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, 4), + x + 8, + ) + class PallasUXTest(PallasTPUTest):