Delete unused code in _dot_batch_rule

PiperOrigin-RevId: 725725676
This commit is contained in:
Gunhyun Park 2025-02-11 12:15:24 -08:00 committed by jax authors
parent 6fc1c61520
commit 7994aa82f8

View File

@ -4194,8 +4194,6 @@ def _dot_batch_rule(
lhs, rhs = unpack_args(batched_args)
lbd, rbd = unpack_dims(batch_dims)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums(