mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas] Fix batching rule for kernels with scratch inputs
Scratch inputs do not need a batching dimension. PiperOrigin-RevId: 588921137
This commit is contained in:
parent
e423347dda
commit
1189d61bc0
@ -311,9 +311,18 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
all_dims = list(dims) + [0] * len(out_shapes)
|
||||
|
||||
num_index_operands = grid_mapping.num_index_operands
|
||||
num_scratch_operands = grid_mapping.num_scratch_operands
|
||||
|
||||
# Only add a batch dimension for the avals that actually have a grid mapping.
|
||||
# This excludes scalar prefetch inputs (the first in the list) and scratch
|
||||
# operands (the last in the list).
|
||||
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
|
||||
batched_block_mappings = map(
|
||||
partial(_batch_block_mapping, grid_mapping.grid),
|
||||
avals[num_index_operands:], all_dims[num_index_operands:], block_mappings)
|
||||
avals_to_batch,
|
||||
all_dims[num_index_operands:],
|
||||
block_mappings,
|
||||
)
|
||||
|
||||
batched_in_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
|
||||
|
Loading…
x
Reference in New Issue
Block a user