diff --git a/tests/BUILD b/tests/BUILD index 4f71899f7..5c378715a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -675,6 +675,7 @@ jax_test( srcs = ["nn_test.py"], shard_count = { "tpu": 10, + "gpu": 10, }, )