mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Make //jax:tpu_custom_call respect --//jax:build_jaxlib=false
Otherwise jaxlib is partially built and doesn't work properly.
This commit is contained in:
parent
f66d3cf016
commit
d0b65f2ab8
14
jax/BUILD
14
jax/BUILD
@ -572,11 +572,15 @@ pytype_strict_library(
|
||||
":core",
|
||||
":jax",
|
||||
"//jax/_src/lib",
|
||||
"//jaxlib/mlir:ir",
|
||||
"//jaxlib/mlir:mhlo_dialect",
|
||||
"//jaxlib/mlir:pass_manager",
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
] + py_deps("numpy") + py_deps("absl/flags"),
|
||||
] + select({
|
||||
":enable_jaxlib_build": [
|
||||
"//jaxlib/mlir:ir",
|
||||
"//jaxlib/mlir:mhlo_dialect",
|
||||
"//jaxlib/mlir:pass_manager",
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + py_deps("numpy") + py_deps("absl/flags"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
|
Loading…
x
Reference in New Issue
Block a user