diff --git a/tests/BUILD b/tests/BUILD index 3fcdc0872..2716631c4 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -526,9 +526,9 @@ jax_test( name = "pmap_test", srcs = ["pmap_test.py"], shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, + "cpu": 15, + "gpu": 30, + "tpu": 15, }, tags = ["multiaccelerator"], deps = [