diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 92b28e3c2..f248063ee 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -311,9 +311,18 @@ def _pallas_call_batching_rule(args, dims, *, all_dims = list(dims) + [0] * len(out_shapes) num_index_operands = grid_mapping.num_index_operands + num_scratch_operands = grid_mapping.num_scratch_operands + + # Only add a batch dimension for the avals that actually have a grid mapping. + # This excludes scalar prefetch inputs (the first in the list) and scratch + # operands (the last in the list). + avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( partial(_batch_block_mapping, grid_mapping.grid), - avals[num_index_operands:], all_dims[num_index_operands:], block_mappings) + avals_to_batch, + all_dims[num_index_operands:], + block_mappings, + ) batched_in_shapes = tuple( jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else