mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Pallas] Add support for trivial vmap of scalar prefetch
PiperOrigin-RevId: 600898742
This commit is contained in:
parent
7e80aac3ec
commit
037bc5edbc
@ -283,8 +283,25 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
which_linear: tuple[bool, ...],
|
||||
**compiler_params: Any):
|
||||
if grid_mapping.num_index_operands:
|
||||
scalar_batch_dims = dims[:grid_mapping.num_index_operands]
|
||||
if any(bdim is not batching.not_mapped for bdim in scalar_batch_dims):
|
||||
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
|
||||
scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands])
|
||||
# Ordinarily, adding support for scalar prefetch in vmap would involve
|
||||
# modifying the block specs in a nontrivial way. However, if we are only
|
||||
# vmapping over 1-sized dimensions, we can just get rid of the dimensions
|
||||
# and pretend we were never vmapping over them at all.
|
||||
if all(
|
||||
bdim is batching.not_mapped or arg.shape[bdim] == 1
|
||||
for arg, bdim in zip(scalar_args, scalar_bdims)
|
||||
):
|
||||
def _squeeze_out_bdim(x: jax.Array, bdim: int | batching.NotMapped):
|
||||
if bdim is batching.not_mapped:
|
||||
return x
|
||||
return jnp.squeeze(x, axis=bdim)
|
||||
scalar_args = safe_map(_squeeze_out_bdim, scalar_args, scalar_bdims)
|
||||
scalar_bdims = [None] * len(scalar_args)
|
||||
args = (*scalar_args, *args)
|
||||
dims = (*scalar_bdims, *bdims)
|
||||
else:
|
||||
# TODO(sharadmv,apaszke): enable batching over prefetched scalar args
|
||||
raise NotImplementedError
|
||||
axis_size, = {x.shape[d] for x, d in zip(args, dims)
|
||||
|
@ -214,6 +214,71 @@ class PallasCallScalarPrefetchTest(jtu.JaxTestCase):
|
||||
|
||||
np.testing.assert_allclose(out, expected)
|
||||
|
||||
def test_vmap_scalar_prefetch_1sized(self):
|
||||
def body(_, x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
|
||||
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
|
||||
|
||||
def _x_transform(i, s_ref):
|
||||
s = pl.load(s_ref, (i,))
|
||||
return (s, 0)
|
||||
|
||||
s = s[None]
|
||||
x = x[None]
|
||||
|
||||
out = jax.vmap(pl.pallas_call(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
in_specs=[
|
||||
pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])),
|
||||
],
|
||||
out_specs=pl.BlockSpec(lambda i, _: (i, 0),
|
||||
(x.shape[1] // 8, x.shape[2])),
|
||||
grid=8,
|
||||
),
|
||||
interpret=self.interpret,
|
||||
))(s, x)
|
||||
np.testing.assert_allclose(
|
||||
out, x.reshape((1, 8, 8, -1))[:, s].reshape(x.shape)
|
||||
)
|
||||
|
||||
def test_nontrivial_vmap_scalar_prefetch(self):
|
||||
def body(_, x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
|
||||
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
|
||||
|
||||
def _x_transform(i, s_ref):
|
||||
s = pl.load(s_ref, (i,))
|
||||
return (s, 0)
|
||||
|
||||
s = jnp.tile(s[None], [2, 1])
|
||||
x = jnp.tile(x[None], [2, 1, 1])
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
jax.vmap(
|
||||
pl.pallas_call(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
in_specs=[
|
||||
pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])),
|
||||
],
|
||||
out_specs=pl.BlockSpec(
|
||||
lambda i, _: (i, 0), (x.shape[1] // 8, x.shape[2])
|
||||
),
|
||||
grid=8,
|
||||
),
|
||||
interpret=self.interpret,
|
||||
)
|
||||
)(s, x)
|
||||
|
||||
|
||||
class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest):
|
||||
interpret: bool = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user