mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Replace xla_extension symlink with genrule that makes xla_extension module accessible from jax._src.lib.
The runfiles of the original targets were lost when the symlinked files were used. This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved. PiperOrigin-RevId: 632508121
This commit is contained in:
parent
c2d78abfa3
commit
0267ed0ba9
21
jaxlib/BUILD
21
jaxlib/BUILD
@ -18,7 +18,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_mosaic_gpu",
|
||||
"if_windows",
|
||||
"py_library_providing_imports_info",
|
||||
"pybind_extension",
|
||||
"pytype_library",
|
||||
@ -31,6 +30,13 @@ package(
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
)
|
||||
|
||||
# This makes xla_extension module accessible from jax._src.lib.
|
||||
genrule(
|
||||
name = "xla_extension_py",
|
||||
outs = ["xla_extension.py"],
|
||||
cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@",
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "jaxlib",
|
||||
srcs = [
|
||||
@ -47,8 +53,8 @@ py_library_providing_imports_info(
|
||||
"lapack.py",
|
||||
":version",
|
||||
":xla_client",
|
||||
":xla_extension_py",
|
||||
],
|
||||
data = [":xla_extension"],
|
||||
lib_rule = pytype_library,
|
||||
deps = [
|
||||
":cpu_feature_guard",
|
||||
@ -71,6 +77,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
"//jaxlib/mosaic",
|
||||
"//jaxlib/triton",
|
||||
"@xla//xla/python:xla_extension",
|
||||
] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
|
||||
)
|
||||
|
||||
@ -88,16 +95,6 @@ symlink_files(
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
symlink_files(
|
||||
name = "xla_extension",
|
||||
srcs = if_windows(
|
||||
["@xla//xla/python:xla_extension.pyd"],
|
||||
["@xla//xla/python:xla_extension.so"],
|
||||
),
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"README.md",
|
||||
"setup.py",
|
||||
|
@ -15,6 +15,7 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
@ -70,9 +71,17 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
""")
|
||||
python = sys.executable
|
||||
assert "python" in python
|
||||
env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"}
|
||||
if os.getenv("PYTHONPATH"):
|
||||
env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH")
|
||||
if os.getenv("LD_LIBRARY_PATH"):
|
||||
env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH")
|
||||
# Make sure C++ logging is at default level for the test process.
|
||||
proc = subprocess.run([python, "-c", program], capture_output=True,
|
||||
env={"TF_CPP_MIN_LOG_LEVEL": "1"})
|
||||
proc = subprocess.run(
|
||||
[python, "-c", program],
|
||||
capture_output=True,
|
||||
env=env_variables,
|
||||
)
|
||||
|
||||
lines = proc.stdout.split(b"\n")
|
||||
lines.extend(proc.stderr.split(b"\n"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user