Add a sort_eigenvalues option to lax.linalg.eigh().

An upcoming change to add a more scalable QDWH-based TPU symmetric eigendecomposition requires that we can obtain the TPU eigenvalues unsorted. The option already exists in XLA, so we simply need to plumb it through to the lax primitive.

PiperOrigin-RevId: 448047584
This commit is contained in:
Peter Hawkins 2022-05-11 11:45:28 -07:00 committed by jax authors
parent 52ad3e6682
commit 590b9161fe
4 changed files with 73 additions and 25 deletions

View File

@ -10,6 +10,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...main).
* Changes
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
that allows users to opt out of eigenvalue sorting on TPU.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -95,8 +95,9 @@ def eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True):
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
def eigh(x, lower: bool = True, symmetrize_input: bool = True):
"""Eigendecomposition of a Hermitian matrix.
def eigh(x, lower: bool = True, symmetrize_input: bool = True,
sort_eigenvalues: bool = True, ):
r"""Eigendecomposition of a Hermitian matrix.
Computes the eigenvectors and eigenvalues of a complex Hermitian or real
symmetric square matrix.
@ -109,7 +110,10 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True):
triangle given by ``lower`` is accessed; the other triangle is ignored and
not accessed.
symmetrize_input: If ``True``, the matrix is symmetrized before the
eigendecomposition by computing :math:`\\frac{1}{2}(x + x^H)`.
eigendecomposition by computing :math:`\frac{1}{2}(x + x^H)`.
sort_eigenvalues: If ``True``, the eigenvalues will be sorted in ascending
order. If ``False`` the eigenvalues are returned in an
implementation-defined order.
Returns:
A tuple ``(v, w)``.
@ -123,7 +127,7 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True):
"""
if symmetrize_input:
x = symmetrize(x)
v, w = eigh_p.bind(x, lower=lower)
v, w = eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
return v, w
@ -515,17 +519,19 @@ ad.primitive_jvps[eig_p] = eig_jvp_rule
# Symmetric/Hermitian eigendecomposition
def eigh_impl(operand, lower):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
def _eigh_impl(operand, *, lower, sort_eigenvalues):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return v, w
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower):
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
sort_eigenvalues):
operand_aval, = avals_in
if operand_aval.shape[-1] == 0:
return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))]
return xops.Eigh(operand, lower=lower)
return xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
def eigh_abstract_eval(operand, lower):
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError(
@ -541,7 +547,9 @@ def eigh_abstract_eval(operand, lower):
v, w = operand, operand
return v, w
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower):
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
sort_eigenvalues):
del sort_eigenvalues # The CPU/GPU implementations always sort.
operand_aval, = ctx.avals_in
v_aval, w_aval = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
@ -562,7 +570,7 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower):
w, _nan_like_mhlo(w_aval))
return [v, w]
def eigh_jvp_rule(primals, tangents, lower):
def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# This is classic nondegenerate perurbation theory, but also see
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
@ -574,7 +582,8 @@ def eigh_jvp_rule(primals, tangents, lower):
a, = primals
a_dot, = tangents
v, w_real = eigh_p.bind(symmetrize(a), lower=lower)
v, w_real = eigh_p.bind(symmetrize(a), lower=lower,
sort_eigenvalues=sort_eigenvalues)
# for complex numbers we need eigenvalues to be full dtype of v, a:
w = w_real.astype(a.dtype)
@ -589,19 +598,19 @@ def eigh_jvp_rule(primals, tangents, lower):
dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
return (v, w_real), (dv, dw)
def eigh_batching_rule(batched_args, batch_dims, lower):
def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return eigh_p.bind(x, lower=lower), (0, 0)
return eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues), (0, 0)
eigh_p = Primitive('eigh')
eigh_p.multiple_results = True
eigh_p.def_impl(eigh_impl)
eigh_p.def_abstract_eval(eigh_abstract_eval)
eigh_p.def_impl(_eigh_impl)
eigh_p.def_abstract_eval(_eigh_abstract_eval)
xla.register_translation(eigh_p, _eigh_translation_rule)
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
batching.primitive_batchers[eigh_p] = eigh_batching_rule
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo),

View File

@ -2519,7 +2519,9 @@ def _eig(operand: TfVal, compute_left_eigenvectors: bool,
tf_impl[lax.linalg.eig_p] = _eig
def _eigh(operand: TfVal, lower: bool, _in_avals, _out_aval):
def _eigh(operand: TfVal, lower: bool, sort_eigenvalues: bool, _in_avals,
_out_aval):
del sort_eigenvalues
if operand.shape[-1] == 0:
v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1]))
else:

View File

@ -321,12 +321,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}_lower={}".format(
jtu.format_shape_dtype_string((n,n), dtype), lower),
{"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format(
jtu.format_shape_dtype_string((n,n), dtype), lower,
sort_eigenvalues),
"n": n, "dtype": dtype, "lower": lower}
for n in [0, 4, 5, 50]
for dtype in float_types + complex_types
for lower in [True, False]))
for lower in [True, False]
for sort_eigenvalues in [True, False]))
def testEigh(self, n, dtype, lower):
rng = jtu.rand_default(self.rng())
tol = 1e-3
@ -1565,8 +1567,40 @@ class ScipyLinalgTest(jtu.JaxTestCase):
class LaxLinalgTest(jtu.JaxTestCase):
"""Tests for lax.linalg primitives."""
def run_test(self, alpha, beta):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format(
jtu.format_shape_dtype_string((n,n), dtype), lower,
sort_eigenvalues),
"n": n, "dtype": dtype, "lower": lower,
"sort_eigenvalues": sort_eigenvalues}
for n in [0, 4, 5, 50]
for dtype in float_types + complex_types
for lower in [True, False]
for sort_eigenvalues in [True, False]))
def testEigh(self, n, dtype, lower, sort_eigenvalues):
rng = jtu.rand_default(self.rng())
tol = 1e-3
args_maker = lambda: [rng((n, n), dtype)]
a, = args_maker()
a = (a + np.conj(a.T)) / 2
v, w = lax.linalg.eigh(np.tril(a) if lower else np.triu(a),
lower=lower, symmetrize_input=False,
sort_eigenvalues=sort_eigenvalues)
w = np.asarray(w)
v = np.asarray(v)
self.assertLessEqual(
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3)
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
tol * np.linalg.norm(a))
w_expected, v_expected = np.linalg.eigh(np.asarray(a))
self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w),
rtol=1e-4)
def run_eigh_tridiagonal_test(self, alpha, beta):
n = alpha.shape[-1]
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
# this we call the slower numpy.linalg.eigh.
@ -1592,7 +1626,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
alpha = a * np.ones([n], dtype=dtype)
beta = b * np.ones([n - 1], dtype=dtype)
self.run_test(alpha, beta)
self.run_eigh_tridiagonal_test(alpha, beta)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_n={n}_dtype={dtype.__name__}",
@ -1602,7 +1636,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
def testRandomUniform(self, n, dtype):
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
self.run_test(alpha, beta)
self.run_eigh_tridiagonal_test(alpha, beta)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype.__name__}",