diff --git a/tests/BUILD b/tests/BUILD index b0609eb43..6c990df86 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -375,7 +375,7 @@ jax_test( }, shard_count = { "cpu": 10, - "gpu": 20, + "gpu": 10, "tpu": 10, "iree": 10, }, @@ -655,7 +655,7 @@ jax_test( }, shard_count = { "cpu": 10, - "gpu": 10, + "gpu": 20, "tpu": 10, "iree": 10, },