diff --git a/CHANGELOG.md b/CHANGELOG.md index ee95b4ae9..22c458738 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.3.14 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main). +* Changes + * {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument + that allows selection between an LU-decomposition based implementation and + an implementation based on QR decomposition. + * {func}`jax.numpy.linalg.qr` now supports `mode="raw"`. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 35b31fd7f..8937407cf 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -18,7 +18,7 @@ from functools import partial import numpy as np import textwrap import operator -from typing import Tuple, Union, cast +from typing import Optional, Tuple, Union, cast from jax import jit, custom_jvp from jax import lax @@ -117,15 +117,8 @@ def matrix_rank(M, tol=None): @custom_jvp -@_wraps(np.linalg.slogdet) -@jit -def slogdet(a): - a, = _promote_dtypes_inexact(jnp.asarray(a)) +def _slogdet_lu(a): dtype = lax.dtype(a) - a_shape = jnp.shape(a) - if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: - msg = "Argument to slogdet() must have shape [..., n, n], got {}" - raise ValueError(msg.format(a_shape)) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) @@ -144,7 +137,50 @@ def slogdet(a): jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet) -@slogdet.defjvp +@custom_jvp +def _slogdet_qr(a): + # Implementation of slogdet using QR decomposition. One reason we might prefer + # QR decomposition is that it is more amenable to a fast batched + # implementation on TPU because of the lack of row pivoting. + if jnp.issubdtype(lax.dtype(a), jnp.complexfloating): + raise NotImplementedError("slogdet method='qr' not implemented for complex " + "inputs") + n = a.shape[-1] + a, taus = lax_linalg.geqrf(a) + # The determinant of a triangular matrix is the product of its diagonal + # elements. We are working in log space, so we compute the magnitude as the + # the trace of the log-absolute values, and we compute the sign separately. + log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1) + sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) + # The determinant of a Householder reflector is -1. So whenever we actually + # made a reflection (tau != 0), multiply the result by -1. + sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1) + return sign_diag * sign_taus, log_abs_det + +@_wraps( + np.linalg.slogdet, + extra_params=textwrap.dedent(""" + method: string, optional + One of ``lu`` or ``qr``, specifying whether the determinant should be + computed using an LU decomposition or a QR decomposition. Defaults to + LU decomposition if ``None``. + """)) +@partial(jit, static_argnames=('method',)) +def slogdet(a, *, method: Optional[str] = None): + a, = _promote_dtypes_inexact(jnp.asarray(a)) + a_shape = jnp.shape(a) + if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: + msg = "Argument to slogdet() must have shape [..., n, n], got {}" + raise ValueError(msg.format(a_shape)) + + if method is None or method == "lu": + return _slogdet_lu(a) + elif method == "qr": + return _slogdet_qr(a) + else: + raise ValueError(f"Unknown slogdet method '{method}'. Supported methods " + "are 'lu' (`None`), and 'qr'.") + def _slogdet_jvp(primals, tangents): x, = primals g, = tangents @@ -157,6 +193,8 @@ def _slogdet_jvp(primals, tangents): sign_dot = jnp.zeros_like(sign) return (sign, ans), (sign_dot, ans_dot) +_slogdet_lu.defjvp(_slogdet_jvp) +_slogdet_qr.defjvp(_slogdet_jvp) def _cofactor_solve(a, b): """Equivalent to det(a)*solve(a, b) for nonsingular mat. diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 632936b02..9bf340d98 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -176,18 +176,22 @@ class NumpyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": - "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} + "_shape={}_method={}".format( + jtu.format_shape_dtype_string(shape, dtype), method), + "shape": shape, "dtype": dtype, "method": method} for shape in [(0, 0), (1, 1), (3, 3), (4, 4), (10, 10), (200, 200), (2, 2, 2), (2, 3, 3), (3, 2, 2)] - for dtype in float_types + complex_types)) - def testSlogdet(self, shape, dtype): + for dtype in float_types + complex_types + for method in (["lu"] if jnp.issubdtype(dtype, jnp.complexfloating) + else ["lu", "qr"]) + )) + def testSlogdet(self, shape, dtype, method): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - - self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker, + slogdet = partial(jnp.linalg.slogdet, method=method) + self._CheckAgainstNumpy(np.linalg.slogdet, slogdet, args_maker, tol=1e-3) - self._CompileAndCheck(jnp.linalg.slogdet, args_maker) + self._CompileAndCheck(slogdet, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name":