add multiaccelerator tag to test

This commit is contained in:
Cjkkkk 2024-03-15 13:22:45 -07:00
parent 4ac1503bd5
commit 85cbe05f25

View File

@ -1437,7 +1437,10 @@ jax_test(
"tpu",
"cpu",
],
shard_count = 4,
shard_count = {
"gpu": 4,
},
tags = ["multiaccelerator"],
deps = [
"//jax:fused_attention_stablehlo",
],