mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Implement np.linalg.matrix_rank (#2008)
* Implement np.linalg.matrix_rank * Test np.linalg.matrix_rank * Use helper numpy testing function * Fix issue with 1D matrix rank procedure * Add new tests for 1D matrices and jit * Do not check dtypes to circumvent int32 vs int64 * Include documentation for matrix_rank * Fix ordering * Use np.sum
This commit is contained in:
parent
632326ac5c
commit
0fca476c54
@ -285,6 +285,7 @@ jax.numpy.linalg
|
||||
eigh
|
||||
inv
|
||||
matrix_power
|
||||
matrix_rank
|
||||
norm
|
||||
qr
|
||||
slogdet
|
||||
|
@ -100,6 +100,19 @@ def matrix_power(a, n):
|
||||
return result
|
||||
|
||||
|
||||
@_wraps(onp.linalg.matrix_rank)
|
||||
def matrix_rank(M, tol=None):
|
||||
M = _promote_arg_dtypes(np.asarray(M))
|
||||
if M.ndim > 2:
|
||||
raise TypeError("array should have 2 or fewer dimensions")
|
||||
if M.ndim < 2:
|
||||
return np.any(M != 0).astype(np.int32)
|
||||
S = svd(M, full_matrices=False, compute_uv=False)
|
||||
if tol is None:
|
||||
tol = S.max() * np.max(M.shape) * np.finfo(S.dtype).eps
|
||||
return np.sum(S > tol)
|
||||
|
||||
|
||||
# TODO(pfau): make this work for complex types
|
||||
def _jvp_slogdet(g, ans, x):
|
||||
jvp_sign = np.zeros(x.shape[:-2])
|
||||
|
@ -623,6 +623,25 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(partial(np.linalg.matrix_power, n=n), args_maker,
|
||||
check_dtypes=True, rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
|
||||
for shape in [(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50)]
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_devices("gpu", "tpu") # TODO(b/145608614): SVD crashes on GPU.
|
||||
def testMatrixRank(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
n = shape[-1]
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
self._CheckAgainstNumpy(onp.linalg.matrix_rank, np.linalg.matrix_rank,
|
||||
args_maker, check_dtypes=False, tol=1e-3)
|
||||
self._CompileAndCheck(np.linalg.matrix_rank, args_maker,
|
||||
check_dtypes=False, rtol=1e-3)
|
||||
|
||||
# Regression test for incorrect type for eigenvalues of a complex matrix.
|
||||
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
|
||||
def testIssue669(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user