mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Delete unused code in _dot_batch_rule
PiperOrigin-RevId: 725725676
This commit is contained in:
parent
6fc1c61520
commit
7994aa82f8
@ -4194,8 +4194,6 @@ def _dot_batch_rule(
|
||||
|
||||
lhs, rhs = unpack_args(batched_args)
|
||||
lbd, rbd = unpack_dims(batch_dims)
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
|
||||
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
|
||||
new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums(
|
||||
|
Loading…
x
Reference in New Issue
Block a user