diff --git a/tests/BUILD b/tests/BUILD index 9c8ca9310..62fb9fb8a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -542,6 +542,21 @@ jax_test( ] + py_deps("numpy"), ) +jax_test( + name = "lax_metal_test", + srcs = ["lax_metal_test.py"], + tags = ["notap"], + disable_backends = [ + "cpu", + "gpu", + "tpu", + ], + deps = [ + "//jax:internal_test_util", + "//jax:lax_reference", + ] + py_deps("numpy"), +) + jax_test( name = "lax_autodiff_test", srcs = ["lax_autodiff_test.py"],