diff --git a/tests/BUILD b/tests/BUILD index 62cb5d41c..f489c0855 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1437,7 +1437,10 @@ jax_test( "tpu", "cpu", ], - shard_count = 4, + shard_count = { + "gpu": 4, + }, + tags = ["multiaccelerator"], deps = [ "//jax:fused_attention_stablehlo", ],