Skip failing tests

This commit is contained in:
Charles Hofer 2025-01-06 16:40:38 +00:00
parent a1734fd31f
commit 307f0db702

View File

@ -327,6 +327,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1E-6)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-8)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product(
l_max=[1, 2, 3, 6],
shape=[(5,), (10,)],
@ -349,6 +350,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
atol=3e-3, check_dtypes=False)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product(
l_max=[3, 4, 6, 32],
shape=[(2,), (3,), (4,), (64,)],
@ -381,6 +383,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
rtol=1e-5, atol=1e-5, check_dtypes=False)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmAccuracy(self):
m = jnp.arange(-3, 3)[:, None]