mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a sort_eigenvalues option to lax.linalg.eigh().
An upcoming change to add a more scalable QDWH-based TPU symmetric eigendecomposition requires that we can obtain the TPU eigenvalues unsorted. The option already exists in XLA, so we simply need to plumb it through to the lax primitive. PiperOrigin-RevId: 448047584
This commit is contained in:
parent
52ad3e6682
commit
590b9161fe
@ -10,6 +10,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jax 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...main).
|
||||
* Changes
|
||||
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
|
||||
that allows users to opt out of eigenvalue sorting on TPU.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
@ -95,8 +95,9 @@ def eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True):
|
||||
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
|
||||
compute_right_eigenvectors=compute_right_eigenvectors)
|
||||
|
||||
def eigh(x, lower: bool = True, symmetrize_input: bool = True):
|
||||
"""Eigendecomposition of a Hermitian matrix.
|
||||
def eigh(x, lower: bool = True, symmetrize_input: bool = True,
|
||||
sort_eigenvalues: bool = True, ):
|
||||
r"""Eigendecomposition of a Hermitian matrix.
|
||||
|
||||
Computes the eigenvectors and eigenvalues of a complex Hermitian or real
|
||||
symmetric square matrix.
|
||||
@ -109,7 +110,10 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True):
|
||||
triangle given by ``lower`` is accessed; the other triangle is ignored and
|
||||
not accessed.
|
||||
symmetrize_input: If ``True``, the matrix is symmetrized before the
|
||||
eigendecomposition by computing :math:`\\frac{1}{2}(x + x^H)`.
|
||||
eigendecomposition by computing :math:`\frac{1}{2}(x + x^H)`.
|
||||
sort_eigenvalues: If ``True``, the eigenvalues will be sorted in ascending
|
||||
order. If ``False`` the eigenvalues are returned in an
|
||||
implementation-defined order.
|
||||
|
||||
Returns:
|
||||
A tuple ``(v, w)``.
|
||||
@ -123,7 +127,7 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True):
|
||||
"""
|
||||
if symmetrize_input:
|
||||
x = symmetrize(x)
|
||||
v, w = eigh_p.bind(x, lower=lower)
|
||||
v, w = eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
||||
return v, w
|
||||
|
||||
|
||||
@ -515,17 +519,19 @@ ad.primitive_jvps[eig_p] = eig_jvp_rule
|
||||
|
||||
# Symmetric/Hermitian eigendecomposition
|
||||
|
||||
def eigh_impl(operand, lower):
|
||||
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
|
||||
def _eigh_impl(operand, *, lower, sort_eigenvalues):
|
||||
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
return v, w
|
||||
|
||||
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower):
|
||||
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
|
||||
sort_eigenvalues):
|
||||
operand_aval, = avals_in
|
||||
if operand_aval.shape[-1] == 0:
|
||||
return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))]
|
||||
return xops.Eigh(operand, lower=lower)
|
||||
return xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
||||
|
||||
def eigh_abstract_eval(operand, lower):
|
||||
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
if isinstance(operand, ShapedArray):
|
||||
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
||||
raise ValueError(
|
||||
@ -541,7 +547,9 @@ def eigh_abstract_eval(operand, lower):
|
||||
v, w = operand, operand
|
||||
return v, w
|
||||
|
||||
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower):
|
||||
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
|
||||
sort_eigenvalues):
|
||||
del sort_eigenvalues # The CPU/GPU implementations always sort.
|
||||
operand_aval, = ctx.avals_in
|
||||
v_aval, w_aval = ctx.avals_out
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
@ -562,7 +570,7 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower):
|
||||
w, _nan_like_mhlo(w_aval))
|
||||
return [v, w]
|
||||
|
||||
def eigh_jvp_rule(primals, tangents, lower):
|
||||
def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues):
|
||||
# Derivative for eigh in the simplest case of distinct eigenvalues.
|
||||
# This is classic nondegenerate perurbation theory, but also see
|
||||
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
||||
@ -574,7 +582,8 @@ def eigh_jvp_rule(primals, tangents, lower):
|
||||
a, = primals
|
||||
a_dot, = tangents
|
||||
|
||||
v, w_real = eigh_p.bind(symmetrize(a), lower=lower)
|
||||
v, w_real = eigh_p.bind(symmetrize(a), lower=lower,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
|
||||
# for complex numbers we need eigenvalues to be full dtype of v, a:
|
||||
w = w_real.astype(a.dtype)
|
||||
@ -589,19 +598,19 @@ def eigh_jvp_rule(primals, tangents, lower):
|
||||
dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
|
||||
return (v, w_real), (dv, dw)
|
||||
|
||||
def eigh_batching_rule(batched_args, batch_dims, lower):
|
||||
def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return eigh_p.bind(x, lower=lower), (0, 0)
|
||||
return eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues), (0, 0)
|
||||
|
||||
eigh_p = Primitive('eigh')
|
||||
eigh_p.multiple_results = True
|
||||
eigh_p.def_impl(eigh_impl)
|
||||
eigh_p.def_abstract_eval(eigh_abstract_eval)
|
||||
eigh_p.def_impl(_eigh_impl)
|
||||
eigh_p.def_abstract_eval(_eigh_abstract_eval)
|
||||
xla.register_translation(eigh_p, _eigh_translation_rule)
|
||||
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
|
||||
batching.primitive_batchers[eigh_p] = eigh_batching_rule
|
||||
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
|
||||
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
|
||||
|
||||
mlir.register_lowering(
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo),
|
||||
|
@ -2519,7 +2519,9 @@ def _eig(operand: TfVal, compute_left_eigenvectors: bool,
|
||||
tf_impl[lax.linalg.eig_p] = _eig
|
||||
|
||||
|
||||
def _eigh(operand: TfVal, lower: bool, _in_avals, _out_aval):
|
||||
def _eigh(operand: TfVal, lower: bool, sort_eigenvalues: bool, _in_avals,
|
||||
_out_aval):
|
||||
del sort_eigenvalues
|
||||
if operand.shape[-1] == 0:
|
||||
v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1]))
|
||||
else:
|
||||
|
@ -321,12 +321,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_n={}_lower={}".format(
|
||||
jtu.format_shape_dtype_string((n,n), dtype), lower),
|
||||
{"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format(
|
||||
jtu.format_shape_dtype_string((n,n), dtype), lower,
|
||||
sort_eigenvalues),
|
||||
"n": n, "dtype": dtype, "lower": lower}
|
||||
for n in [0, 4, 5, 50]
|
||||
for dtype in float_types + complex_types
|
||||
for lower in [True, False]))
|
||||
for lower in [True, False]
|
||||
for sort_eigenvalues in [True, False]))
|
||||
def testEigh(self, n, dtype, lower):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = 1e-3
|
||||
@ -1565,8 +1567,40 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
class LaxLinalgTest(jtu.JaxTestCase):
|
||||
"""Tests for lax.linalg primitives."""
|
||||
|
||||
def run_test(self, alpha, beta):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format(
|
||||
jtu.format_shape_dtype_string((n,n), dtype), lower,
|
||||
sort_eigenvalues),
|
||||
"n": n, "dtype": dtype, "lower": lower,
|
||||
"sort_eigenvalues": sort_eigenvalues}
|
||||
for n in [0, 4, 5, 50]
|
||||
for dtype in float_types + complex_types
|
||||
for lower in [True, False]
|
||||
for sort_eigenvalues in [True, False]))
|
||||
def testEigh(self, n, dtype, lower, sort_eigenvalues):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = 1e-3
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
a, = args_maker()
|
||||
a = (a + np.conj(a.T)) / 2
|
||||
v, w = lax.linalg.eigh(np.tril(a) if lower else np.triu(a),
|
||||
lower=lower, symmetrize_input=False,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
w = np.asarray(w)
|
||||
v = np.asarray(v)
|
||||
self.assertLessEqual(
|
||||
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3)
|
||||
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
|
||||
tol * np.linalg.norm(a))
|
||||
|
||||
w_expected, v_expected = np.linalg.eigh(np.asarray(a))
|
||||
self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w),
|
||||
rtol=1e-4)
|
||||
|
||||
def run_eigh_tridiagonal_test(self, alpha, beta):
|
||||
n = alpha.shape[-1]
|
||||
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
|
||||
# this we call the slower numpy.linalg.eigh.
|
||||
@ -1592,7 +1626,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
|
||||
alpha = a * np.ones([n], dtype=dtype)
|
||||
beta = b * np.ones([n - 1], dtype=dtype)
|
||||
self.run_test(alpha, beta)
|
||||
self.run_eigh_tridiagonal_test(alpha, beta)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_n={n}_dtype={dtype.__name__}",
|
||||
@ -1602,7 +1636,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
def testRandomUniform(self, n, dtype):
|
||||
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
|
||||
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
|
||||
self.run_test(alpha, beta)
|
||||
self.run_eigh_tridiagonal_test(alpha, beta)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={dtype.__name__}",
|
||||
|
Loading…
x
Reference in New Issue
Block a user