mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Simplify jnp.trace
implementation
This commit is contained in:
parent
084adc7b79
commit
60c828a78a
@ -2391,9 +2391,6 @@ def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
|
||||
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
|
||||
lax_internal._check_user_dtype_supported(dtype, "trace")
|
||||
|
||||
axis1 = _canonicalize_axis(axis1, ndim(a))
|
||||
axis2 = _canonicalize_axis(axis2, ndim(a))
|
||||
|
||||
a_shape = shape(a)
|
||||
if dtype is None:
|
||||
dtype = _dtype(a)
|
||||
@ -2402,10 +2399,7 @@ def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
|
||||
if iinfo(dtype).bits < iinfo(default_int).bits:
|
||||
dtype = default_int
|
||||
|
||||
# Move the axis? dimensions to the end.
|
||||
perm = [i for i in range(len(a_shape)) if i != axis1 and i != axis2]
|
||||
perm = perm + [axis1, axis2]
|
||||
a = lax.transpose(a, perm)
|
||||
a = moveaxis(a, (axis1, axis2), (-2, -1))
|
||||
|
||||
# Mask out the diagonal and reduce.
|
||||
a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),
|
||||
|
Loading…
x
Reference in New Issue
Block a user