[TPU] Switch the default eigendecomposition implementation on TPU to use QDWH-eig.

Adds a new non-differentiable primitive `eigh_jacobi` that calls the XLA Jacobi eigh implementation for use inside the TPU QDWH-eigh lowering rule.

PiperOrigin-RevId: 451471088
This commit is contained in:
Peter Hawkins 2022-05-27 13:50:07 -07:00 committed by jax authors
parent 0553f9ed06
commit 5ccdcc5cc6
4 changed files with 110 additions and 20 deletions

View File

@ -24,6 +24,7 @@ import jax._src.numpy.lax_numpy as jnp
import jax._src.numpy.linalg as jnp_linalg
from jax import lax
from jax._src.lax import qdwh
from jax._src.lax import linalg as lax_linalg
from jax._src.lax.stack import Stack
@ -384,16 +385,19 @@ def _eigh_work(H, n, termination_size=256):
return blocks[:, 0], eigenvectors
def eigh(H, *, precision="float32", termination_size=256, n=None):
def eigh(H, *, precision="float32", termination_size=256, n=None,
sort_eigenvalues=True):
""" Computes the eigendecomposition of the symmetric/Hermitian matrix H.
Args:
H: The `n x n` Hermitian input.
H: The `n x n` Hermitian input, padded to `N x N`.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used.
termination_size: Recursion ends once the blocks reach this linear size.
n: the true (dynamic) size of the matrix.
sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to
highest.
Returns:
vals: The `n` eigenvalues of `H`, sorted from lowest to highest.
vals: The `n` eigenvalues of `H`.
vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
to numerical error.
@ -403,7 +407,10 @@ def eigh(H, *, precision="float32", termination_size=256, n=None):
raise TypeError(f"Input H of shape {H.shape} must be square.")
if N <= termination_size:
return jnp_linalg.eigh(H)
if n is not None:
H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype))
return lax_linalg.eigh_jacobi(
H, sort_eigenvalues=sort_eigenvalues)
# TODO(phawkins): consider rounding N up to a larger size to maximize reuse
# between matrices.
@ -412,7 +419,8 @@ def eigh(H, *, precision="float32", termination_size=256, n=None):
with jax.default_matmul_precision(precision):
eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size)
eig_vals = _mask(jnp.real(eig_vals), (n,), jnp.nan)
sort_idxs = jnp.argsort(eig_vals)
eig_vals = eig_vals[sort_idxs]
eig_vecs = eig_vecs[:, sort_idxs]
if sort_eigenvalues:
sort_idxs = jnp.argsort(eig_vals)
eig_vals = eig_vals[sort_idxs]
eig_vecs = eig_vecs[:, sort_idxs]
return eig_vals, eig_vecs

View File

@ -35,6 +35,8 @@ from jax.core import Primitive, ShapedArray, raise_to_shaped
from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype)
from jax._src.lax import control_flow
from jax._src.lax import eigh as lax_eigh
from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
import jax._src.lib
@ -574,18 +576,55 @@ ad.primitive_jvps[eig_p] = eig_jvp_rule
# Symmetric/Hermitian eigendecomposition
def eigh_jacobi(x, *, lower: bool = True, sort_eigenvalues: bool = True):
"""Helper Jacobi eigendecomposition implemented by XLA.
Used as a subroutine of QDWH-eig on TPU."""
w, v = eigh_jacobi_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
return w, v
def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues):
w, v = xla.apply_primitive(eigh_jacobi_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return w, v
def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError(
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
"got shape {}".format(operand.shape))
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
w = operand.update(shape=batch_dims + (n,),
dtype=lax_internal._complex_basetype(operand.dtype))
v = operand.update(shape=batch_dims + (n, n))
else:
w, v = operand, operand
return w, v
def _eigh_jacobi_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
sort_eigenvalues):
operand_aval, = avals_in
if operand_aval.shape[-1] == 0:
return [xops.Real(xops.Reshape(operand, operand_aval.shape[:-1])), operand]
v, w = xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
return w, v
eigh_jacobi_p = Primitive('eigh_jacobi')
eigh_jacobi_p.multiple_results = True
eigh_jacobi_p.def_impl(_eigh_jacobi_impl)
eigh_jacobi_p.def_abstract_eval(_eigh_jacobi_abstract_eval)
xla.register_translation(eigh_jacobi_p, _eigh_jacobi_translation_rule)
def _eigh_impl(operand, *, lower, sort_eigenvalues):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return v, w
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
sort_eigenvalues):
operand_aval, = avals_in
if operand_aval.shape[-1] == 0:
return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))]
return xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
@ -625,6 +664,45 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
w, _nan_like_mhlo(w_aval))
return [v, w]
def _eigh_tpu_impl(x, *, lower, sort_eigenvalues):
*_, m, n = x.shape
assert m == n, (m, n)
termination_size = 256
if m <= termination_size:
eig_vals, eig_vecs = eigh_jacobi(x, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return eig_vecs, eig_vals
def eigh_qdwh(x):
if len(x.shape) > 2:
return control_flow.map(eigh_qdwh, x)
# We should only look at elements from the lower/upper triangle. Reflects
# that triangle into the other triangle to form a Hermitian matrix.
if lower:
mask = jnp.tri(n, k=0, dtype=bool)
else:
mask = jnp.logical_not(jnp.tri(n, k=-1, dtype=bool))
if dtypes.issubdtype(x.dtype, jnp.complexfloating):
re = lax.select(mask, lax.real(x), _T(lax.real(x)))
if lower:
im_mask = jnp.tri(n, k=-1, dtype=bool)
else:
im_mask = jnp.logical_not(jnp.tri(n, k=0, dtype=bool))
im = lax.select(im_mask, lax.imag(x), jnp.zeros_like(lax.imag(x)))
im = lax.select(mask, im, -_T(im))
x = lax.complex(re, im)
else:
x = lax.select(mask, x, _T(x))
return lax_eigh.eigh(x, sort_eigenvalues=sort_eigenvalues,
termination_size=termination_size)
eig_vals, eig_vecs = eigh_qdwh(x)
return eig_vecs, eig_vals
def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# This is classic nondegenerate perurbation theory, but also see
@ -663,7 +741,6 @@ eigh_p = Primitive('eigh')
eigh_p.multiple_results = True
eigh_p.def_impl(_eigh_impl)
eigh_p.def_abstract_eval(_eigh_abstract_eval)
xla.register_translation(eigh_p, _eigh_translation_rule)
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
@ -685,6 +762,10 @@ if solver_apis is not None:
eigh_p, partial(_eigh_cpu_gpu_lowering, solver_apis.syevd_mhlo),
platform='gpu')
mlir.register_lowering(
eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True),
platform='tpu')
triangular_solve_dtype_rule = partial(
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),

View File

@ -1012,6 +1012,7 @@ tf_not_yet_impl = [
"xla_pmap",
"geqrf",
"orgqr",
"eigh_jacobi",
]
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient

View File

@ -332,7 +332,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
jtu.format_shape_dtype_string((n,n), dtype), lower,
sort_eigenvalues),
"n": n, "dtype": dtype, "lower": lower}
for n in [0, 4, 5, 50]
for n in [0, 4, 5, 50, 512]
for dtype in float_types + complex_types
for lower in [True, False]
for sort_eigenvalues in [True, False]))
@ -459,7 +459,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
"shape": shape, "dtype": dtype}
for shape in [(1, 1), (4, 4), (5, 5)]
for shape in [(1, 1), (4, 4), (5, 5), (300, 300)]
for dtype in float_types + complex_types))
def testEighBatching(self, shape, dtype):
rng = jtu.rand_default(self.rng())
@ -467,8 +467,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
args = rng(shape, dtype)
args = (args + np.conj(T(args))) / 2
ws, vs = vmap(jsp.linalg.eigh)(args)
self.assertTrue(np.all(np.linalg.norm(
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
norm = np.max(np.linalg.norm(np.matmul(args, vs) - ws[..., None, :] * vs))
self.assertTrue(norm < 3e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":