From 037bc5edbcd17c24b5d2d798490f1ea6c97fe76e Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 23 Jan 2024 13:56:29 -0800 Subject: [PATCH] [Pallas] Add support for trivial vmap of scalar prefetch PiperOrigin-RevId: 600898742 --- jax/_src/pallas/pallas_call.py | 21 ++++++++- tests/pallas/pallas_call_tpu_test.py | 65 ++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 2e3858652..6aa194c0a 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -283,8 +283,25 @@ def _pallas_call_batching_rule(args, dims, *, which_linear: tuple[bool, ...], **compiler_params: Any): if grid_mapping.num_index_operands: - scalar_batch_dims = dims[:grid_mapping.num_index_operands] - if any(bdim is not batching.not_mapped for bdim in scalar_batch_dims): + scalar_args, args = split_list(args, [grid_mapping.num_index_operands]) + scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands]) + # Ordinarily, adding support for scalar prefetch in vmap would involve + # modifying the block specs in a nontrivial way. However, if we are only + # vmapping over 1-sized dimensions, we can just get rid of the dimensions + # and pretend we were never vmapping over them at all. + if all( + bdim is batching.not_mapped or arg.shape[bdim] == 1 + for arg, bdim in zip(scalar_args, scalar_bdims) + ): + def _squeeze_out_bdim(x: jax.Array, bdim: int | batching.NotMapped): + if bdim is batching.not_mapped: + return x + return jnp.squeeze(x, axis=bdim) + scalar_args = safe_map(_squeeze_out_bdim, scalar_args, scalar_bdims) + scalar_bdims = [None] * len(scalar_args) + args = (*scalar_args, *args) + dims = (*scalar_bdims, *bdims) + else: # TODO(sharadmv,apaszke): enable batching over prefetched scalar args raise NotImplementedError axis_size, = {x.shape[d] for x, d in zip(args, dims) diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 55e1faf4c..54a3ef4b8 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -214,6 +214,71 @@ class PallasCallScalarPrefetchTest(jtu.JaxTestCase): np.testing.assert_allclose(out, expected) + def test_vmap_scalar_prefetch_1sized(self): + def body(_, x_ref, o_ref): + o_ref[...] = x_ref[...] + + s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) + x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) + + def _x_transform(i, s_ref): + s = pl.load(s_ref, (i,)) + return (s, 0) + + s = s[None] + x = x[None] + + out = jax.vmap(pl.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])), + ], + out_specs=pl.BlockSpec(lambda i, _: (i, 0), + (x.shape[1] // 8, x.shape[2])), + grid=8, + ), + interpret=self.interpret, + ))(s, x) + np.testing.assert_allclose( + out, x.reshape((1, 8, 8, -1))[:, s].reshape(x.shape) + ) + + def test_nontrivial_vmap_scalar_prefetch(self): + def body(_, x_ref, o_ref): + o_ref[...] = x_ref[...] + + s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) + x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) + + def _x_transform(i, s_ref): + s = pl.load(s_ref, (i,)) + return (s, 0) + + s = jnp.tile(s[None], [2, 1]) + x = jnp.tile(x[None], [2, 1, 1]) + + with self.assertRaises(NotImplementedError): + jax.vmap( + pl.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])), + ], + out_specs=pl.BlockSpec( + lambda i, _: (i, 0), (x.shape[1] // 8, x.shape[2]) + ), + grid=8, + ), + interpret=self.interpret, + ) + )(s, x) + class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): interpret: bool = True