From 1189d61bc086fcfb548e73235a601ec46c3623c5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 7 Dec 2023 15:02:17 -0800 Subject: [PATCH] [Pallas] Fix batching rule for kernels with scratch inputs Scratch inputs do not need a batching dimension. PiperOrigin-RevId: 588921137 --- jax/_src/pallas/pallas_call.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 92b28e3c2..f248063ee 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -311,9 +311,18 @@ def _pallas_call_batching_rule(args, dims, *, all_dims = list(dims) + [0] * len(out_shapes) num_index_operands = grid_mapping.num_index_operands + num_scratch_operands = grid_mapping.num_scratch_operands + + # Only add a batch dimension for the avals that actually have a grid mapping. + # This excludes scalar prefetch inputs (the first in the list) and scratch + # operands (the last in the list). + avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( partial(_batch_block_mapping, grid_mapping.grid), - avals[num_index_operands:], all_dims[num_index_operands:], block_mappings) + avals_to_batch, + all_dims[num_index_operands:], + block_mappings, + ) batched_in_shapes = tuple( jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else