From e12656002f424175b2903727a7c6529077d0e217 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 13 Jun 2024 20:24:07 -0700 Subject: [PATCH] [Pallas] Don't actually vmap if we're vmapping over axis size 1 PiperOrigin-RevId: 643209848 --- jax/_src/pallas/pallas_call.py | 22 ++++++++++++++++++++-- tests/pallas/pallas_call_tpu_test.py | 1 + 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index f601820f7..1b5d1d01e 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -518,6 +518,26 @@ def _pallas_call_batching_rule( return x return jnp.squeeze(x, axis=bdim) + axis_size, = {x.shape[d] for x, d in zip(args, dims) + if d is not batching.not_mapped} + if axis_size == 1: + # Why are we even vmapping? + args = map(_maybe_squeeze_out_bdim, args, dims) + out = pallas_call_p.bind( + *args, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + which_linear=which_linear, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + ) + return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out) + # The first num_dynamic_grid_bounds arguments are size-1 arrays that store # the size of the dynamic bounds. dynamic_grid_args, args = split_list( @@ -588,8 +608,6 @@ def _pallas_call_batching_rule( if not dims: raise NotImplementedError("vmapping pallas_call with no arguments.") - axis_size, = {x.shape[d] for x, d in zip(args, dims) - if d is not batching.not_mapped} block_mappings = grid_mapping.block_mappings avals = [v.aval for v in jaxpr.invars] # How should we pick output dimensions? This actually matters because XLA diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index de201e723..b0d651226 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -423,6 +423,7 @@ class PallasCallDynamicGridTest(PallasTPUTest): return self.pallas_call( kernel, grid=(steps * 2,), + in_specs=[pl.BlockSpec(lambda i: (0, 0), shape)], out_specs=pl.BlockSpec(lambda i: (0, 0), shape), out_shape=result_ty, )(x)