diff --git a/tests/BUILD b/tests/BUILD index e9b2988a9..1c052c6f2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -318,7 +318,7 @@ jax_test( srcs = ["lax_control_flow_test.py"], shard_count = { "cpu": 30, - "gpu": 30, + "gpu": 40, "tpu": 30, "iree": 10, },