mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[TPU] Switch the default eigendecomposition implementation on TPU to use QDWH-eig.
Adds a new non-differentiable primitive `eigh_jacobi` that calls the XLA Jacobi eigh implementation for use inside the TPU QDWH-eigh lowering rule. PiperOrigin-RevId: 451471088
This commit is contained in:
parent
0553f9ed06
commit
5ccdcc5cc6
@ -24,6 +24,7 @@ import jax._src.numpy.lax_numpy as jnp
|
||||
import jax._src.numpy.linalg as jnp_linalg
|
||||
from jax import lax
|
||||
from jax._src.lax import qdwh
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.lax.stack import Stack
|
||||
|
||||
|
||||
@ -384,16 +385,19 @@ def _eigh_work(H, n, termination_size=256):
|
||||
return blocks[:, 0], eigenvectors
|
||||
|
||||
|
||||
def eigh(H, *, precision="float32", termination_size=256, n=None):
|
||||
def eigh(H, *, precision="float32", termination_size=256, n=None,
|
||||
sort_eigenvalues=True):
|
||||
""" Computes the eigendecomposition of the symmetric/Hermitian matrix H.
|
||||
|
||||
Args:
|
||||
H: The `n x n` Hermitian input.
|
||||
H: The `n x n` Hermitian input, padded to `N x N`.
|
||||
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
|
||||
symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used.
|
||||
termination_size: Recursion ends once the blocks reach this linear size.
|
||||
n: the true (dynamic) size of the matrix.
|
||||
sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to
|
||||
highest.
|
||||
Returns:
|
||||
vals: The `n` eigenvalues of `H`, sorted from lowest to highest.
|
||||
vals: The `n` eigenvalues of `H`.
|
||||
vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
|
||||
of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
|
||||
to numerical error.
|
||||
@ -403,7 +407,10 @@ def eigh(H, *, precision="float32", termination_size=256, n=None):
|
||||
raise TypeError(f"Input H of shape {H.shape} must be square.")
|
||||
|
||||
if N <= termination_size:
|
||||
return jnp_linalg.eigh(H)
|
||||
if n is not None:
|
||||
H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype))
|
||||
return lax_linalg.eigh_jacobi(
|
||||
H, sort_eigenvalues=sort_eigenvalues)
|
||||
|
||||
# TODO(phawkins): consider rounding N up to a larger size to maximize reuse
|
||||
# between matrices.
|
||||
@ -412,7 +419,8 @@ def eigh(H, *, precision="float32", termination_size=256, n=None):
|
||||
with jax.default_matmul_precision(precision):
|
||||
eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size)
|
||||
eig_vals = _mask(jnp.real(eig_vals), (n,), jnp.nan)
|
||||
sort_idxs = jnp.argsort(eig_vals)
|
||||
eig_vals = eig_vals[sort_idxs]
|
||||
eig_vecs = eig_vecs[:, sort_idxs]
|
||||
if sort_eigenvalues:
|
||||
sort_idxs = jnp.argsort(eig_vals)
|
||||
eig_vals = eig_vals[sort_idxs]
|
||||
eig_vecs = eig_vecs[:, sort_idxs]
|
||||
return eig_vals, eig_vecs
|
||||
|
@ -35,6 +35,8 @@ from jax.core import Primitive, ShapedArray, raise_to_shaped
|
||||
from jax._src.lax.lax import (
|
||||
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
||||
_input_dtype)
|
||||
from jax._src.lax import control_flow
|
||||
from jax._src.lax import eigh as lax_eigh
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import svd as lax_svd
|
||||
import jax._src.lib
|
||||
@ -574,18 +576,55 @@ ad.primitive_jvps[eig_p] = eig_jvp_rule
|
||||
|
||||
# Symmetric/Hermitian eigendecomposition
|
||||
|
||||
|
||||
def eigh_jacobi(x, *, lower: bool = True, sort_eigenvalues: bool = True):
|
||||
"""Helper Jacobi eigendecomposition implemented by XLA.
|
||||
|
||||
Used as a subroutine of QDWH-eig on TPU."""
|
||||
w, v = eigh_jacobi_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
||||
return w, v
|
||||
|
||||
def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues):
|
||||
w, v = xla.apply_primitive(eigh_jacobi_p, operand, lower=lower,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
return w, v
|
||||
|
||||
def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
if isinstance(operand, ShapedArray):
|
||||
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
||||
raise ValueError(
|
||||
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
|
||||
"got shape {}".format(operand.shape))
|
||||
|
||||
batch_dims = operand.shape[:-2]
|
||||
n = operand.shape[-1]
|
||||
w = operand.update(shape=batch_dims + (n,),
|
||||
dtype=lax_internal._complex_basetype(operand.dtype))
|
||||
v = operand.update(shape=batch_dims + (n, n))
|
||||
else:
|
||||
w, v = operand, operand
|
||||
return w, v
|
||||
|
||||
def _eigh_jacobi_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
|
||||
sort_eigenvalues):
|
||||
operand_aval, = avals_in
|
||||
if operand_aval.shape[-1] == 0:
|
||||
return [xops.Real(xops.Reshape(operand, operand_aval.shape[:-1])), operand]
|
||||
v, w = xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
||||
return w, v
|
||||
|
||||
eigh_jacobi_p = Primitive('eigh_jacobi')
|
||||
eigh_jacobi_p.multiple_results = True
|
||||
eigh_jacobi_p.def_impl(_eigh_jacobi_impl)
|
||||
eigh_jacobi_p.def_abstract_eval(_eigh_jacobi_abstract_eval)
|
||||
xla.register_translation(eigh_jacobi_p, _eigh_jacobi_translation_rule)
|
||||
|
||||
|
||||
def _eigh_impl(operand, *, lower, sort_eigenvalues):
|
||||
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
return v, w
|
||||
|
||||
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
|
||||
sort_eigenvalues):
|
||||
operand_aval, = avals_in
|
||||
if operand_aval.shape[-1] == 0:
|
||||
return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))]
|
||||
return xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
||||
|
||||
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
if isinstance(operand, ShapedArray):
|
||||
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
||||
@ -625,6 +664,45 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
|
||||
w, _nan_like_mhlo(w_aval))
|
||||
return [v, w]
|
||||
|
||||
def _eigh_tpu_impl(x, *, lower, sort_eigenvalues):
|
||||
*_, m, n = x.shape
|
||||
assert m == n, (m, n)
|
||||
|
||||
termination_size = 256
|
||||
|
||||
if m <= termination_size:
|
||||
eig_vals, eig_vecs = eigh_jacobi(x, lower=lower,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
return eig_vecs, eig_vals
|
||||
|
||||
def eigh_qdwh(x):
|
||||
if len(x.shape) > 2:
|
||||
return control_flow.map(eigh_qdwh, x)
|
||||
|
||||
# We should only look at elements from the lower/upper triangle. Reflects
|
||||
# that triangle into the other triangle to form a Hermitian matrix.
|
||||
if lower:
|
||||
mask = jnp.tri(n, k=0, dtype=bool)
|
||||
else:
|
||||
mask = jnp.logical_not(jnp.tri(n, k=-1, dtype=bool))
|
||||
if dtypes.issubdtype(x.dtype, jnp.complexfloating):
|
||||
re = lax.select(mask, lax.real(x), _T(lax.real(x)))
|
||||
if lower:
|
||||
im_mask = jnp.tri(n, k=-1, dtype=bool)
|
||||
else:
|
||||
im_mask = jnp.logical_not(jnp.tri(n, k=0, dtype=bool))
|
||||
im = lax.select(im_mask, lax.imag(x), jnp.zeros_like(lax.imag(x)))
|
||||
im = lax.select(mask, im, -_T(im))
|
||||
x = lax.complex(re, im)
|
||||
else:
|
||||
x = lax.select(mask, x, _T(x))
|
||||
|
||||
return lax_eigh.eigh(x, sort_eigenvalues=sort_eigenvalues,
|
||||
termination_size=termination_size)
|
||||
|
||||
eig_vals, eig_vecs = eigh_qdwh(x)
|
||||
return eig_vecs, eig_vals
|
||||
|
||||
def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues):
|
||||
# Derivative for eigh in the simplest case of distinct eigenvalues.
|
||||
# This is classic nondegenerate perurbation theory, but also see
|
||||
@ -663,7 +741,6 @@ eigh_p = Primitive('eigh')
|
||||
eigh_p.multiple_results = True
|
||||
eigh_p.def_impl(_eigh_impl)
|
||||
eigh_p.def_abstract_eval(_eigh_abstract_eval)
|
||||
xla.register_translation(eigh_p, _eigh_translation_rule)
|
||||
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
|
||||
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
|
||||
|
||||
@ -685,6 +762,10 @@ if solver_apis is not None:
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, solver_apis.syevd_mhlo),
|
||||
platform='gpu')
|
||||
|
||||
mlir.register_lowering(
|
||||
eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True),
|
||||
platform='tpu')
|
||||
|
||||
|
||||
triangular_solve_dtype_rule = partial(
|
||||
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
|
||||
|
@ -1012,6 +1012,7 @@ tf_not_yet_impl = [
|
||||
"xla_pmap",
|
||||
"geqrf",
|
||||
"orgqr",
|
||||
"eigh_jacobi",
|
||||
]
|
||||
|
||||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||||
|
@ -332,7 +332,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string((n,n), dtype), lower,
|
||||
sort_eigenvalues),
|
||||
"n": n, "dtype": dtype, "lower": lower}
|
||||
for n in [0, 4, 5, 50]
|
||||
for n in [0, 4, 5, 50, 512]
|
||||
for dtype in float_types + complex_types
|
||||
for lower in [True, False]
|
||||
for sort_eigenvalues in [True, False]))
|
||||
@ -459,7 +459,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(1, 1), (4, 4), (5, 5)]
|
||||
for shape in [(1, 1), (4, 4), (5, 5), (300, 300)]
|
||||
for dtype in float_types + complex_types))
|
||||
def testEighBatching(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -467,8 +467,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
args = rng(shape, dtype)
|
||||
args = (args + np.conj(T(args))) / 2
|
||||
ws, vs = vmap(jsp.linalg.eigh)(args)
|
||||
self.assertTrue(np.all(np.linalg.norm(
|
||||
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
||||
norm = np.max(np.linalg.norm(np.matmul(args, vs) - ws[..., None, :] * vs))
|
||||
self.assertTrue(norm < 3e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
|
Loading…
x
Reference in New Issue
Block a user