diff --git a/tests/BUILD b/tests/BUILD index 1f18042a0..7a65df699 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -457,7 +457,7 @@ jax_test( ], }, shard_count = { - "cpu": 20, + "cpu": 25, "gpu": 40, "tpu": 10, "iree": 20, @@ -683,6 +683,7 @@ jax_test( args = ["--jax_bcoo_cusparse_lowering=true"], shard_count = { "gpu": 20, + "tpu": 10, }, deps = [ "//jax:experimental_sparse",