[Pallas/TPU] Allow 1-sized batch dim in vmap of dynamic grid

PiperOrigin-RevId: 603518847
This commit is contained in:
Sharad Vikram 2024-02-01 16:42:46 -08:00 committed by jax authors
parent d76705da94
commit a41385c860
2 changed files with 69 additions and 7 deletions

View File

@ -286,14 +286,32 @@ def _pallas_call_batching_rule(args, dims, *,
interpret: bool,
which_linear: tuple[bool, ...],
**compiler_params: Any):
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
) -> jax.Array:
if bdim is batching.not_mapped:
return x
return jnp.squeeze(x, axis=bdim)
dynamic_grid_args, args = split_list(
args, [grid_mapping.num_dynamic_grid_bounds]
)
dynamic_grid_dims, dims = split_list(
dims, [grid_mapping.num_dynamic_grid_bounds]
)
if any(dim is not batching.not_mapped for dim in dynamic_grid_dims):
raise NotImplementedError("Batched dynamic grid bounds unsupported")
if all(
bdim is batching.not_mapped or arg.shape[bdim] == 1
for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims)
):
dynamic_grid_args = safe_map(
_maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims)
elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims):
raise NotImplementedError(
f"Batched dynamic grid bounds unsupported: {dynamic_grid_dims}"
)
else:
pass # No dynamic grid dimensions
del dynamic_grid_dims
if grid_mapping.num_index_operands:
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
@ -306,17 +324,15 @@ def _pallas_call_batching_rule(args, dims, *,
bdim is batching.not_mapped or arg.shape[bdim] == 1
for arg, bdim in zip(scalar_args, scalar_bdims)
):
def _squeeze_out_bdim(x: jax.Array, bdim: int | batching.NotMapped):
if bdim is batching.not_mapped:
return x
return jnp.squeeze(x, axis=bdim)
scalar_args = safe_map(_squeeze_out_bdim, scalar_args, scalar_bdims)
scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims)
scalar_bdims = [None] * len(scalar_args)
args = (*scalar_args, *args)
dims = (*scalar_bdims, *bdims)
else:
# TODO(sharadmv,apaszke): enable batching over prefetched scalar args
raise NotImplementedError
if not dims:
raise NotImplementedError("vmapping pallas_call with no arguments.")
axis_size, = {x.shape[d] for x, d in zip(args, dims)
if d is not batching.not_mapped}
block_mappings = grid_mapping.block_mappings

View File

@ -1113,6 +1113,52 @@ class PallasCallTest(PallasTPUTest):
dynamic_kernel(4), np.full(shape, 8.0, np.float32)
)
def test_vmap_trivial_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(x_ref, y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = x_ref[...]
y_ref[...] += 1
@jax.jit
@jax.vmap
def dynamic_kernel(steps, x):
return pl.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)(x)
x = jnp.arange(8 * 128.).reshape((1, *shape))
np.testing.assert_array_equal(dynamic_kernel(jnp.array([4]), x), x + 8.0)
def test_vmap_nontrivial_dynamic_grid(self):
# Dynamic grid doesn't support vmapping over multiple distinct grid values
# at the moment.
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@jax.jit
@jax.vmap
def dynamic_kernel(steps):
return pl.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)()
with self.assertRaises(NotImplementedError):
dynamic_kernel(jnp.array([4, 8]))
def test_vmap_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)