diff --git a/tests/BUILD b/tests/BUILD index 83cbc368f..ebc02e05f 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -269,6 +269,9 @@ jax_test( jax_test( name = "layout_test", srcs = ["layout_test.py"], + backend_tags = { + "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. + }, tags = ["multiaccelerator"], ) @@ -318,6 +321,9 @@ jax_test( jax_test( name = "array_test", srcs = ["array_test.py"], + backend_tags = { + "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. + }, tags = ["multiaccelerator"], deps = [ "//jax:experimental",