preserve precision config in dot_general transpose

This commit is contained in:
James Bradbury 2019-10-09 16:25:37 -07:00
parent 6e55c4e7ba
commit fb433fb9d2

View File

@ -2088,7 +2088,8 @@ def _dot_general_transpose_lhs(g, y, dimension_numbers, precision,
dims = ((ans_y, y_kept), (ans_batch, y_batch))
x_contract_sorted_by_y = list(onp.take(x_contract, onp.argsort(y_contract)))
out_axes = onp.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
return transpose(dot_general(g, y, dims), tuple(out_axes))
return transpose(dot_general(g, y, dims, precision=precision),
tuple(out_axes))
def _dot_general_transpose_rhs(g, x, dimension_numbers, precision):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers