jnp.moveaxis: fix bug when axes are integer dtype

This commit is contained in:
Jake VanderPlas 2020-09-21 16:32:11 -07:00
parent ada6f30f59
commit 2cf8d49f5b

View File

@ -1256,10 +1256,15 @@ def swapaxes(a, axis1, axis2):
@_wraps(np.moveaxis)
def moveaxis(a, source, destination):
if isinstance(source, int):
source = (source,)
if isinstance(destination, int):
destination = (destination,)
_check_arraylike("moveaxis", a)
try:
source = (operator.index(source),)
except TypeError:
pass
try:
destination = (operator.index(destination),)
except TypeError:
pass
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
if len(source) != len(destination):