diff --git a/tests/BUILD b/tests/BUILD index 2fbed0601..d1fb4dcc7 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -235,8 +235,8 @@ jax_test( }, enable_configs = [ "gpu_2gpu_shardy", - "tpu_df_2x2_shardy", - "tpu_pf_2x2_shardy", + "tpu_v3_2x2_shardy", + "tpu_v4_2x2_shardy", ], shard_count = { "cpu": 5, @@ -1426,7 +1426,7 @@ jax_test( name = "export_test", srcs = ["export_test.py"], enable_configs = [ - "tpu_df_2x2", + "tpu_v3_2x2", ], tags = [], deps = [