Remove GPU test with unreasonably large memory footprint.

PiperOrigin-RevId: 695717589
This commit is contained in:
Dan Foreman-Mackey 2024-11-12 07:02:16 -08:00 committed by jax authors
parent 21e98b5ce4
commit a99ccd9341

View File

@ -1450,14 +1450,6 @@ class ScipyLinalgTest(jtu.JaxTestCase):
self.assertAllClose(ls, actual_ls, rtol=5e-6)
self.assertAllClose(us, actual_us)
@jtu.skip_on_devices("cpu", "tpu")
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testBatchedLuOverflow(self):
# see https://github.com/jax-ml/jax/issues/24843
x = self.rng().standard_normal((1500000, 20, 20)).astype(np.float32)
lu, _, _ = lax.linalg.lu(x)
self.assertTrue(jnp.all(lu.std(axis=[1, 2]) > 0.9))
@jtu.skip_on_devices("cpu", "tpu")
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")