diff --git a/tests/BUILD b/tests/BUILD index 27ecbe957..cc247ac76 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -635,7 +635,7 @@ jax_test( shard_count = { "cpu": 30, "gpu": 30, - "tpu": 30, + "tpu": 40, "iree": 30, }, tags = ["noasan"], # Times out @@ -652,7 +652,8 @@ jax_test( "notsan", ], "tpu": [ - "noasan", # Times out under asan/tsan. + "noasan", # Times out under asan/msan/tsan. + "nomsan", "notsan", "optonly", ], @@ -719,6 +720,9 @@ jax_test( jax_test( name = "scipy_stats_test", srcs = ["scipy_stats_test.py"], + backend_tags = { + "tpu": ["nomsan"], # Times out + }, shard_count = { "cpu": 40, "gpu": 30, @@ -1029,7 +1033,7 @@ jax_test( shard_count = { "cpu": 20, "gpu": 10, - "tpu": 10, + "tpu": 20, }, ) @@ -1037,7 +1041,7 @@ jax_test( name = "shard_map_test", srcs = ["shard_map_test.py"], shard_count = { - "cpu": 20, + "cpu": 30, "gpu": 10, "tpu": 10, },