diff --git a/tests/BUILD b/tests/BUILD index 49dbf0512..0cc6ed6d9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -997,6 +997,7 @@ jax_multiplatform_test( "cpu": ["--jax_num_generated_cases=40"], "cpu_x32": ["--jax_num_generated_cases=40"], "gpu": ["--jax_num_generated_cases=40"], + "tpu": ["--jax_num_generated_cases=40"], }, shard_count = { "cpu": 50,