diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f42b033a..7db696a8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...main). +* Changes + * {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument + that allows users to opt out of eigenvalue sorting on TPU. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 4520a6c61..c9a757fa8 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -95,8 +95,9 @@ def eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True): return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors) -def eigh(x, lower: bool = True, symmetrize_input: bool = True): - """Eigendecomposition of a Hermitian matrix. +def eigh(x, lower: bool = True, symmetrize_input: bool = True, + sort_eigenvalues: bool = True, ): + r"""Eigendecomposition of a Hermitian matrix. Computes the eigenvectors and eigenvalues of a complex Hermitian or real symmetric square matrix. @@ -109,7 +110,10 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True): triangle given by ``lower`` is accessed; the other triangle is ignored and not accessed. symmetrize_input: If ``True``, the matrix is symmetrized before the - eigendecomposition by computing :math:`\\frac{1}{2}(x + x^H)`. + eigendecomposition by computing :math:`\frac{1}{2}(x + x^H)`. + sort_eigenvalues: If ``True``, the eigenvalues will be sorted in ascending + order. If ``False`` the eigenvalues are returned in an + implementation-defined order. Returns: A tuple ``(v, w)``. @@ -123,7 +127,7 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True): """ if symmetrize_input: x = symmetrize(x) - v, w = eigh_p.bind(x, lower=lower) + v, w = eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues) return v, w @@ -515,17 +519,19 @@ ad.primitive_jvps[eig_p] = eig_jvp_rule # Symmetric/Hermitian eigendecomposition -def eigh_impl(operand, lower): - v, w = xla.apply_primitive(eigh_p, operand, lower=lower) +def _eigh_impl(operand, *, lower, sort_eigenvalues): + v, w = xla.apply_primitive(eigh_p, operand, lower=lower, + sort_eigenvalues=sort_eigenvalues) return v, w -def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower): +def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower, + sort_eigenvalues): operand_aval, = avals_in if operand_aval.shape[-1] == 0: return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))] - return xops.Eigh(operand, lower=lower) + return xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues) -def eigh_abstract_eval(operand, lower): +def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( @@ -541,7 +547,9 @@ def eigh_abstract_eval(operand, lower): v, w = operand, operand return v, w -def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower): +def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower, + sort_eigenvalues): + del sort_eigenvalues # The CPU/GPU implementations always sort. operand_aval, = ctx.avals_in v_aval, w_aval = ctx.avals_out batch_dims = operand_aval.shape[:-2] @@ -562,7 +570,7 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower): w, _nan_like_mhlo(w_aval)) return [v, w] -def eigh_jvp_rule(primals, tangents, lower): +def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues): # Derivative for eigh in the simplest case of distinct eigenvalues. # This is classic nondegenerate perurbation theory, but also see # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf @@ -574,7 +582,8 @@ def eigh_jvp_rule(primals, tangents, lower): a, = primals a_dot, = tangents - v, w_real = eigh_p.bind(symmetrize(a), lower=lower) + v, w_real = eigh_p.bind(symmetrize(a), lower=lower, + sort_eigenvalues=sort_eigenvalues) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) @@ -589,19 +598,19 @@ def eigh_jvp_rule(primals, tangents, lower): dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) return (v, w_real), (dv, dw) -def eigh_batching_rule(batched_args, batch_dims, lower): +def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) - return eigh_p.bind(x, lower=lower), (0, 0) + return eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues), (0, 0) eigh_p = Primitive('eigh') eigh_p.multiple_results = True -eigh_p.def_impl(eigh_impl) -eigh_p.def_abstract_eval(eigh_abstract_eval) +eigh_p.def_impl(_eigh_impl) +eigh_p.def_abstract_eval(_eigh_abstract_eval) xla.register_translation(eigh_p, _eigh_translation_rule) -ad.primitive_jvps[eigh_p] = eigh_jvp_rule -batching.primitive_batchers[eigh_p] = eigh_batching_rule +ad.primitive_jvps[eigh_p] = _eigh_jvp_rule +batching.primitive_batchers[eigh_p] = _eigh_batching_rule mlir.register_lowering( eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo), diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index aebf6ca11..8c3336c29 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2519,7 +2519,9 @@ def _eig(operand: TfVal, compute_left_eigenvectors: bool, tf_impl[lax.linalg.eig_p] = _eig -def _eigh(operand: TfVal, lower: bool, _in_avals, _out_aval): +def _eigh(operand: TfVal, lower: bool, sort_eigenvalues: bool, _in_avals, + _out_aval): + del sort_eigenvalues if operand.shape[-1] == 0: v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1])) else: diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 57cdbd864..6612b6515 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -321,12 +321,14 @@ class NumpyLinalgTest(jtu.JaxTestCase): np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_n={}_lower={}".format( - jtu.format_shape_dtype_string((n,n), dtype), lower), + {"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format( + jtu.format_shape_dtype_string((n,n), dtype), lower, + sort_eigenvalues), "n": n, "dtype": dtype, "lower": lower} for n in [0, 4, 5, 50] for dtype in float_types + complex_types - for lower in [True, False])) + for lower in [True, False] + for sort_eigenvalues in [True, False])) def testEigh(self, n, dtype, lower): rng = jtu.rand_default(self.rng()) tol = 1e-3 @@ -1565,8 +1567,40 @@ class ScipyLinalgTest(jtu.JaxTestCase): class LaxLinalgTest(jtu.JaxTestCase): + """Tests for lax.linalg primitives.""" - def run_test(self, alpha, beta): + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format( + jtu.format_shape_dtype_string((n,n), dtype), lower, + sort_eigenvalues), + "n": n, "dtype": dtype, "lower": lower, + "sort_eigenvalues": sort_eigenvalues} + for n in [0, 4, 5, 50] + for dtype in float_types + complex_types + for lower in [True, False] + for sort_eigenvalues in [True, False])) + def testEigh(self, n, dtype, lower, sort_eigenvalues): + rng = jtu.rand_default(self.rng()) + tol = 1e-3 + args_maker = lambda: [rng((n, n), dtype)] + + a, = args_maker() + a = (a + np.conj(a.T)) / 2 + v, w = lax.linalg.eigh(np.tril(a) if lower else np.triu(a), + lower=lower, symmetrize_input=False, + sort_eigenvalues=sort_eigenvalues) + w = np.asarray(w) + v = np.asarray(v) + self.assertLessEqual( + np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3) + self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v), + tol * np.linalg.norm(a)) + + w_expected, v_expected = np.linalg.eigh(np.asarray(a)) + self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w), + rtol=1e-4) + + def run_eigh_tridiagonal_test(self, alpha, beta): n = alpha.shape[-1] # scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for # this we call the slower numpy.linalg.eigh. @@ -1592,7 +1626,7 @@ class LaxLinalgTest(jtu.JaxTestCase): for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]: alpha = a * np.ones([n], dtype=dtype) beta = b * np.ones([n - 1], dtype=dtype) - self.run_test(alpha, beta) + self.run_eigh_tridiagonal_test(alpha, beta) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_n={n}_dtype={dtype.__name__}", @@ -1602,7 +1636,7 @@ class LaxLinalgTest(jtu.JaxTestCase): def testRandomUniform(self, n, dtype): alpha = jtu.rand_uniform(self.rng())((n,), dtype) beta = jtu.rand_uniform(self.rng())((n - 1,), dtype) - self.run_test(alpha, beta) + self.run_eigh_tridiagonal_test(alpha, beta) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_dtype={dtype.__name__}",