Merge pull request #11797 from jakevdp:matrix-rank-mapped

PiperOrigin-RevId: 466114039
This commit is contained in:
jax authors 2022-08-08 12:20:41 -07:00
commit a2a84c40d5
2 changed files with 5 additions and 5 deletions

View File

@ -111,14 +111,13 @@ def matrix_power(a, n):
@jit
def matrix_rank(M, tol=None):
M, = _promote_dtypes_inexact(jnp.asarray(M))
if M.ndim > 2:
raise TypeError("array should have 2 or fewer dimensions")
if M.ndim < 2:
return jnp.any(M != 0).astype(jnp.int32)
S = svd(M, full_matrices=False, compute_uv=False)
if tol is None:
tol = S.max() * np.max(M.shape).astype(S.dtype) * jnp.finfo(S.dtype).eps
return jnp.sum(S > tol)
tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps
tol = jnp.expand_dims(tol, np.ndim(tol))
return jnp.sum(S > tol, axis=-1)
@custom_jvp

View File

@ -924,7 +924,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50)]
for shape in [(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50),
(3, 4, 5), (2, 3, 4, 5)]
for dtype in float_types + complex_types))
def testMatrixRank(self, shape, dtype):
rng = jtu.rand_default(self.rng())