diff --git a/tests/BUILD b/tests/BUILD index 0df288d1f..29776452d 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1307,13 +1307,6 @@ jax_test( jax_test( name = "shape_poly_test", srcs = ["shape_poly_test.py"], - backend_tags = { - "tpu": [ - "noasan", - "nomsan", - "notsan", - ], # Times out - }, disable_configs = [ "gpu_a100", # TODO(b/269593297): matmul precision issues ], @@ -1326,6 +1319,11 @@ jax_test( "gpu": 4, "tpu": 4, }, + tags = [ + "noasan", # Times out + "nomsan", # Times out + "notsan", # Times out + ], deps = [ "//jax:internal_test_harnesses", "//jax/experimental/export",