mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
transpose shouldn't transpose with identity perm
This commit is contained in:
parent
538850e271
commit
77d6fb4c01
@ -244,7 +244,11 @@ def _index_untake(axes, src, dst, *idxs):
|
||||
return dst
|
||||
|
||||
def transpose(operand, permutation):
|
||||
return transpose_p.bind(operand, permutation=tuple(permutation))
|
||||
permutation = tuple(permutation)
|
||||
if permutation == tuple(range(len(permutation))):
|
||||
return operand
|
||||
else:
|
||||
return transpose_p.bind(operand, permutation=permutation)
|
||||
|
||||
def reduce(operand, init_value, computation, dimensions):
|
||||
monoid_reducer = _get_monoid_reducer(computation, init_value)
|
||||
|
Loading…
x
Reference in New Issue
Block a user