diff --git a/tests/BUILD b/tests/BUILD index 0ca66be5f..1bf103875 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -766,13 +766,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], - # The following cases are disabled because they time out in Google's CI, mostly because the - # CUDA kernels in Torch take a very long time to compile. - disable_configs = [ - "gpu_p100", # Pytorch P100 build times out in Google's CI. - "gpu_a100", # Pytorch A100 build times out in Google's CI. - "gpu_h100", # Pytorch H100 build times out in Google's CI. - ], enable_backends = [ "cpu", "gpu", @@ -1440,10 +1433,8 @@ jax_multiplatform_test( "tpu": 20, }, tags = [ - "noasan", # Times out, TODO(b/314760446): test failures on Sapphire Rapids. + "noasan", # Times out "nodebug", # Times out. - "nomsan", # TODO(b/314760446): test failures on Sapphire Rapids. - "notsan", # TODO(b/314760446): test failures on Sapphire Rapids. ], deps = [ "//jax:internal_test_harnesses",