[Pallas] Fix batching rule for kernels with scratch inputs

Scratch inputs do not need a batching dimension.

PiperOrigin-RevId: 588921137
This commit is contained in:
jax authors 2023-12-07 15:02:17 -08:00
parent e423347dda
commit 1189d61bc0

View File

@ -311,9 +311,18 @@ def _pallas_call_batching_rule(args, dims, *,
all_dims = list(dims) + [0] * len(out_shapes)
num_index_operands = grid_mapping.num_index_operands
num_scratch_operands = grid_mapping.num_scratch_operands
# Only add a batch dimension for the avals that actually have a grid mapping.
# This excludes scalar prefetch inputs (the first in the list) and scratch
# operands (the last in the list).
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
batched_block_mappings = map(
partial(_batch_block_mapping, grid_mapping.grid),
avals[num_index_operands:], all_dims[num_index_operands:], block_mappings)
avals_to_batch,
all_dims[num_index_operands:],
block_mappings,
)
batched_in_shapes = tuple(
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else