mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules. Fix build failure with dangling matplotlib reference. PiperOrigin-RevId: 465562141
This commit is contained in:
parent
f0b6478b3e
commit
b865111996
11
jax/BUILD
11
jax/BUILD
@ -17,15 +17,12 @@
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"absl_logging_py_deps",
|
||||
"absl_testing_py_deps",
|
||||
"jax_extra_deps",
|
||||
"jax_internal_packages",
|
||||
"jax_test_util_visibility",
|
||||
"numpy_py_deps",
|
||||
"py_deps",
|
||||
"py_library_providing_imports_info",
|
||||
"pytype_library",
|
||||
"scipy_py_deps",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
@ -73,7 +70,7 @@ py_library(
|
||||
] + jax_test_util_visibility,
|
||||
deps = [
|
||||
":jax",
|
||||
] + absl_testing_py_deps + numpy_py_deps,
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
@ -118,7 +115,7 @@ py_library_providing_imports_info(
|
||||
":enable_jaxlib_build": [":jaxlib_deps"],
|
||||
"//conditions:default": [],
|
||||
}) +
|
||||
numpy_py_deps + scipy_py_deps + jax_extra_deps,
|
||||
py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -137,7 +134,7 @@ py_library_providing_imports_info(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":jax",
|
||||
] + absl_logging_py_deps + numpy_py_deps,
|
||||
] + py_deps("absl/logging") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
|
@ -15,8 +15,7 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax2tf_deps",
|
||||
"numpy_py_deps",
|
||||
"tensorflow_py_deps",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
@ -44,5 +43,5 @@ py_library(
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//jax",
|
||||
] + numpy_py_deps + tensorflow_py_deps + jax2tf_deps,
|
||||
] + py_deps("numpy") + py_deps("tensorflow") + jax2tf_deps,
|
||||
)
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"tensorflow_py_deps",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
@ -39,5 +39,5 @@ py_library(
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax/experimental/jax2tf",
|
||||
] + tensorflow_py_deps,
|
||||
] + py_deps("tensorflow"),
|
||||
)
|
||||
|
@ -36,14 +36,13 @@ jax_internal_packages = []
|
||||
jax_test_util_visibility = []
|
||||
loops_visibility = []
|
||||
|
||||
absl_logging_py_deps = []
|
||||
absl_testing_py_deps = []
|
||||
cloudpickle_py_deps = []
|
||||
numpy_py_deps = []
|
||||
pil_py_deps = []
|
||||
portpicker_py_deps = []
|
||||
scipy_py_deps = []
|
||||
tensorflow_py_deps = []
|
||||
def py_deps(_package):
|
||||
"""Returns the Bazel deps for Python package `package`."""
|
||||
|
||||
# We assume the user has installed all dependencies in their Python environment.
|
||||
# This indirection exists because in Google's internal build we build
|
||||
# dependencies from source with Bazel, but that's not something most people would want.
|
||||
return []
|
||||
|
||||
jax_extra_deps = []
|
||||
jax2tf_deps = []
|
||||
|
26
tests/BUILD
26
tests/BUILD
@ -14,16 +14,11 @@
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"absl_logging_py_deps",
|
||||
"cloudpickle_py_deps",
|
||||
"jax_generate_backend_suites",
|
||||
"jax_test",
|
||||
"jax_test_file_visibility",
|
||||
"pil_py_deps",
|
||||
"portpicker_py_deps",
|
||||
"py_deps",
|
||||
"pytype_library",
|
||||
"scipy_py_deps",
|
||||
"tensorflow_py_deps",
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
@ -53,7 +48,7 @@ jax_test(
|
||||
name = "array_interoperability_test",
|
||||
srcs = ["array_interoperability_test.py"],
|
||||
disable_backends = ["tpu"],
|
||||
deps = tensorflow_py_deps,
|
||||
deps = py_deps("tensorflow"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -111,7 +106,7 @@ py_test(
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
] + portpicker_py_deps,
|
||||
] + py_deps("portpicker"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -155,8 +150,7 @@ jax_test(
|
||||
},
|
||||
deps = [
|
||||
"//jax:experimental_sparse",
|
||||
"//third_party/py/matplotlib",
|
||||
],
|
||||
] + py_deps("matplotlib"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -241,7 +235,7 @@ jax_test(
|
||||
"tpu": 10,
|
||||
"iree": 10,
|
||||
},
|
||||
deps = pil_py_deps + tensorflow_py_deps,
|
||||
deps = py_deps("pil") + py_deps("tensorflow"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -290,7 +284,7 @@ py_test(
|
||||
"//jax:test_util",
|
||||
"//jax/experimental/jax2tf",
|
||||
"//jax/tools:jax_to_ir",
|
||||
] + tensorflow_py_deps,
|
||||
] + py_deps("tensorflow"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -519,7 +513,7 @@ jax_test(
|
||||
srcs = ["pickle_test.py"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
] + cloudpickle_py_deps,
|
||||
] + py_deps("cloudpickle"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -674,7 +668,7 @@ jax_test(
|
||||
},
|
||||
deps = [
|
||||
"//jax:experimental_sparse",
|
||||
] + scipy_py_deps,
|
||||
] + py_deps("scipy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -750,7 +744,7 @@ py_test(
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
] + absl_logging_py_deps,
|
||||
] + py_deps("absl/logging"),
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -820,7 +814,7 @@ jax_test(
|
||||
deps = [
|
||||
"//jax:experimental_host_callback",
|
||||
"//jax:ode",
|
||||
] + tensorflow_py_deps,
|
||||
] + py_deps("tensorflow"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
Loading…
x
Reference in New Issue
Block a user