diff --git a/tests/BUILD b/tests/BUILD index 489af3925..4b36d048a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -731,8 +731,8 @@ jax_test( }, enable_configs = ["cpu_jit_pjit_api_merge"], shard_count = { - "cpu": 40, - "gpu": 40, + "cpu": 50, + "gpu": 50, "tpu": 50, "iree": 10, },