[Pallas] Add support for trivial vmap of scalar prefetch

PiperOrigin-RevId: 600898742
This commit is contained in:
Sharad Vikram 2024-01-23 13:56:29 -08:00 committed by jax authors
parent 7e80aac3ec
commit 037bc5edbc
2 changed files with 84 additions and 2 deletions

View File

@ -283,8 +283,25 @@ def _pallas_call_batching_rule(args, dims, *,
which_linear: tuple[bool, ...],
**compiler_params: Any):
if grid_mapping.num_index_operands:
scalar_batch_dims = dims[:grid_mapping.num_index_operands]
if any(bdim is not batching.not_mapped for bdim in scalar_batch_dims):
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands])
# Ordinarily, adding support for scalar prefetch in vmap would involve
# modifying the block specs in a nontrivial way. However, if we are only
# vmapping over 1-sized dimensions, we can just get rid of the dimensions
# and pretend we were never vmapping over them at all.
if all(
bdim is batching.not_mapped or arg.shape[bdim] == 1
for arg, bdim in zip(scalar_args, scalar_bdims)
):
def _squeeze_out_bdim(x: jax.Array, bdim: int | batching.NotMapped):
if bdim is batching.not_mapped:
return x
return jnp.squeeze(x, axis=bdim)
scalar_args = safe_map(_squeeze_out_bdim, scalar_args, scalar_bdims)
scalar_bdims = [None] * len(scalar_args)
args = (*scalar_args, *args)
dims = (*scalar_bdims, *bdims)
else:
# TODO(sharadmv,apaszke): enable batching over prefetched scalar args
raise NotImplementedError
axis_size, = {x.shape[d] for x, d in zip(args, dims)

View File

@ -214,6 +214,71 @@ class PallasCallScalarPrefetchTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, expected)
def test_vmap_scalar_prefetch_1sized(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
def _x_transform(i, s_ref):
s = pl.load(s_ref, (i,))
return (s, 0)
s = s[None]
x = x[None]
out = jax.vmap(pl.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])),
],
out_specs=pl.BlockSpec(lambda i, _: (i, 0),
(x.shape[1] // 8, x.shape[2])),
grid=8,
),
interpret=self.interpret,
))(s, x)
np.testing.assert_allclose(
out, x.reshape((1, 8, 8, -1))[:, s].reshape(x.shape)
)
def test_nontrivial_vmap_scalar_prefetch(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
def _x_transform(i, s_ref):
s = pl.load(s_ref, (i,))
return (s, 0)
s = jnp.tile(s[None], [2, 1])
x = jnp.tile(x[None], [2, 1, 1])
with self.assertRaises(NotImplementedError):
jax.vmap(
pl.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])),
],
out_specs=pl.BlockSpec(
lambda i, _: (i, 0), (x.shape[1] // 8, x.shape[2])
),
grid=8,
),
interpret=self.interpret,
)
)(s, x)
class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest):
interpret: bool = True