diff --git a/tests/BUILD b/tests/BUILD index adea49cac..b5a99b254 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -688,6 +688,11 @@ jax_test( jax_test( name = "nn_test", srcs = ["nn_test.py"], + backend_tags = { + "gpu": [ + "noasan", # Times out under asan. + ], + }, shard_count = { "cpu": 10, "tpu": 10,