Make //jax:tpu_custom_call respect --//jax:build_jaxlib=false

Otherwise jaxlib is partially built and doesn't work properly.
This commit is contained in:
Skye Wanderman-Milne 2023-07-26 15:50:42 +00:00
parent f66d3cf016
commit d0b65f2ab8

View File

@ -572,11 +572,15 @@ pytype_strict_library(
":core",
":jax",
"//jax/_src/lib",
"//jaxlib/mlir:ir",
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:stablehlo_dialect",
] + py_deps("numpy") + py_deps("absl/flags"),
] + select({
":enable_jaxlib_build": [
"//jaxlib/mlir:ir",
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:stablehlo_dialect",
],
"//conditions:default": [],
}) + py_deps("numpy") + py_deps("absl/flags"),
)
pytype_strict_library(