Implement np.linalg.matrix_rank (#2008)

* Implement np.linalg.matrix_rank

* Test np.linalg.matrix_rank

* Use helper numpy testing function

* Fix issue with 1D matrix rank procedure

* Add new tests for 1D matrices and jit

* Do not check dtypes to circumvent int32 vs int64

* Include documentation for matrix_rank

* Fix ordering

* Use np.sum
This commit is contained in:
Ziyad Edher 2020-01-26 14:29:33 -05:00 committed by Stephan Hoyer
parent 632326ac5c
commit 0fca476c54
3 changed files with 33 additions and 0 deletions

View File

@ -285,6 +285,7 @@ jax.numpy.linalg
eigh
inv
matrix_power
matrix_rank
norm
qr
slogdet

View File

@ -100,6 +100,19 @@ def matrix_power(a, n):
return result
@_wraps(onp.linalg.matrix_rank)
def matrix_rank(M, tol=None):
M = _promote_arg_dtypes(np.asarray(M))
if M.ndim > 2:
raise TypeError("array should have 2 or fewer dimensions")
if M.ndim < 2:
return np.any(M != 0).astype(np.int32)
S = svd(M, full_matrices=False, compute_uv=False)
if tol is None:
tol = S.max() * np.max(M.shape) * np.finfo(S.dtype).eps
return np.sum(S > tol)
# TODO(pfau): make this work for complex types
def _jvp_slogdet(g, ans, x):
jvp_sign = np.zeros(x.shape[:-2])

View File

@ -623,6 +623,25 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(partial(np.linalg.matrix_power, n=n), args_maker,
check_dtypes=True, rtol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
for shape in [(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50)]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("gpu", "tpu") # TODO(b/145608614): SVD crashes on GPU.
def testMatrixRank(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
self._CheckAgainstNumpy(onp.linalg.matrix_rank, np.linalg.matrix_rank,
args_maker, check_dtypes=False, tol=1e-3)
self._CompileAndCheck(np.linalg.matrix_rank, args_maker,
check_dtypes=False, rtol=1e-3)
# Regression test for incorrect type for eigenvalues of a complex matrix.
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
def testIssue669(self):