diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index 38d6cccd8..77ee057d2 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -106,6 +106,9 @@ class LaxBackedScipyFftTests(jtu.JaxTestCase): dtype=real_dtypes, norm=[None, 'ortho'], ) + # TODO(phawkins): these tests are failing on T4 GPUs in CI with a + # CUDA_ERROR_ILLEGAL_ADDRESS. + @jtu.skip_on_devices("cuda") def testiDctn(self, shape, dtype, s, axes, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),)