mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:16:05 +00:00
jnp.moveaxis: fix bug when axes are integer dtype
This commit is contained in:
parent
ada6f30f59
commit
2cf8d49f5b
@ -1256,10 +1256,15 @@ def swapaxes(a, axis1, axis2):
|
||||
|
||||
@_wraps(np.moveaxis)
|
||||
def moveaxis(a, source, destination):
|
||||
if isinstance(source, int):
|
||||
source = (source,)
|
||||
if isinstance(destination, int):
|
||||
destination = (destination,)
|
||||
_check_arraylike("moveaxis", a)
|
||||
try:
|
||||
source = (operator.index(source),)
|
||||
except TypeError:
|
||||
pass
|
||||
try:
|
||||
destination = (operator.index(destination),)
|
||||
except TypeError:
|
||||
pass
|
||||
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
|
||||
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
|
||||
if len(source) != len(destination):
|
||||
|
Loading…
x
Reference in New Issue
Block a user