mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #10140 from jakevdp:jnp-diagonal
PiperOrigin-RevId: 439596331
This commit is contained in:
commit
fef367019b
@ -2478,16 +2478,11 @@ def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1):
|
||||
perm = perm + [axis1, axis2]
|
||||
a = lax.transpose(a, perm)
|
||||
|
||||
# Mask out the diagonal and reduce over one of the axes
|
||||
a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),
|
||||
a, zeros_like(a))
|
||||
reduce_axis = -2 if offset < 0 else -1
|
||||
d = sum(a, axis=reduce_axis, dtype=_dtype(a))
|
||||
|
||||
# Slice out the correct diagonal size.
|
||||
diag_size = _max(0, _min(a_shape[axis1] + _min(offset, 0),
|
||||
a_shape[axis2] - _max(offset, 0)))
|
||||
return lax.slice_in_dim(d, 0, diag_size, axis=-1)
|
||||
i = arange(diag_size)
|
||||
j = arange(_abs(offset), _abs(offset) + diag_size)
|
||||
return a[..., i, j] if offset >= 0 else a[..., j, i]
|
||||
|
||||
|
||||
@_wraps(np.diag, lax_description=_ARRAY_VIEW_DOC)
|
||||
|
Loading…
x
Reference in New Issue
Block a user