Refactor BUILD files to avoid individually naming Python dependencies.

Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
This commit is contained in:
Peter Hawkins 2022-08-05 07:48:40 -07:00 committed by jax authors
parent f0b6478b3e
commit b865111996
5 changed files with 25 additions and 36 deletions

View File

@ -17,15 +17,12 @@
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load(
"//jaxlib:jax.bzl",
"absl_logging_py_deps",
"absl_testing_py_deps",
"jax_extra_deps",
"jax_internal_packages",
"jax_test_util_visibility",
"numpy_py_deps",
"py_deps",
"py_library_providing_imports_info",
"pytype_library",
"scipy_py_deps",
)
licenses(["notice"])
@ -73,7 +70,7 @@ py_library(
] + jax_test_util_visibility,
deps = [
":jax",
] + absl_testing_py_deps + numpy_py_deps,
] + py_deps("absl/testing") + py_deps("numpy"),
)
py_library_providing_imports_info(
@ -118,7 +115,7 @@ py_library_providing_imports_info(
":enable_jaxlib_build": [":jaxlib_deps"],
"//conditions:default": [],
}) +
numpy_py_deps + scipy_py_deps + jax_extra_deps,
py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
)
py_library(
@ -137,7 +134,7 @@ py_library_providing_imports_info(
visibility = ["//visibility:public"],
deps = [
":jax",
] + absl_logging_py_deps + numpy_py_deps,
] + py_deps("absl/logging") + py_deps("numpy"),
)
pytype_library(

View File

@ -15,8 +15,7 @@
load(
"//jaxlib:jax.bzl",
"jax2tf_deps",
"numpy_py_deps",
"tensorflow_py_deps",
"py_deps",
)
licenses(["notice"]) # Apache 2
@ -44,5 +43,5 @@ py_library(
srcs_version = "PY3",
deps = [
"//jax",
] + numpy_py_deps + tensorflow_py_deps + jax2tf_deps,
] + py_deps("numpy") + py_deps("tensorflow") + jax2tf_deps,
)

View File

@ -14,7 +14,7 @@
load(
"//jaxlib:jax.bzl",
"tensorflow_py_deps",
"py_deps",
)
licenses(["notice"])
@ -39,5 +39,5 @@ py_library(
deps = [
"//jax",
"//jax/experimental/jax2tf",
] + tensorflow_py_deps,
] + py_deps("tensorflow"),
)

View File

@ -36,14 +36,13 @@ jax_internal_packages = []
jax_test_util_visibility = []
loops_visibility = []
absl_logging_py_deps = []
absl_testing_py_deps = []
cloudpickle_py_deps = []
numpy_py_deps = []
pil_py_deps = []
portpicker_py_deps = []
scipy_py_deps = []
tensorflow_py_deps = []
def py_deps(_package):
"""Returns the Bazel deps for Python package `package`."""
# We assume the user has installed all dependencies in their Python environment.
# This indirection exists because in Google's internal build we build
# dependencies from source with Bazel, but that's not something most people would want.
return []
jax_extra_deps = []
jax2tf_deps = []

View File

@ -14,16 +14,11 @@
load(
"//jaxlib:jax.bzl",
"absl_logging_py_deps",
"cloudpickle_py_deps",
"jax_generate_backend_suites",
"jax_test",
"jax_test_file_visibility",
"pil_py_deps",
"portpicker_py_deps",
"py_deps",
"pytype_library",
"scipy_py_deps",
"tensorflow_py_deps",
)
licenses(["notice"]) # Apache 2
@ -53,7 +48,7 @@ jax_test(
name = "array_interoperability_test",
srcs = ["array_interoperability_test.py"],
disable_backends = ["tpu"],
deps = tensorflow_py_deps,
deps = py_deps("tensorflow"),
)
jax_test(
@ -111,7 +106,7 @@ py_test(
deps = [
"//jax",
"//jax:test_util",
] + portpicker_py_deps,
] + py_deps("portpicker"),
)
jax_test(
@ -155,8 +150,7 @@ jax_test(
},
deps = [
"//jax:experimental_sparse",
"//third_party/py/matplotlib",
],
] + py_deps("matplotlib"),
)
jax_test(
@ -241,7 +235,7 @@ jax_test(
"tpu": 10,
"iree": 10,
},
deps = pil_py_deps + tensorflow_py_deps,
deps = py_deps("pil") + py_deps("tensorflow"),
)
jax_test(
@ -290,7 +284,7 @@ py_test(
"//jax:test_util",
"//jax/experimental/jax2tf",
"//jax/tools:jax_to_ir",
] + tensorflow_py_deps,
] + py_deps("tensorflow"),
)
jax_test(
@ -519,7 +513,7 @@ jax_test(
srcs = ["pickle_test.py"],
deps = [
"//jax:experimental",
] + cloudpickle_py_deps,
] + py_deps("cloudpickle"),
)
jax_test(
@ -674,7 +668,7 @@ jax_test(
},
deps = [
"//jax:experimental_sparse",
] + scipy_py_deps,
] + py_deps("scipy"),
)
jax_test(
@ -750,7 +744,7 @@ py_test(
deps = [
"//jax",
"//jax:test_util",
] + absl_logging_py_deps,
] + py_deps("absl/logging"),
)
py_test(
@ -820,7 +814,7 @@ jax_test(
deps = [
"//jax:experimental_host_callback",
"//jax:ode",
] + tensorflow_py_deps,
] + py_deps("tensorflow"),
)
jax_test(