diff --git a/CHANGELOG.md b/CHANGELOG.md index a13a8fd5b..c0987b891 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,10 @@ Remember to align the itemized text with the first line of an item within a list * from {mod}`jax.interpreters.xla`: `backend_specific_translations`, `translations`, `register_translation`, `xla_destructure`, `TranslationRule`, `TranslationContext`, `XlaOp`. + * The ``tol`` argument of {func}`jax.numpy.linalg.matrix_rank` is being + deprecated and will soon be removed. Use `rtol` instead. + * The ``rcond`` argument of {func}`jax.numpy.linalg.pinv` is being + deprecated and will soon be removed. Use `rtol` instead. ## jaxlib 0.4.29 diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 6e8c002ca..ce026311b 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -35,7 +35,7 @@ from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions, ufuncs from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike from jax._src.util import canonicalize_axis -from jax._src.typing import ArrayLike, Array, DTypeLike +from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg class EighResult(NamedTuple): @@ -392,7 +392,9 @@ def matrix_power(a: ArrayLike, n: int) -> Array: @jit -def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array: +def matrix_rank( + M: ArrayLike, rtol: ArrayLike | None = None, *, + tol: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: """Compute the rank of a matrix. JAX implementation of :func:`numpy.linalg.matrix_rank`. @@ -402,9 +404,10 @@ def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array: Args: a: array of shape ``(..., M, N)`` whose rank is to be computed. - tol: optional array of shape ``(...)`` specifying the tolerance. Singular values - smaller than `tol` are considered to be zero. If ``tol`` is None (the default), - a reasonable default is chosen based the floating point precision of the input. + rtol: optional array of shape ``(...)`` specifying the tolerance. Singular values + smaller than `rtol * largest_singular_value` are considered to be zero. If + ``rtol`` is None (the default), a reasonable default is chosen based the + floating point precision of the input. Returns: array of shape ``a.shape[-2]`` giving the matrix rank. @@ -412,7 +415,7 @@ def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array: Notes: The rank calculation may be inaccurate for matrices with very small singular values or those that are numerically ill-conditioned. Consider adjusting the - ``tol`` parameter or using a more specialized rank computation method in such cases. + ``rtol`` parameter or using a more specialized rank computation method in such cases. Examples: >>> a = jnp.array([[1, 2], @@ -426,14 +429,24 @@ def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array: Array(1, dtype=int32) """ check_arraylike("jnp.linalg.matrix_rank", M) + # TODO(micky774): deprecated 2024-5-14, remove after deprecation expires. + if not isinstance(tol, DeprecatedArg): + rtol = tol + del tol + warnings.warn( + "The tol argument for linalg.matrix_rank is deprecated using it will soon raise " + "an error. To prepare for future releases, and suppress this warning, " + "please use rtol instead.", + DeprecationWarning, stacklevel=2 + ) M, = promote_dtypes_inexact(jnp.asarray(M)) if M.ndim < 2: return (M != 0).any().astype(jnp.int32) S = svd(M, full_matrices=False, compute_uv=False) - if tol is None: - tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps - tol = jnp.expand_dims(tol, np.ndim(tol)) - return reductions.sum(S > tol, axis=-1) + if rtol is None: + rtol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps + rtol = jnp.expand_dims(rtol, np.ndim(rtol)) + return reductions.sum(S > rtol, axis=-1) @custom_jvp @@ -861,21 +874,21 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: return w -@partial(custom_jvp, nondiff_argnums=(1, 2)) -@partial(jit, static_argnames=('hermitian',)) -def pinv(a: ArrayLike, rcond: ArrayLike | None = None, - hermitian: bool = False) -> Array: +# TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires. +def pinv(a: ArrayLike, rtol: ArrayLike | None = None, + hermitian: bool = False, *, + rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: """Compute the (Moore-Penrose) pseudo-inverse of a matrix. JAX implementation of :func:`numpy.linalg.pinv`. Args: a: array of shape ``(..., M, N)`` containing matrices to pseudo-invert. - rcond: float or array_like of shape ``a.shape[:-2]``. Specifies the cutoff + rtol: float or array_like of shape ``a.shape[:-2]``. Specifies the cutoff for small singular values.of shape ``(...,)``. - Cutoff for small singular values; singular values smaller than this value - are treated as zero. The default is determined based on the floating point - precision of the dtype. + Cutoff for small singular values; singular values smaller + ``rtol * largest_singular_value`` are treated as zero. The default is + determined based on the floating point precision of the dtype. hermitian: if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False) @@ -905,6 +918,22 @@ def pinv(a: ArrayLike, rcond: ArrayLike | None = None, >>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4) Array(True, dtype=bool) """ + if not isinstance(rcond, DeprecatedArg): + rtol = rcond + del rcond + warnings.warn( + "The rcond argument for linalg.pinv is deprecated using it will soon " + "raise an error. To prepare for future releases, and suppress this " + "warning, please use rtol instead.", + DeprecationWarning, stacklevel=2 + ) + + return _pinv(a, rtol, hermitian) + + +@partial(custom_jvp, nondiff_argnums=(1, 2)) +@partial(jit, static_argnames=('hermitian')) +def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) -> Array: # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 check_arraylike("jnp.linalg.pinv", a) @@ -913,31 +942,31 @@ def pinv(a: ArrayLike, rcond: ArrayLike | None = None, if m == 0 or n == 0: return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) arr = ufuncs.conj(arr) - if rcond is None: + if rtol is None: max_rows_cols = max(arr.shape[-2:]) - rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) - rcond = jnp.asarray(rcond) + rtol = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) + rtol = jnp.asarray(rtol) u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian) - # Singular values less than or equal to ``rcond * largest_singular_value`` + # Singular values less than or equal to ``rtol * largest_singular_value`` # are set to zero. - rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1)) - cutoff = rcond * s[..., 0:1] + rtol = lax.expand_dims(rtol[..., jnp.newaxis], range(s.ndim - rtol.ndim - 1)) + cutoff = rtol * s[..., 0:1] s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) res = jnp.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]), precision=lax.Precision.HIGHEST) return lax.convert_element_type(res, arr.dtype) -@pinv.defjvp +@_pinv.defjvp @jax.default_matmul_precision("float32") -def _pinv_jvp(rcond, hermitian, primals, tangents): +def _pinv_jvp(rtol, hermitian, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) a, = primals # m x n a_dot, = tangents - p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m + p = pinv(a, rtol=rtol, hermitian=hermitian) # n x m if hermitian: # svd(..., hermitian=True) symmetrizes its input, and the JVP must match. a = _symmetrize(a) diff --git a/jax/experimental/array_api/_linear_algebra_functions.py b/jax/experimental/array_api/_linear_algebra_functions.py index 8234a9f01..cc488d721 100644 --- a/jax/experimental/array_api/_linear_algebra_functions.py +++ b/jax/experimental/array_api/_linear_algebra_functions.py @@ -14,17 +14,15 @@ import jax -# TODO(micky774): Remove after deprecating tol-->rtol in jnp.linalg.matrix_rank +# TODO(micky774): Remove after deprecation is completed (began 2024-5-14) def matrix_rank(x, /, *, rtol=None): """ Returns the rank (i.e., number of non-zero singular values) of a matrix (or a stack of matrices). """ - return jax.numpy.linalg.matrix_rank(x, tol=rtol) + return jax.numpy.linalg.matrix_rank(x, rtol) -# TODO(micky774): Remove after deprecating rcond-->rtol in -# jnp.linalg.pinv def pinv(x, /, *, rtol=None): """ Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices) x. """ - return jax.numpy.linalg.pinv(x, rcond=rtol) + return jax.numpy.linalg.pinv(x, rtol)