diff --git a/tests/BUILD b/tests/BUILD index 91c5b5a2a..2d31b04d3 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -999,8 +999,7 @@ jax_test( ], enable_configs = [ "gpu_a100", - # TODO(b/337303303): re-enable the test - # "gpu_h100", + "gpu_h100", ], deps = [ "//jax:experimental_sparse",