Include mpmath as a bazel dependency of lax_test.

This test has additional test cases that require mpmath.

PiperOrigin-RevId: 693464078
This commit is contained in:
Peter Hawkins 2024-11-05 13:42:22 -08:00 committed by jax authors
parent 563ecdf2a2
commit ea1e879577
2 changed files with 2 additions and 1 deletions

View File

@ -67,6 +67,7 @@ _py_deps = {
"flatbuffers": ["@pypi_flatbuffers//:pkg"],
"hypothesis": ["@pypi_hypothesis//:pkg"],
"matplotlib": ["@pypi_matplotlib//:pkg"],
"mpmath": [],
"opt_einsum": ["@pypi_opt_einsum//:pkg"],
"pil": ["@pypi_pillow//:pkg"],
"portpicker": ["@pypi_portpicker//:pkg"],

View File

@ -565,7 +565,7 @@ jax_multiplatform_test(
deps = [
"//jax:internal_test_util",
"//jax:lax_reference",
] + py_deps("numpy"),
] + py_deps("numpy") + py_deps("mpmath"),
)
jax_multiplatform_test(