Simplify jnp.trace implementation

This commit is contained in:
Lukas Geiger 2022-04-08 18:10:01 +01:00
parent 084adc7b79
commit 60c828a78a

View File

@ -2391,9 +2391,6 @@ def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
lax_internal._check_user_dtype_supported(dtype, "trace")
axis1 = _canonicalize_axis(axis1, ndim(a))
axis2 = _canonicalize_axis(axis2, ndim(a))
a_shape = shape(a)
if dtype is None:
dtype = _dtype(a)
@ -2402,10 +2399,7 @@ def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
if iinfo(dtype).bits < iinfo(default_int).bits:
dtype = default_int
# Move the axis? dimensions to the end.
perm = [i for i in range(len(a_shape)) if i != axis1 and i != axis2]
perm = perm + [axis1, axis2]
a = lax.transpose(a, perm)
a = moveaxis(a, (axis1, axis2), (-2, -1))
# Mask out the diagonal and reduce.
a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),