diff --git a/tests/BUILD b/tests/BUILD index 6eae1fc28..04aeff375 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1171,9 +1171,9 @@ jax_test( name = "shard_map_test", srcs = ["shard_map_test.py"], shard_count = { - "cpu": 30, + "cpu": 50, "gpu": 10, - "tpu": 40, + "tpu": 50, }, tags = ["multiaccelerator"], deps = [