[JAX] Enable/disable tests that timed out in CI.

Reenable pmap_test since it was recently sped up.

PiperOrigin-RevId: 491650701
This commit is contained in:
Peter Hawkins 2022-11-29 08:58:19 -08:00 committed by jax authors
parent 21bab5efab
commit 7495a9e370

View File

@ -555,15 +555,6 @@ jax_test(
jax_test(
name = "pmap_test",
srcs = ["pmap_test.py"],
backend_tags = {
"tpu": [
"noasan", # Times out.
"nomsan", # Times out.
"nodebug", # Times out.
"notsan", # Times out.
],
"cpu": ["notsan"], # Times out
},
pjrt_c_api_bypass = True,
shard_count = {
"cpu": 30,
@ -694,11 +685,13 @@ jax_test(
"cpu": [
"noasan", # Test times out under asan.
],
# TPU test times out under asan/msan/tsan (b/260710050)
"tpu": [
"noasan",
"nomsan",
"notsan",
"optonly",
], # Test times out under asan/tsan.
],
},
shard_count = {
"cpu": 40,