diff --git a/tests/BUILD b/tests/BUILD index 1e64d98db..21b26aa64 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -416,6 +416,9 @@ jax_test( backend_tags = { "cpu": ["notsan"], # Test times out. }, + disable_configs = [ + "gpu_h100", # TODO(b/328050517): ptxas compilation failures + ], shard_count = { "cpu": 40, "gpu": 40, @@ -731,6 +734,9 @@ jax_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], disable_backends = ["tpu"], + disable_configs = [ + "gpu_h100", # Pytorch H100 build times out in Google's CI. + ], tags = [ # PyTorch leaks dlpack metadata https://github.com/pytorch/pytorch/issues/117058, and # compilation times out on CPU. @@ -858,6 +864,9 @@ jax_test( "optonly", ], }, + disable_configs = [ + "gpu_h100", # TODO(phawkins): numerical failure on h100 + ], shard_count = { "cpu": 40, "gpu": 40, @@ -938,6 +947,9 @@ jax_test( "cpu_x32": ["--jax_num_generated_cases=40"], "gpu": ["--jax_num_generated_cases=40"], }, + disable_configs = [ + "gpu_h100", # TODO(b/328050517): ptxas compilation failures + ], shard_count = { "cpu": 50, "gpu": 50, @@ -967,6 +979,9 @@ jax_test( "noasan", # Times out under asan. ], }, + disable_configs = [ + "gpu_h100", # TODO(b/328050517): ptxas compilation failures + ], shard_count = { "cpu": 5, "gpu": 20, @@ -1223,6 +1238,9 @@ jax_test( backend_variant_args = { "tpu_pjrt_c_api": ["--jax_num_generated_cases=1"], }, + disable_configs = [ + "gpu_h100", # TODO(b/328050517): ptxas compilation failures + ], enable_configs = [ "gpu", "cpu",