Add an implementation of jnp.linalg.slogdet based on QR decomposition.

Adds a non-standard `method` argument to `jnp.linalg.slogdet` to select between the current LU decomposition based implementation (like NumPy) and the QR decomposition implementation.

QR decomposition is more amenable to a high performance batched implementation particularly on TPU hardware because it does not need row pivoting. The same may be true on other hardware also, and having the option is nice either way!

PiperOrigin-RevId: 449271317
This commit is contained in:
Peter Hawkins 2022-05-17 11:23:10 -07:00 committed by jax authors
parent 548a6bf58b
commit 1bcb5e073c
3 changed files with 64 additions and 17 deletions

View File

@ -10,6 +10,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.14 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main).
* Changes
* {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument
that allows selection between an LU-decomposition based implementation and
an implementation based on QR decomposition.
* {func}`jax.numpy.linalg.qr` now supports `mode="raw"`.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -18,7 +18,7 @@ from functools import partial
import numpy as np
import textwrap
import operator
from typing import Tuple, Union, cast
from typing import Optional, Tuple, Union, cast
from jax import jit, custom_jvp
from jax import lax
@ -117,15 +117,8 @@ def matrix_rank(M, tol=None):
@custom_jvp
@_wraps(np.linalg.slogdet)
@jit
def slogdet(a):
a, = _promote_dtypes_inexact(jnp.asarray(a))
def _slogdet_lu(a):
dtype = lax.dtype(a)
a_shape = jnp.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
raise ValueError(msg.format(a_shape))
lu, pivot, _ = lax_linalg.lu(a)
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
@ -144,7 +137,50 @@ def slogdet(a):
jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
return sign, jnp.real(logdet)
@slogdet.defjvp
@custom_jvp
def _slogdet_qr(a):
# Implementation of slogdet using QR decomposition. One reason we might prefer
# QR decomposition is that it is more amenable to a fast batched
# implementation on TPU because of the lack of row pivoting.
if jnp.issubdtype(lax.dtype(a), jnp.complexfloating):
raise NotImplementedError("slogdet method='qr' not implemented for complex "
"inputs")
n = a.shape[-1]
a, taus = lax_linalg.geqrf(a)
# The determinant of a triangular matrix is the product of its diagonal
# elements. We are working in log space, so we compute the magnitude as the
# the trace of the log-absolute values, and we compute the sign separately.
log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1)
sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
# The determinant of a Householder reflector is -1. So whenever we actually
# made a reflection (tau != 0), multiply the result by -1.
sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1)
return sign_diag * sign_taus, log_abs_det
@_wraps(
np.linalg.slogdet,
extra_params=textwrap.dedent("""
method: string, optional
One of ``lu`` or ``qr``, specifying whether the determinant should be
computed using an LU decomposition or a QR decomposition. Defaults to
LU decomposition if ``None``.
"""))
@partial(jit, static_argnames=('method',))
def slogdet(a, *, method: Optional[str] = None):
a, = _promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
raise ValueError(msg.format(a_shape))
if method is None or method == "lu":
return _slogdet_lu(a)
elif method == "qr":
return _slogdet_qr(a)
else:
raise ValueError(f"Unknown slogdet method '{method}'. Supported methods "
"are 'lu' (`None`), and 'qr'.")
def _slogdet_jvp(primals, tangents):
x, = primals
g, = tangents
@ -157,6 +193,8 @@ def _slogdet_jvp(primals, tangents):
sign_dot = jnp.zeros_like(sign)
return (sign, ans), (sign_dot, ans_dot)
_slogdet_lu.defjvp(_slogdet_jvp)
_slogdet_qr.defjvp(_slogdet_jvp)
def _cofactor_solve(a, b):
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.

View File

@ -176,18 +176,22 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
"_shape={}_method={}".format(
jtu.format_shape_dtype_string(shape, dtype), method),
"shape": shape, "dtype": dtype, "method": method}
for shape in [(0, 0), (1, 1), (3, 3), (4, 4), (10, 10), (200, 200),
(2, 2, 2), (2, 3, 3), (3, 2, 2)]
for dtype in float_types + complex_types))
def testSlogdet(self, shape, dtype):
for dtype in float_types + complex_types
for method in (["lu"] if jnp.issubdtype(dtype, jnp.complexfloating)
else ["lu", "qr"])
))
def testSlogdet(self, shape, dtype, method):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
slogdet = partial(jnp.linalg.slogdet, method=method)
self._CheckAgainstNumpy(np.linalg.slogdet, slogdet, args_maker,
tol=1e-3)
self._CompileAndCheck(jnp.linalg.slogdet, args_maker)
self._CompileAndCheck(slogdet, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":