diff --git a/tests/BUILD b/tests/BUILD index 16ce56189..49475f902 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -755,6 +755,13 @@ jax_test( ], "tpu": ["optonly"], }, + # Use fewer cases to prevent timeouts. + backend_variant_args = { + "cpu": ["--jax_num_generated_cases=40"], + "cpu_x32": ["--jax_num_generated_cases=40"], + "cpu_no_jax_array": ["--jax_num_generated_cases=40"], + "gpu": ["--jax_num_generated_cases=40"], + }, shard_count = { "cpu": 50, "gpu": 50,