Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv

This commit is contained in:
Meekail Zain 2024-05-14 19:53:54 +00:00
parent 0501d3d7a0
commit 5cc255b755
3 changed files with 63 additions and 32 deletions

View File

@ -21,6 +21,10 @@ Remember to align the itemized text with the first line of an item within a list
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`,
`translations`, `register_translation`, `xla_destructure`,
`TranslationRule`, `TranslationContext`, `XlaOp`.
* The ``tol`` argument of {func}`jax.numpy.linalg.matrix_rank` is being
deprecated and will soon be removed. Use `rtol` instead.
* The ``rcond`` argument of {func}`jax.numpy.linalg.pinv` is being
deprecated and will soon be removed. Use `rtol` instead.
## jaxlib 0.4.29

View File

@ -35,7 +35,7 @@ from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions, ufuncs
from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike
from jax._src.util import canonicalize_axis
from jax._src.typing import ArrayLike, Array, DTypeLike
from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg
class EighResult(NamedTuple):
@ -392,7 +392,9 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
@jit
def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
def matrix_rank(
M: ArrayLike, rtol: ArrayLike | None = None, *,
tol: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array:
"""Compute the rank of a matrix.
JAX implementation of :func:`numpy.linalg.matrix_rank`.
@ -402,9 +404,10 @@ def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
Args:
a: array of shape ``(..., M, N)`` whose rank is to be computed.
tol: optional array of shape ``(...)`` specifying the tolerance. Singular values
smaller than `tol` are considered to be zero. If ``tol`` is None (the default),
a reasonable default is chosen based the floating point precision of the input.
rtol: optional array of shape ``(...)`` specifying the tolerance. Singular values
smaller than `rtol * largest_singular_value` are considered to be zero. If
``rtol`` is None (the default), a reasonable default is chosen based the
floating point precision of the input.
Returns:
array of shape ``a.shape[-2]`` giving the matrix rank.
@ -412,7 +415,7 @@ def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
Notes:
The rank calculation may be inaccurate for matrices with very small singular
values or those that are numerically ill-conditioned. Consider adjusting the
``tol`` parameter or using a more specialized rank computation method in such cases.
``rtol`` parameter or using a more specialized rank computation method in such cases.
Examples:
>>> a = jnp.array([[1, 2],
@ -426,14 +429,24 @@ def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
Array(1, dtype=int32)
"""
check_arraylike("jnp.linalg.matrix_rank", M)
# TODO(micky774): deprecated 2024-5-14, remove after deprecation expires.
if not isinstance(tol, DeprecatedArg):
rtol = tol
del tol
warnings.warn(
"The tol argument for linalg.matrix_rank is deprecated using it will soon raise "
"an error. To prepare for future releases, and suppress this warning, "
"please use rtol instead.",
DeprecationWarning, stacklevel=2
)
M, = promote_dtypes_inexact(jnp.asarray(M))
if M.ndim < 2:
return (M != 0).any().astype(jnp.int32)
S = svd(M, full_matrices=False, compute_uv=False)
if tol is None:
tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps
tol = jnp.expand_dims(tol, np.ndim(tol))
return reductions.sum(S > tol, axis=-1)
if rtol is None:
rtol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps
rtol = jnp.expand_dims(rtol, np.ndim(rtol))
return reductions.sum(S > rtol, axis=-1)
@custom_jvp
@ -861,21 +874,21 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
return w
@partial(custom_jvp, nondiff_argnums=(1, 2))
@partial(jit, static_argnames=('hermitian',))
def pinv(a: ArrayLike, rcond: ArrayLike | None = None,
hermitian: bool = False) -> Array:
# TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires.
def pinv(a: ArrayLike, rtol: ArrayLike | None = None,
hermitian: bool = False, *,
rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array:
"""Compute the (Moore-Penrose) pseudo-inverse of a matrix.
JAX implementation of :func:`numpy.linalg.pinv`.
Args:
a: array of shape ``(..., M, N)`` containing matrices to pseudo-invert.
rcond: float or array_like of shape ``a.shape[:-2]``. Specifies the cutoff
rtol: float or array_like of shape ``a.shape[:-2]``. Specifies the cutoff
for small singular values.of shape ``(...,)``.
Cutoff for small singular values; singular values smaller than this value
are treated as zero. The default is determined based on the floating point
precision of the dtype.
Cutoff for small singular values; singular values smaller
``rtol * largest_singular_value`` are treated as zero. The default is
determined based on the floating point precision of the dtype.
hermitian: if True, then the input is assumed to be Hermitian, and a more
efficient algorithm is used (default: False)
@ -905,6 +918,22 @@ def pinv(a: ArrayLike, rcond: ArrayLike | None = None,
>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4)
Array(True, dtype=bool)
"""
if not isinstance(rcond, DeprecatedArg):
rtol = rcond
del rcond
warnings.warn(
"The rcond argument for linalg.pinv is deprecated using it will soon "
"raise an error. To prepare for future releases, and suppress this "
"warning, please use rtol instead.",
DeprecationWarning, stacklevel=2
)
return _pinv(a, rtol, hermitian)
@partial(custom_jvp, nondiff_argnums=(1, 2))
@partial(jit, static_argnames=('hermitian'))
def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) -> Array:
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
check_arraylike("jnp.linalg.pinv", a)
@ -913,31 +942,31 @@ def pinv(a: ArrayLike, rcond: ArrayLike | None = None,
if m == 0 or n == 0:
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
arr = ufuncs.conj(arr)
if rcond is None:
if rtol is None:
max_rows_cols = max(arr.shape[-2:])
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
rcond = jnp.asarray(rcond)
rtol = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
rtol = jnp.asarray(rtol)
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
# Singular values less than or equal to ``rcond * largest_singular_value``
# Singular values less than or equal to ``rtol * largest_singular_value``
# are set to zero.
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
cutoff = rcond * s[..., 0:1]
rtol = lax.expand_dims(rtol[..., jnp.newaxis], range(s.ndim - rtol.ndim - 1))
cutoff = rtol * s[..., 0:1]
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
res = jnp.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]),
precision=lax.Precision.HIGHEST)
return lax.convert_element_type(res, arr.dtype)
@pinv.defjvp
@_pinv.defjvp
@jax.default_matmul_precision("float32")
def _pinv_jvp(rcond, hermitian, primals, tangents):
def _pinv_jvp(rtol, hermitian, primals, tangents):
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
a, = primals # m x n
a_dot, = tangents
p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m
p = pinv(a, rtol=rtol, hermitian=hermitian) # n x m
if hermitian:
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
a = _symmetrize(a)

View File

@ -14,17 +14,15 @@
import jax
# TODO(micky774): Remove after deprecating tol-->rtol in jnp.linalg.matrix_rank
# TODO(micky774): Remove after deprecation is completed (began 2024-5-14)
def matrix_rank(x, /, *, rtol=None):
"""
Returns the rank (i.e., number of non-zero singular values) of a matrix (or a stack of matrices).
"""
return jax.numpy.linalg.matrix_rank(x, tol=rtol)
return jax.numpy.linalg.matrix_rank(x, rtol)
# TODO(micky774): Remove after deprecating rcond-->rtol in
# jnp.linalg.pinv
def pinv(x, /, *, rtol=None):
"""
Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices) x.
"""
return jax.numpy.linalg.pinv(x, rcond=rtol)
return jax.numpy.linalg.pinv(x, rtol)