[Pallas/TPU] Add vmap support for dynamic grid

PiperOrigin-RevId: 603502393
This commit is contained in:
Sharad Vikram 2024-02-01 15:38:46 -08:00 committed by jax authors
parent b6cda05218
commit d76705da94
2 changed files with 47 additions and 11 deletions

View File

@ -286,8 +286,15 @@ def _pallas_call_batching_rule(args, dims, *,
interpret: bool,
which_linear: tuple[bool, ...],
**compiler_params: Any):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
dynamic_grid_args, args = split_list(
args, [grid_mapping.num_dynamic_grid_bounds]
)
dynamic_grid_dims, dims = split_list(
dims, [grid_mapping.num_dynamic_grid_bounds]
)
if any(dim is not batching.not_mapped for dim in dynamic_grid_dims):
raise NotImplementedError("Batched dynamic grid bounds unsupported")
del dynamic_grid_dims
if grid_mapping.num_index_operands:
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands])
@ -363,15 +370,20 @@ def _pallas_call_batching_rule(args, dims, *,
grid=(axis_size, *grid_mapping.grid),
block_mappings=tuple(batched_block_mappings),
mapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.mapped_dims))
out = pallas_call_p.bind(*args, jaxpr=jaxpr, name=f"batched_{name}",
in_shapes=batched_in_shapes,
out_shapes=batched_out_shapes,
which_linear=which_linear,
grid_mapping=batched_grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
**compiler_params)
out = pallas_call_p.bind(
*dynamic_grid_args,
*args,
jaxpr=jaxpr,
name=f"batched_{name}",
in_shapes=batched_in_shapes,
out_shapes=batched_out_shapes,
which_linear=which_linear,
grid_mapping=batched_grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
**compiler_params,
)
return out, (0,) * len(out)
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule

View File

@ -1113,6 +1113,30 @@ class PallasCallTest(PallasTPUTest):
dynamic_kernel(4), np.full(shape, 8.0, np.float32)
)
def test_vmap_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(x_ref, y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = x_ref[...]
y_ref[...] += 1.
@jax.jit
def dynamic_kernel(x, steps):
return pl.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)(x)
x = jnp.arange(4 * 8 * 128.).reshape((4, *shape))
np.testing.assert_array_equal(
jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, 4),
x + 8,
)
class PallasUXTest(PallasTPUTest):