mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a sort_eigenvalues option to lax.linalg.eigh().
An upcoming change to add a more scalable QDWH-based TPU symmetric eigendecomposition requires that we can obtain the TPU eigenvalues unsorted. The option already exists in XLA, so we simply need to plumb it through to the lax primitive. PiperOrigin-RevId: 448047584
This commit is contained in:
parent
52ad3e6682
commit
590b9161fe
@ -10,6 +10,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
|
|
||||||
## jax 0.3.11 (Unreleased)
|
## jax 0.3.11 (Unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...main).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...main).
|
||||||
|
* Changes
|
||||||
|
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
|
||||||
|
that allows users to opt out of eigenvalue sorting on TPU.
|
||||||
|
|
||||||
## jaxlib 0.3.11 (Unreleased)
|
## jaxlib 0.3.11 (Unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||||
|
@ -95,8 +95,9 @@ def eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True):
|
|||||||
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
|
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
|
||||||
compute_right_eigenvectors=compute_right_eigenvectors)
|
compute_right_eigenvectors=compute_right_eigenvectors)
|
||||||
|
|
||||||
def eigh(x, lower: bool = True, symmetrize_input: bool = True):
|
def eigh(x, lower: bool = True, symmetrize_input: bool = True,
|
||||||
"""Eigendecomposition of a Hermitian matrix.
|
sort_eigenvalues: bool = True, ):
|
||||||
|
r"""Eigendecomposition of a Hermitian matrix.
|
||||||
|
|
||||||
Computes the eigenvectors and eigenvalues of a complex Hermitian or real
|
Computes the eigenvectors and eigenvalues of a complex Hermitian or real
|
||||||
symmetric square matrix.
|
symmetric square matrix.
|
||||||
@ -109,7 +110,10 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True):
|
|||||||
triangle given by ``lower`` is accessed; the other triangle is ignored and
|
triangle given by ``lower`` is accessed; the other triangle is ignored and
|
||||||
not accessed.
|
not accessed.
|
||||||
symmetrize_input: If ``True``, the matrix is symmetrized before the
|
symmetrize_input: If ``True``, the matrix is symmetrized before the
|
||||||
eigendecomposition by computing :math:`\\frac{1}{2}(x + x^H)`.
|
eigendecomposition by computing :math:`\frac{1}{2}(x + x^H)`.
|
||||||
|
sort_eigenvalues: If ``True``, the eigenvalues will be sorted in ascending
|
||||||
|
order. If ``False`` the eigenvalues are returned in an
|
||||||
|
implementation-defined order.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple ``(v, w)``.
|
A tuple ``(v, w)``.
|
||||||
@ -123,7 +127,7 @@ def eigh(x, lower: bool = True, symmetrize_input: bool = True):
|
|||||||
"""
|
"""
|
||||||
if symmetrize_input:
|
if symmetrize_input:
|
||||||
x = symmetrize(x)
|
x = symmetrize(x)
|
||||||
v, w = eigh_p.bind(x, lower=lower)
|
v, w = eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
||||||
return v, w
|
return v, w
|
||||||
|
|
||||||
|
|
||||||
@ -515,17 +519,19 @@ ad.primitive_jvps[eig_p] = eig_jvp_rule
|
|||||||
|
|
||||||
# Symmetric/Hermitian eigendecomposition
|
# Symmetric/Hermitian eigendecomposition
|
||||||
|
|
||||||
def eigh_impl(operand, lower):
|
def _eigh_impl(operand, *, lower, sort_eigenvalues):
|
||||||
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
|
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
|
||||||
|
sort_eigenvalues=sort_eigenvalues)
|
||||||
return v, w
|
return v, w
|
||||||
|
|
||||||
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower):
|
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
|
||||||
|
sort_eigenvalues):
|
||||||
operand_aval, = avals_in
|
operand_aval, = avals_in
|
||||||
if operand_aval.shape[-1] == 0:
|
if operand_aval.shape[-1] == 0:
|
||||||
return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))]
|
return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))]
|
||||||
return xops.Eigh(operand, lower=lower)
|
return xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
||||||
|
|
||||||
def eigh_abstract_eval(operand, lower):
|
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||||
if isinstance(operand, ShapedArray):
|
if isinstance(operand, ShapedArray):
|
||||||
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -541,7 +547,9 @@ def eigh_abstract_eval(operand, lower):
|
|||||||
v, w = operand, operand
|
v, w = operand, operand
|
||||||
return v, w
|
return v, w
|
||||||
|
|
||||||
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower):
|
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
|
||||||
|
sort_eigenvalues):
|
||||||
|
del sort_eigenvalues # The CPU/GPU implementations always sort.
|
||||||
operand_aval, = ctx.avals_in
|
operand_aval, = ctx.avals_in
|
||||||
v_aval, w_aval = ctx.avals_out
|
v_aval, w_aval = ctx.avals_out
|
||||||
batch_dims = operand_aval.shape[:-2]
|
batch_dims = operand_aval.shape[:-2]
|
||||||
@ -562,7 +570,7 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower):
|
|||||||
w, _nan_like_mhlo(w_aval))
|
w, _nan_like_mhlo(w_aval))
|
||||||
return [v, w]
|
return [v, w]
|
||||||
|
|
||||||
def eigh_jvp_rule(primals, tangents, lower):
|
def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues):
|
||||||
# Derivative for eigh in the simplest case of distinct eigenvalues.
|
# Derivative for eigh in the simplest case of distinct eigenvalues.
|
||||||
# This is classic nondegenerate perurbation theory, but also see
|
# This is classic nondegenerate perurbation theory, but also see
|
||||||
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
||||||
@ -574,7 +582,8 @@ def eigh_jvp_rule(primals, tangents, lower):
|
|||||||
a, = primals
|
a, = primals
|
||||||
a_dot, = tangents
|
a_dot, = tangents
|
||||||
|
|
||||||
v, w_real = eigh_p.bind(symmetrize(a), lower=lower)
|
v, w_real = eigh_p.bind(symmetrize(a), lower=lower,
|
||||||
|
sort_eigenvalues=sort_eigenvalues)
|
||||||
|
|
||||||
# for complex numbers we need eigenvalues to be full dtype of v, a:
|
# for complex numbers we need eigenvalues to be full dtype of v, a:
|
||||||
w = w_real.astype(a.dtype)
|
w = w_real.astype(a.dtype)
|
||||||
@ -589,19 +598,19 @@ def eigh_jvp_rule(primals, tangents, lower):
|
|||||||
dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
|
dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
|
||||||
return (v, w_real), (dv, dw)
|
return (v, w_real), (dv, dw)
|
||||||
|
|
||||||
def eigh_batching_rule(batched_args, batch_dims, lower):
|
def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues):
|
||||||
x, = batched_args
|
x, = batched_args
|
||||||
bd, = batch_dims
|
bd, = batch_dims
|
||||||
x = batching.moveaxis(x, bd, 0)
|
x = batching.moveaxis(x, bd, 0)
|
||||||
return eigh_p.bind(x, lower=lower), (0, 0)
|
return eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues), (0, 0)
|
||||||
|
|
||||||
eigh_p = Primitive('eigh')
|
eigh_p = Primitive('eigh')
|
||||||
eigh_p.multiple_results = True
|
eigh_p.multiple_results = True
|
||||||
eigh_p.def_impl(eigh_impl)
|
eigh_p.def_impl(_eigh_impl)
|
||||||
eigh_p.def_abstract_eval(eigh_abstract_eval)
|
eigh_p.def_abstract_eval(_eigh_abstract_eval)
|
||||||
xla.register_translation(eigh_p, _eigh_translation_rule)
|
xla.register_translation(eigh_p, _eigh_translation_rule)
|
||||||
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
|
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
|
||||||
batching.primitive_batchers[eigh_p] = eigh_batching_rule
|
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
|
||||||
|
|
||||||
mlir.register_lowering(
|
mlir.register_lowering(
|
||||||
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo),
|
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo),
|
||||||
|
@ -2519,7 +2519,9 @@ def _eig(operand: TfVal, compute_left_eigenvectors: bool,
|
|||||||
tf_impl[lax.linalg.eig_p] = _eig
|
tf_impl[lax.linalg.eig_p] = _eig
|
||||||
|
|
||||||
|
|
||||||
def _eigh(operand: TfVal, lower: bool, _in_avals, _out_aval):
|
def _eigh(operand: TfVal, lower: bool, sort_eigenvalues: bool, _in_avals,
|
||||||
|
_out_aval):
|
||||||
|
del sort_eigenvalues
|
||||||
if operand.shape[-1] == 0:
|
if operand.shape[-1] == 0:
|
||||||
v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1]))
|
v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1]))
|
||||||
else:
|
else:
|
||||||
|
@ -321,12 +321,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": "_n={}_lower={}".format(
|
{"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format(
|
||||||
jtu.format_shape_dtype_string((n,n), dtype), lower),
|
jtu.format_shape_dtype_string((n,n), dtype), lower,
|
||||||
|
sort_eigenvalues),
|
||||||
"n": n, "dtype": dtype, "lower": lower}
|
"n": n, "dtype": dtype, "lower": lower}
|
||||||
for n in [0, 4, 5, 50]
|
for n in [0, 4, 5, 50]
|
||||||
for dtype in float_types + complex_types
|
for dtype in float_types + complex_types
|
||||||
for lower in [True, False]))
|
for lower in [True, False]
|
||||||
|
for sort_eigenvalues in [True, False]))
|
||||||
def testEigh(self, n, dtype, lower):
|
def testEigh(self, n, dtype, lower):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
tol = 1e-3
|
tol = 1e-3
|
||||||
@ -1565,8 +1567,40 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
|
|
||||||
class LaxLinalgTest(jtu.JaxTestCase):
|
class LaxLinalgTest(jtu.JaxTestCase):
|
||||||
|
"""Tests for lax.linalg primitives."""
|
||||||
|
|
||||||
def run_test(self, alpha, beta):
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{"testcase_name": "_n={}_lower={}_sort_eigenvalues={}".format(
|
||||||
|
jtu.format_shape_dtype_string((n,n), dtype), lower,
|
||||||
|
sort_eigenvalues),
|
||||||
|
"n": n, "dtype": dtype, "lower": lower,
|
||||||
|
"sort_eigenvalues": sort_eigenvalues}
|
||||||
|
for n in [0, 4, 5, 50]
|
||||||
|
for dtype in float_types + complex_types
|
||||||
|
for lower in [True, False]
|
||||||
|
for sort_eigenvalues in [True, False]))
|
||||||
|
def testEigh(self, n, dtype, lower, sort_eigenvalues):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
tol = 1e-3
|
||||||
|
args_maker = lambda: [rng((n, n), dtype)]
|
||||||
|
|
||||||
|
a, = args_maker()
|
||||||
|
a = (a + np.conj(a.T)) / 2
|
||||||
|
v, w = lax.linalg.eigh(np.tril(a) if lower else np.triu(a),
|
||||||
|
lower=lower, symmetrize_input=False,
|
||||||
|
sort_eigenvalues=sort_eigenvalues)
|
||||||
|
w = np.asarray(w)
|
||||||
|
v = np.asarray(v)
|
||||||
|
self.assertLessEqual(
|
||||||
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3)
|
||||||
|
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
|
||||||
|
tol * np.linalg.norm(a))
|
||||||
|
|
||||||
|
w_expected, v_expected = np.linalg.eigh(np.asarray(a))
|
||||||
|
self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w),
|
||||||
|
rtol=1e-4)
|
||||||
|
|
||||||
|
def run_eigh_tridiagonal_test(self, alpha, beta):
|
||||||
n = alpha.shape[-1]
|
n = alpha.shape[-1]
|
||||||
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
|
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
|
||||||
# this we call the slower numpy.linalg.eigh.
|
# this we call the slower numpy.linalg.eigh.
|
||||||
@ -1592,7 +1626,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
|||||||
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
|
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
|
||||||
alpha = a * np.ones([n], dtype=dtype)
|
alpha = a * np.ones([n], dtype=dtype)
|
||||||
beta = b * np.ones([n - 1], dtype=dtype)
|
beta = b * np.ones([n - 1], dtype=dtype)
|
||||||
self.run_test(alpha, beta)
|
self.run_eigh_tridiagonal_test(alpha, beta)
|
||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": f"_n={n}_dtype={dtype.__name__}",
|
{"testcase_name": f"_n={n}_dtype={dtype.__name__}",
|
||||||
@ -1602,7 +1636,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
|||||||
def testRandomUniform(self, n, dtype):
|
def testRandomUniform(self, n, dtype):
|
||||||
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
|
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
|
||||||
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
|
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
|
||||||
self.run_test(alpha, beta)
|
self.run_eigh_tridiagonal_test(alpha, beta)
|
||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": f"_dtype={dtype.__name__}",
|
{"testcase_name": f"_dtype={dtype.__name__}",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user