mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11797 from jakevdp:matrix-rank-mapped
PiperOrigin-RevId: 466114039
This commit is contained in:
commit
a2a84c40d5
@ -111,14 +111,13 @@ def matrix_power(a, n):
|
||||
@jit
|
||||
def matrix_rank(M, tol=None):
|
||||
M, = _promote_dtypes_inexact(jnp.asarray(M))
|
||||
if M.ndim > 2:
|
||||
raise TypeError("array should have 2 or fewer dimensions")
|
||||
if M.ndim < 2:
|
||||
return jnp.any(M != 0).astype(jnp.int32)
|
||||
S = svd(M, full_matrices=False, compute_uv=False)
|
||||
if tol is None:
|
||||
tol = S.max() * np.max(M.shape).astype(S.dtype) * jnp.finfo(S.dtype).eps
|
||||
return jnp.sum(S > tol)
|
||||
tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps
|
||||
tol = jnp.expand_dims(tol, np.ndim(tol))
|
||||
return jnp.sum(S > tol, axis=-1)
|
||||
|
||||
|
||||
@custom_jvp
|
||||
|
@ -924,7 +924,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_shape={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50)]
|
||||
for shape in [(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50),
|
||||
(3, 4, 5), (2, 3, 4, 5)]
|
||||
for dtype in float_types + complex_types))
|
||||
def testMatrixRank(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user