diff --git a/tests/BUILD b/tests/BUILD index dd80984b9..e422722b3 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -700,6 +700,7 @@ jax_test( srcs = ["sparsify_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], shard_count = { + "cpu": 5, "gpu": 20, "tpu": 10, },