mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
preserve precision config in dot_general transpose
This commit is contained in:
parent
6e55c4e7ba
commit
fb433fb9d2
@ -2088,7 +2088,8 @@ def _dot_general_transpose_lhs(g, y, dimension_numbers, precision,
|
||||
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
||||
x_contract_sorted_by_y = list(onp.take(x_contract, onp.argsort(y_contract)))
|
||||
out_axes = onp.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
||||
return transpose(dot_general(g, y, dims), tuple(out_axes))
|
||||
return transpose(dot_general(g, y, dims, precision=precision),
|
||||
tuple(out_axes))
|
||||
|
||||
def _dot_general_transpose_rhs(g, x, dimension_numbers, precision):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
|
Loading…
x
Reference in New Issue
Block a user