mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add an implementation of jnp.linalg.slogdet based on QR decomposition.
Adds a non-standard `method` argument to `jnp.linalg.slogdet` to select between the current LU decomposition based implementation (like NumPy) and the QR decomposition implementation. QR decomposition is more amenable to a high performance batched implementation particularly on TPU hardware because it does not need row pivoting. The same may be true on other hardware also, and having the option is nice either way! PiperOrigin-RevId: 449271317
This commit is contained in:
parent
548a6bf58b
commit
1bcb5e073c
@ -10,6 +10,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jax 0.3.14 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main).
|
||||
* Changes
|
||||
* {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument
|
||||
that allows selection between an LU-decomposition based implementation and
|
||||
an implementation based on QR decomposition.
|
||||
* {func}`jax.numpy.linalg.qr` now supports `mode="raw"`.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
@ -18,7 +18,7 @@ from functools import partial
|
||||
import numpy as np
|
||||
import textwrap
|
||||
import operator
|
||||
from typing import Tuple, Union, cast
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
|
||||
from jax import jit, custom_jvp
|
||||
from jax import lax
|
||||
@ -117,15 +117,8 @@ def matrix_rank(M, tol=None):
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.linalg.slogdet)
|
||||
@jit
|
||||
def slogdet(a):
|
||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||
def _slogdet_lu(a):
|
||||
dtype = lax.dtype(a)
|
||||
a_shape = jnp.shape(a)
|
||||
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
||||
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
|
||||
raise ValueError(msg.format(a_shape))
|
||||
lu, pivot, _ = lax_linalg.lu(a)
|
||||
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
||||
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
|
||||
@ -144,7 +137,50 @@ def slogdet(a):
|
||||
jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
|
||||
return sign, jnp.real(logdet)
|
||||
|
||||
@slogdet.defjvp
|
||||
@custom_jvp
|
||||
def _slogdet_qr(a):
|
||||
# Implementation of slogdet using QR decomposition. One reason we might prefer
|
||||
# QR decomposition is that it is more amenable to a fast batched
|
||||
# implementation on TPU because of the lack of row pivoting.
|
||||
if jnp.issubdtype(lax.dtype(a), jnp.complexfloating):
|
||||
raise NotImplementedError("slogdet method='qr' not implemented for complex "
|
||||
"inputs")
|
||||
n = a.shape[-1]
|
||||
a, taus = lax_linalg.geqrf(a)
|
||||
# The determinant of a triangular matrix is the product of its diagonal
|
||||
# elements. We are working in log space, so we compute the magnitude as the
|
||||
# the trace of the log-absolute values, and we compute the sign separately.
|
||||
log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1)
|
||||
sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
|
||||
# The determinant of a Householder reflector is -1. So whenever we actually
|
||||
# made a reflection (tau != 0), multiply the result by -1.
|
||||
sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1)
|
||||
return sign_diag * sign_taus, log_abs_det
|
||||
|
||||
@_wraps(
|
||||
np.linalg.slogdet,
|
||||
extra_params=textwrap.dedent("""
|
||||
method: string, optional
|
||||
One of ``lu`` or ``qr``, specifying whether the determinant should be
|
||||
computed using an LU decomposition or a QR decomposition. Defaults to
|
||||
LU decomposition if ``None``.
|
||||
"""))
|
||||
@partial(jit, static_argnames=('method',))
|
||||
def slogdet(a, *, method: Optional[str] = None):
|
||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||
a_shape = jnp.shape(a)
|
||||
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
||||
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
|
||||
raise ValueError(msg.format(a_shape))
|
||||
|
||||
if method is None or method == "lu":
|
||||
return _slogdet_lu(a)
|
||||
elif method == "qr":
|
||||
return _slogdet_qr(a)
|
||||
else:
|
||||
raise ValueError(f"Unknown slogdet method '{method}'. Supported methods "
|
||||
"are 'lu' (`None`), and 'qr'.")
|
||||
|
||||
def _slogdet_jvp(primals, tangents):
|
||||
x, = primals
|
||||
g, = tangents
|
||||
@ -157,6 +193,8 @@ def _slogdet_jvp(primals, tangents):
|
||||
sign_dot = jnp.zeros_like(sign)
|
||||
return (sign, ans), (sign_dot, ans_dot)
|
||||
|
||||
_slogdet_lu.defjvp(_slogdet_jvp)
|
||||
_slogdet_qr.defjvp(_slogdet_jvp)
|
||||
|
||||
def _cofactor_solve(a, b):
|
||||
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
|
||||
|
@ -176,18 +176,22 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
"_shape={}_method={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), method),
|
||||
"shape": shape, "dtype": dtype, "method": method}
|
||||
for shape in [(0, 0), (1, 1), (3, 3), (4, 4), (10, 10), (200, 200),
|
||||
(2, 2, 2), (2, 3, 3), (3, 2, 2)]
|
||||
for dtype in float_types + complex_types))
|
||||
def testSlogdet(self, shape, dtype):
|
||||
for dtype in float_types + complex_types
|
||||
for method in (["lu"] if jnp.issubdtype(dtype, jnp.complexfloating)
|
||||
else ["lu", "qr"])
|
||||
))
|
||||
def testSlogdet(self, shape, dtype, method):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
||||
slogdet = partial(jnp.linalg.slogdet, method=method)
|
||||
self._CheckAgainstNumpy(np.linalg.slogdet, slogdet, args_maker,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.slogdet, args_maker)
|
||||
self._CompileAndCheck(slogdet, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
|
Loading…
x
Reference in New Issue
Block a user