diff --git a/tests/BUILD b/tests/BUILD index 49394523b..64b41fc58 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -124,6 +124,7 @@ jax_test( }, shard_count = { "tpu": 20, + "cpu": 20, }, )