mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix rank promotion error in JVP of batched eigh.
PiperOrigin-RevId: 730475017
This commit is contained in:
parent
9d421c9149
commit
6bd99207d5
@ -1208,7 +1208,8 @@ def _eigh_jvp_rule(
|
||||
w = w_real.astype(a.dtype)
|
||||
eye_n = lax_internal._eye(a.dtype, (n, n))
|
||||
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
|
||||
Fmat = lax.integer_pow(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis], -1) - eye_n
|
||||
with config.numpy_rank_promotion("allow"):
|
||||
Fmat = lax.integer_pow(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis], -1) - eye_n
|
||||
# eigh impl doesn't support batch dims, but future-proof the grad.
|
||||
dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
|
||||
precision=lax.Precision.HIGHEST)
|
||||
|
@ -614,6 +614,11 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
jtu.assert_dot_precision(
|
||||
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.eigh), (a,), (a,))
|
||||
|
||||
def testEighGradRankPromotion(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
a = rng((10, 3, 3), np.float32)
|
||||
jvp(jnp.linalg.eigh, (a,), (a,)) # doesn't crash
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1, 1), (4, 4), (5, 5), (300, 300)],
|
||||
dtype=float_types + complex_types,
|
||||
|
Loading…
x
Reference in New Issue
Block a user