diff --git a/tests/BUILD b/tests/BUILD index 88a8de6dc..d95eb8f12 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -317,9 +317,9 @@ jax_test( name = "lax_control_flow_test", srcs = ["lax_control_flow_test.py"], shard_count = { - "cpu": 10, + "cpu": 20, "gpu": 20, - "tpu": 10, + "tpu": 20, "iree": 10, }, ) @@ -529,9 +529,9 @@ jax_test( name = "pmap_test", srcs = ["pmap_test.py"], shard_count = { - "cpu": 15, + "cpu": 30, "gpu": 30, - "tpu": 15, + "tpu": 30, }, tags = ["multiaccelerator"], deps = [