mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[Pallas/TPU] Add vmap support for dynamic grid
PiperOrigin-RevId: 603502393
This commit is contained in:
parent
b6cda05218
commit
d76705da94
@ -286,8 +286,15 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
interpret: bool,
|
||||
which_linear: tuple[bool, ...],
|
||||
**compiler_params: Any):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
dynamic_grid_args, args = split_list(
|
||||
args, [grid_mapping.num_dynamic_grid_bounds]
|
||||
)
|
||||
dynamic_grid_dims, dims = split_list(
|
||||
dims, [grid_mapping.num_dynamic_grid_bounds]
|
||||
)
|
||||
if any(dim is not batching.not_mapped for dim in dynamic_grid_dims):
|
||||
raise NotImplementedError("Batched dynamic grid bounds unsupported")
|
||||
del dynamic_grid_dims
|
||||
if grid_mapping.num_index_operands:
|
||||
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
|
||||
scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands])
|
||||
@ -363,15 +370,20 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
grid=(axis_size, *grid_mapping.grid),
|
||||
block_mappings=tuple(batched_block_mappings),
|
||||
mapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.mapped_dims))
|
||||
out = pallas_call_p.bind(*args, jaxpr=jaxpr, name=f"batched_{name}",
|
||||
in_shapes=batched_in_shapes,
|
||||
out_shapes=batched_out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=batched_grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
**compiler_params)
|
||||
out = pallas_call_p.bind(
|
||||
*dynamic_grid_args,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
name=f"batched_{name}",
|
||||
in_shapes=batched_in_shapes,
|
||||
out_shapes=batched_out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=batched_grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
**compiler_params,
|
||||
)
|
||||
return out, (0,) * len(out)
|
||||
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
|
||||
|
||||
|
@ -1113,6 +1113,30 @@ class PallasCallTest(PallasTPUTest):
|
||||
dynamic_kernel(4), np.full(shape, 8.0, np.float32)
|
||||
)
|
||||
|
||||
def test_vmap_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = x_ref[...]
|
||||
y_ref[...] += 1.
|
||||
|
||||
@jax.jit
|
||||
def dynamic_kernel(x, steps):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)(x)
|
||||
x = jnp.arange(4 * 8 * 128.).reshape((4, *shape))
|
||||
np.testing.assert_array_equal(
|
||||
jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, 4),
|
||||
x + 8,
|
||||
)
|
||||
|
||||
|
||||
class PallasUXTest(PallasTPUTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user