mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[Pallas] Don't actually vmap if we're vmapping over axis size 1
PiperOrigin-RevId: 643209848
This commit is contained in:
parent
a92fa547a0
commit
e12656002f
@ -518,6 +518,26 @@ def _pallas_call_batching_rule(
|
||||
return x
|
||||
return jnp.squeeze(x, axis=bdim)
|
||||
|
||||
axis_size, = {x.shape[d] for x, d in zip(args, dims)
|
||||
if d is not batching.not_mapped}
|
||||
if axis_size == 1:
|
||||
# Why are we even vmapping?
|
||||
args = map(_maybe_squeeze_out_bdim, args, dims)
|
||||
out = pallas_call_p.bind(
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
|
||||
|
||||
# The first num_dynamic_grid_bounds arguments are size-1 arrays that store
|
||||
# the size of the dynamic bounds.
|
||||
dynamic_grid_args, args = split_list(
|
||||
@ -588,8 +608,6 @@ def _pallas_call_batching_rule(
|
||||
|
||||
if not dims:
|
||||
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
||||
axis_size, = {x.shape[d] for x, d in zip(args, dims)
|
||||
if d is not batching.not_mapped}
|
||||
block_mappings = grid_mapping.block_mappings
|
||||
avals = [v.aval for v in jaxpr.invars]
|
||||
# How should we pick output dimensions? This actually matters because XLA
|
||||
|
@ -423,6 +423,7 @@ class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
return self.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
in_specs=[pl.BlockSpec(lambda i: (0, 0), shape)],
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user