mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv
This commit is contained in:
parent
0501d3d7a0
commit
5cc255b755
@ -21,6 +21,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`,
|
||||
`translations`, `register_translation`, `xla_destructure`,
|
||||
`TranslationRule`, `TranslationContext`, `XlaOp`.
|
||||
* The ``tol`` argument of {func}`jax.numpy.linalg.matrix_rank` is being
|
||||
deprecated and will soon be removed. Use `rtol` instead.
|
||||
* The ``rcond`` argument of {func}`jax.numpy.linalg.pinv` is being
|
||||
deprecated and will soon be removed. Use `rtol` instead.
|
||||
|
||||
## jaxlib 0.4.29
|
||||
|
||||
|
@ -35,7 +35,7 @@ from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import reductions, ufuncs
|
||||
from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.typing import ArrayLike, Array, DTypeLike
|
||||
from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg
|
||||
|
||||
|
||||
class EighResult(NamedTuple):
|
||||
@ -392,7 +392,9 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
|
||||
|
||||
|
||||
@jit
|
||||
def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
|
||||
def matrix_rank(
|
||||
M: ArrayLike, rtol: ArrayLike | None = None, *,
|
||||
tol: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array:
|
||||
"""Compute the rank of a matrix.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.matrix_rank`.
|
||||
@ -402,9 +404,10 @@ def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
|
||||
|
||||
Args:
|
||||
a: array of shape ``(..., M, N)`` whose rank is to be computed.
|
||||
tol: optional array of shape ``(...)`` specifying the tolerance. Singular values
|
||||
smaller than `tol` are considered to be zero. If ``tol`` is None (the default),
|
||||
a reasonable default is chosen based the floating point precision of the input.
|
||||
rtol: optional array of shape ``(...)`` specifying the tolerance. Singular values
|
||||
smaller than `rtol * largest_singular_value` are considered to be zero. If
|
||||
``rtol`` is None (the default), a reasonable default is chosen based the
|
||||
floating point precision of the input.
|
||||
|
||||
Returns:
|
||||
array of shape ``a.shape[-2]`` giving the matrix rank.
|
||||
@ -412,7 +415,7 @@ def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
|
||||
Notes:
|
||||
The rank calculation may be inaccurate for matrices with very small singular
|
||||
values or those that are numerically ill-conditioned. Consider adjusting the
|
||||
``tol`` parameter or using a more specialized rank computation method in such cases.
|
||||
``rtol`` parameter or using a more specialized rank computation method in such cases.
|
||||
|
||||
Examples:
|
||||
>>> a = jnp.array([[1, 2],
|
||||
@ -426,14 +429,24 @@ def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array:
|
||||
Array(1, dtype=int32)
|
||||
"""
|
||||
check_arraylike("jnp.linalg.matrix_rank", M)
|
||||
# TODO(micky774): deprecated 2024-5-14, remove after deprecation expires.
|
||||
if not isinstance(tol, DeprecatedArg):
|
||||
rtol = tol
|
||||
del tol
|
||||
warnings.warn(
|
||||
"The tol argument for linalg.matrix_rank is deprecated using it will soon raise "
|
||||
"an error. To prepare for future releases, and suppress this warning, "
|
||||
"please use rtol instead.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
M, = promote_dtypes_inexact(jnp.asarray(M))
|
||||
if M.ndim < 2:
|
||||
return (M != 0).any().astype(jnp.int32)
|
||||
S = svd(M, full_matrices=False, compute_uv=False)
|
||||
if tol is None:
|
||||
tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps
|
||||
tol = jnp.expand_dims(tol, np.ndim(tol))
|
||||
return reductions.sum(S > tol, axis=-1)
|
||||
if rtol is None:
|
||||
rtol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps
|
||||
rtol = jnp.expand_dims(rtol, np.ndim(rtol))
|
||||
return reductions.sum(S > rtol, axis=-1)
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@ -861,21 +874,21 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
|
||||
return w
|
||||
|
||||
|
||||
@partial(custom_jvp, nondiff_argnums=(1, 2))
|
||||
@partial(jit, static_argnames=('hermitian',))
|
||||
def pinv(a: ArrayLike, rcond: ArrayLike | None = None,
|
||||
hermitian: bool = False) -> Array:
|
||||
# TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires.
|
||||
def pinv(a: ArrayLike, rtol: ArrayLike | None = None,
|
||||
hermitian: bool = False, *,
|
||||
rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array:
|
||||
"""Compute the (Moore-Penrose) pseudo-inverse of a matrix.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.pinv`.
|
||||
|
||||
Args:
|
||||
a: array of shape ``(..., M, N)`` containing matrices to pseudo-invert.
|
||||
rcond: float or array_like of shape ``a.shape[:-2]``. Specifies the cutoff
|
||||
rtol: float or array_like of shape ``a.shape[:-2]``. Specifies the cutoff
|
||||
for small singular values.of shape ``(...,)``.
|
||||
Cutoff for small singular values; singular values smaller than this value
|
||||
are treated as zero. The default is determined based on the floating point
|
||||
precision of the dtype.
|
||||
Cutoff for small singular values; singular values smaller
|
||||
``rtol * largest_singular_value`` are treated as zero. The default is
|
||||
determined based on the floating point precision of the dtype.
|
||||
hermitian: if True, then the input is assumed to be Hermitian, and a more
|
||||
efficient algorithm is used (default: False)
|
||||
|
||||
@ -905,6 +918,22 @@ def pinv(a: ArrayLike, rcond: ArrayLike | None = None,
|
||||
>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4)
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
if not isinstance(rcond, DeprecatedArg):
|
||||
rtol = rcond
|
||||
del rcond
|
||||
warnings.warn(
|
||||
"The rcond argument for linalg.pinv is deprecated using it will soon "
|
||||
"raise an error. To prepare for future releases, and suppress this "
|
||||
"warning, please use rtol instead.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
return _pinv(a, rtol, hermitian)
|
||||
|
||||
|
||||
@partial(custom_jvp, nondiff_argnums=(1, 2))
|
||||
@partial(jit, static_argnames=('hermitian'))
|
||||
def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) -> Array:
|
||||
# Uses same algorithm as
|
||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
||||
check_arraylike("jnp.linalg.pinv", a)
|
||||
@ -913,31 +942,31 @@ def pinv(a: ArrayLike, rcond: ArrayLike | None = None,
|
||||
if m == 0 or n == 0:
|
||||
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
|
||||
arr = ufuncs.conj(arr)
|
||||
if rcond is None:
|
||||
if rtol is None:
|
||||
max_rows_cols = max(arr.shape[-2:])
|
||||
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
|
||||
rcond = jnp.asarray(rcond)
|
||||
rtol = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
|
||||
rtol = jnp.asarray(rtol)
|
||||
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
|
||||
# Singular values less than or equal to ``rcond * largest_singular_value``
|
||||
# Singular values less than or equal to ``rtol * largest_singular_value``
|
||||
# are set to zero.
|
||||
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
|
||||
cutoff = rcond * s[..., 0:1]
|
||||
rtol = lax.expand_dims(rtol[..., jnp.newaxis], range(s.ndim - rtol.ndim - 1))
|
||||
cutoff = rtol * s[..., 0:1]
|
||||
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
|
||||
res = jnp.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]),
|
||||
precision=lax.Precision.HIGHEST)
|
||||
return lax.convert_element_type(res, arr.dtype)
|
||||
|
||||
|
||||
@pinv.defjvp
|
||||
@_pinv.defjvp
|
||||
@jax.default_matmul_precision("float32")
|
||||
def _pinv_jvp(rcond, hermitian, primals, tangents):
|
||||
def _pinv_jvp(rtol, hermitian, primals, tangents):
|
||||
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
||||
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
||||
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
||||
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
||||
a, = primals # m x n
|
||||
a_dot, = tangents
|
||||
p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m
|
||||
p = pinv(a, rtol=rtol, hermitian=hermitian) # n x m
|
||||
if hermitian:
|
||||
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
|
||||
a = _symmetrize(a)
|
||||
|
@ -14,17 +14,15 @@
|
||||
|
||||
import jax
|
||||
|
||||
# TODO(micky774): Remove after deprecating tol-->rtol in jnp.linalg.matrix_rank
|
||||
# TODO(micky774): Remove after deprecation is completed (began 2024-5-14)
|
||||
def matrix_rank(x, /, *, rtol=None):
|
||||
"""
|
||||
Returns the rank (i.e., number of non-zero singular values) of a matrix (or a stack of matrices).
|
||||
"""
|
||||
return jax.numpy.linalg.matrix_rank(x, tol=rtol)
|
||||
return jax.numpy.linalg.matrix_rank(x, rtol)
|
||||
|
||||
# TODO(micky774): Remove after deprecating rcond-->rtol in
|
||||
# jnp.linalg.pinv
|
||||
def pinv(x, /, *, rtol=None):
|
||||
"""
|
||||
Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices) x.
|
||||
"""
|
||||
return jax.numpy.linalg.pinv(x, rcond=rtol)
|
||||
return jax.numpy.linalg.pinv(x, rtol)
|
||||
|
Loading…
x
Reference in New Issue
Block a user