mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas/TPU] Allow 1-sized batch dim in vmap of dynamic grid
PiperOrigin-RevId: 603518847
This commit is contained in:
parent
d76705da94
commit
a41385c860
@ -286,14 +286,32 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
interpret: bool,
|
||||
which_linear: tuple[bool, ...],
|
||||
**compiler_params: Any):
|
||||
|
||||
def _maybe_squeeze_out_bdim(
|
||||
x: jax.Array, bdim: int | batching.NotMapped
|
||||
) -> jax.Array:
|
||||
if bdim is batching.not_mapped:
|
||||
return x
|
||||
return jnp.squeeze(x, axis=bdim)
|
||||
|
||||
dynamic_grid_args, args = split_list(
|
||||
args, [grid_mapping.num_dynamic_grid_bounds]
|
||||
)
|
||||
dynamic_grid_dims, dims = split_list(
|
||||
dims, [grid_mapping.num_dynamic_grid_bounds]
|
||||
)
|
||||
if any(dim is not batching.not_mapped for dim in dynamic_grid_dims):
|
||||
raise NotImplementedError("Batched dynamic grid bounds unsupported")
|
||||
if all(
|
||||
bdim is batching.not_mapped or arg.shape[bdim] == 1
|
||||
for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims)
|
||||
):
|
||||
dynamic_grid_args = safe_map(
|
||||
_maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims)
|
||||
elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims):
|
||||
raise NotImplementedError(
|
||||
f"Batched dynamic grid bounds unsupported: {dynamic_grid_dims}"
|
||||
)
|
||||
else:
|
||||
pass # No dynamic grid dimensions
|
||||
del dynamic_grid_dims
|
||||
if grid_mapping.num_index_operands:
|
||||
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
|
||||
@ -306,17 +324,15 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
bdim is batching.not_mapped or arg.shape[bdim] == 1
|
||||
for arg, bdim in zip(scalar_args, scalar_bdims)
|
||||
):
|
||||
def _squeeze_out_bdim(x: jax.Array, bdim: int | batching.NotMapped):
|
||||
if bdim is batching.not_mapped:
|
||||
return x
|
||||
return jnp.squeeze(x, axis=bdim)
|
||||
scalar_args = safe_map(_squeeze_out_bdim, scalar_args, scalar_bdims)
|
||||
scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims)
|
||||
scalar_bdims = [None] * len(scalar_args)
|
||||
args = (*scalar_args, *args)
|
||||
dims = (*scalar_bdims, *bdims)
|
||||
else:
|
||||
# TODO(sharadmv,apaszke): enable batching over prefetched scalar args
|
||||
raise NotImplementedError
|
||||
if not dims:
|
||||
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
||||
axis_size, = {x.shape[d] for x, d in zip(args, dims)
|
||||
if d is not batching.not_mapped}
|
||||
block_mappings = grid_mapping.block_mappings
|
||||
|
@ -1113,6 +1113,52 @@ class PallasCallTest(PallasTPUTest):
|
||||
dynamic_kernel(4), np.full(shape, 8.0, np.float32)
|
||||
)
|
||||
|
||||
def test_vmap_trivial_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = x_ref[...]
|
||||
y_ref[...] += 1
|
||||
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def dynamic_kernel(steps, x):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)(x)
|
||||
x = jnp.arange(8 * 128.).reshape((1, *shape))
|
||||
np.testing.assert_array_equal(dynamic_kernel(jnp.array([4]), x), x + 8.0)
|
||||
|
||||
def test_vmap_nontrivial_dynamic_grid(self):
|
||||
# Dynamic grid doesn't support vmapping over multiple distinct grid values
|
||||
# at the moment.
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = jnp.zeros_like(y_ref)
|
||||
y_ref[...] += 1
|
||||
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def dynamic_kernel(steps):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
dynamic_kernel(jnp.array([4, 8]))
|
||||
|
||||
def test_vmap_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user