[Pallas] Don't actually vmap if we're vmapping over axis size 1

PiperOrigin-RevId: 643209848
This commit is contained in:
Sharad Vikram 2024-06-13 20:24:07 -07:00 committed by jax authors
parent a92fa547a0
commit e12656002f
2 changed files with 21 additions and 2 deletions

View File

@ -518,6 +518,26 @@ def _pallas_call_batching_rule(
return x
return jnp.squeeze(x, axis=bdim)
axis_size, = {x.shape[d] for x, d in zip(args, dims)
if d is not batching.not_mapped}
if axis_size == 1:
# Why are we even vmapping?
args = map(_maybe_squeeze_out_bdim, args, dims)
out = pallas_call_p.bind(
*args,
jaxpr=jaxpr,
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
which_linear=which_linear,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
)
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
# The first num_dynamic_grid_bounds arguments are size-1 arrays that store
# the size of the dynamic bounds.
dynamic_grid_args, args = split_list(
@ -588,8 +608,6 @@ def _pallas_call_batching_rule(
if not dims:
raise NotImplementedError("vmapping pallas_call with no arguments.")
axis_size, = {x.shape[d] for x, d in zip(args, dims)
if d is not batching.not_mapped}
block_mappings = grid_mapping.block_mappings
avals = [v.aval for v in jaxpr.invars]
# How should we pick output dimensions? This actually matters because XLA

View File

@ -423,6 +423,7 @@ class PallasCallDynamicGridTest(PallasTPUTest):
return self.pallas_call(
kernel,
grid=(steps * 2,),
in_specs=[pl.BlockSpec(lambda i: (0, 0), shape)],
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)(x)