From 590b9161fe9d9efbfa79122b17ca9b31fa1d0a83 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 May 2022 11:45:28 -0700 Subject: [PATCH] Add a sort_eigenvalues option to lax.linalg.eigh(). An upcoming change to add a more scalable QDWH-based TPU symmetric eigendecomposition requires that we can obtain the TPU eigenvalues unsorted. The option already exists in XLA, so we simply need to plumb it through to the lax primitive. PiperOrigin-RevId: 448047584 --- CHANGELOG.md | 3 ++ jax/_src/lax/linalg.py | 45 ++++++++++++++++++------------ jax/experimental/jax2tf/jax2tf.py | 4 ++- tests/linalg_test.py | 46 +++++++++++++++++++++++++++---- 4 files changed, 73 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f42b033a..7db696a8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...main). +* Changes + * {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument + that allows users to opt out of eigenvalue sorting on TPU. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 4520a6c61..c9a757fa8 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -95,8 +95,9 @@ def eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True): return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors) -def eigh(x, lower: bool = True, symmetrize_input: bool = True): - """Eigendecomposition of a Hermitian matrix. +def eigh(x, lower: bool = True, symmetrize_input: bool = True, + sort_eigenvalues: bool = True, ): + r"""Eigendecomposition of a Hermitian matrix. Computes the eigenvectors and eigenvalues of a complex Hermitian or real symmetric square matrix. @@ -109,7 +110,10 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True): triangle given by ``lower`` is accessed; the other triangle is ignored and not accessed. symmetrize_input: If ``True``, the matrix is symmetrized before the - eigendecomposition by computing :math:`\\frac{1}{2}(x + x^H)`. + eigendecomposition by computing :math:`\frac{1}{2}(x + x^H)`. + sort_eigenvalues: If ``True``, the eigenvalues will be sorted in ascending + order. If ``False`` the eigenvalues are returned in an + implementation-defined order. Returns: A tuple ``(v, w)``. @@ -123,7 +127,7 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True): """ if symmetrize_input: x = symmetrize(x) - v, w = eigh_p.bind(x, lower=lower) + v, w = eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues) return v, w @@ -515,17 +519,19 @@ ad.primitive_jvps[eig_p] = eig_jvp_rule # Symmetric/Hermitian eigendecomposition -def eigh_impl(operand, lower): - v, w = xla.apply_primitive(eigh_p, operand, lower=lower) +def _eigh_impl(operand, *, lower, sort_eigenvalues): + v, w = xla.apply_primitive(eigh_p, operand, lower=lower, + sort_eigenvalues=sort_eigenvalues) return v, w -def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower): +def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower, + sort_eigenvalues): operand_aval, = avals_in if operand_aval.shape[-1] == 0: return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))] - return xops.Eigh(operand, lower=lower) + return xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues) -def eigh_abstract_eval(operand, lower): +def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( @@ -541,7 +547,9 @@ def eigh_abstract_eval(operand, lower): v, w = operand, operand return v, w -def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower): +def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower, + sort_eigenvalues): + del sort_eigenvalues # The CPU/GPU implementations always sort. operand_aval, = ctx.avals_in v_aval, w_aval = ctx.avals_out batch_dims = operand_aval.shape[:-2] @@ -562,7 +570,7 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower): w, _nan_like_mhlo(w_aval)) return [v, w] -def eigh_jvp_rule(primals, tangents, lower): +def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues): # Derivative for eigh in the simplest case of distinct eigenvalues. # This is classic nondegenerate perurbation theory, but also see # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf @@ -574,7 +582,8 @@ def eigh_jvp_rule(primals, tangents, lower): a, = primals a_dot, = tangents - v, w_real = eigh_p.bind(symmetrize(a), lower=lower) + v, w_real = eigh_p.bind(symmetrize(a), lower=lower, + sort_eigenvalues=sort_eigenvalues) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) @@ -589,19 +598,19 @@ def eigh_jvp_rule(primals, tangents, lower): dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) return (v, w_real), (dv, dw) -def eigh_batching_rule(batched_args, batch_dims, lower): +def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) - return eigh_p.bind(x, lower=lower), (0, 0) + return eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues), (0, 0) eigh_p = Primitive('eigh') eigh_p.multiple_results = True -eigh_p.def_impl(eigh_impl) -eigh_p.def_abstract_eval(eigh_abstract_eval) +eigh_p.def_impl(_eigh_impl) +eigh_p.def_abstract_eval(_eigh_abstract_eval) xla.register_translation(eigh_p, _eigh_translation_rule) -ad.primitive_jvps[eigh_p] = eigh_jvp_rule -batching.primitive_batchers[eigh_p] = eigh_batching_rule +ad.primitive_jvps[eigh_p] = _eigh_jvp_rule +batching.primitive_batchers[eigh_p] = _eigh_batching_rule mlir.register_lowering( eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo), diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index aebf6ca11..8c3336c29 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2519,7 +2519,9 @@ def _eig(operand: TfVal, compute_left_eigenvectors: bool, tf_impl[lax.linalg.eig_p] = _eig -def _eigh(operand: TfVal, lower: bool, _in_avals, _out_aval): +def _eigh(operand: TfVal, lower: bool, sort_eigenvalues: bool, _in_avals, + _out_aval): + del sort_eigenvalues if operand.shape[-1] == 0: v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1])) else: diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 57cdbd864..6612b6515 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -321,12 +321,14 @@ class NumpyLinalgTest(jtu.JaxTestCase): np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_n={}_lower={}".format( - jtu.format_shape_dtype_string((n,n), dtype), lower), + {"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format( + jtu.format_shape_dtype_string((n,n), dtype), lower, + sort_eigenvalues), "n": n, "dtype": dtype, "lower": lower} for n in [0, 4, 5, 50] for dtype in float_types + complex_types - for lower in [True, False])) + for lower in [True, False] + for sort_eigenvalues in [True, False])) def testEigh(self, n, dtype, lower): rng = jtu.rand_default(self.rng()) tol = 1e-3 @@ -1565,8 +1567,40 @@ class ScipyLinalgTest(jtu.JaxTestCase): class LaxLinalgTest(jtu.JaxTestCase): + """Tests for lax.linalg primitives.""" - def run_test(self, alpha, beta): + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format( + jtu.format_shape_dtype_string((n,n), dtype), lower, + sort_eigenvalues), + "n": n, "dtype": dtype, "lower": lower, + "sort_eigenvalues": sort_eigenvalues} + for n in [0, 4, 5, 50] + for dtype in float_types + complex_types + for lower in [True, False] + for sort_eigenvalues in [True, False])) + def testEigh(self, n, dtype, lower, sort_eigenvalues): + rng = jtu.rand_default(self.rng()) + tol = 1e-3 + args_maker = lambda: [rng((n, n), dtype)] + + a, = args_maker() + a = (a + np.conj(a.T)) / 2 + v, w = lax.linalg.eigh(np.tril(a) if lower else np.triu(a), + lower=lower, symmetrize_input=False, + sort_eigenvalues=sort_eigenvalues) + w = np.asarray(w) + v = np.asarray(v) + self.assertLessEqual( + np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3) + self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v), + tol * np.linalg.norm(a)) + + w_expected, v_expected = np.linalg.eigh(np.asarray(a)) + self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w), + rtol=1e-4) + + def run_eigh_tridiagonal_test(self, alpha, beta): n = alpha.shape[-1] # scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for # this we call the slower numpy.linalg.eigh. @@ -1592,7 +1626,7 @@ class LaxLinalgTest(jtu.JaxTestCase): for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]: alpha = a * np.ones([n], dtype=dtype) beta = b * np.ones([n - 1], dtype=dtype) - self.run_test(alpha, beta) + self.run_eigh_tridiagonal_test(alpha, beta) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_n={n}_dtype={dtype.__name__}", @@ -1602,7 +1636,7 @@ class LaxLinalgTest(jtu.JaxTestCase): def testRandomUniform(self, n, dtype): alpha = jtu.rand_uniform(self.rng())((n,), dtype) beta = jtu.rand_uniform(self.rng())((n - 1,), dtype) - self.run_test(alpha, beta) + self.run_eigh_tridiagonal_test(alpha, beta) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_dtype={dtype.__name__}",