Fix rank promotion error in JVP of batched eigh.

PiperOrigin-RevId: 730475017
This commit is contained in:
Dan Foreman-Mackey 2025-02-24 09:08:07 -08:00 committed by jax authors
parent 9d421c9149
commit 6bd99207d5
2 changed files with 7 additions and 1 deletions

View File

@ -1208,7 +1208,8 @@ def _eigh_jvp_rule(
w = w_real.astype(a.dtype)
eye_n = lax_internal._eye(a.dtype, (n, n))
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
Fmat = lax.integer_pow(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis], -1) - eye_n
with config.numpy_rank_promotion("allow"):
Fmat = lax.integer_pow(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis], -1) - eye_n
# eigh impl doesn't support batch dims, but future-proof the grad.
dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
precision=lax.Precision.HIGHEST)

View File

@ -614,6 +614,11 @@ class NumpyLinalgTest(jtu.JaxTestCase):
jtu.assert_dot_precision(
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.eigh), (a,), (a,))
def testEighGradRankPromotion(self):
rng = jtu.rand_default(self.rng())
a = rng((10, 3, 3), np.float32)
jvp(jnp.linalg.eigh, (a,), (a,)) # doesn't crash
@jtu.sample_product(
shape=[(1, 1), (4, 4), (5, 5), (300, 300)],
dtype=float_types + complex_types,