Replace xla_extension symlink with genrule that makes xla_extension module accessible from jax._src.lib.

The runfiles of the original targets were lost when the symlinked files were used.

This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved.

PiperOrigin-RevId: 632508121
This commit is contained in:
jax authors 2024-05-10 08:47:22 -07:00
parent c2d78abfa3
commit 0267ed0ba9
2 changed files with 20 additions and 14 deletions

View File

@ -18,7 +18,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"if_building_mosaic_gpu",
"if_windows",
"py_library_providing_imports_info",
"pybind_extension",
"pytype_library",
@ -31,6 +30,13 @@ package(
default_visibility = ["//:__subpackages__"],
)
# This makes xla_extension module accessible from jax._src.lib.
genrule(
name = "xla_extension_py",
outs = ["xla_extension.py"],
cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@",
)
py_library_providing_imports_info(
name = "jaxlib",
srcs = [
@ -47,8 +53,8 @@ py_library_providing_imports_info(
"lapack.py",
":version",
":xla_client",
":xla_extension_py",
],
data = [":xla_extension"],
lib_rule = pytype_library,
deps = [
":cpu_feature_guard",
@ -71,6 +77,7 @@ py_library_providing_imports_info(
"//jaxlib/mlir:vector_dialect",
"//jaxlib/mosaic",
"//jaxlib/triton",
"@xla//xla/python:xla_extension",
] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
)
@ -88,16 +95,6 @@ symlink_files(
flatten = True,
)
symlink_files(
name = "xla_extension",
srcs = if_windows(
["@xla//xla/python:xla_extension.pyd"],
["@xla//xla/python:xla_extension.so"],
),
dst = ".",
flatten = True,
)
exports_files([
"README.md",
"setup.py",

View File

@ -15,6 +15,7 @@
import contextlib
import io
import logging
import os
import platform
import subprocess
import sys
@ -70,9 +71,17 @@ class LoggingTest(jtu.JaxTestCase):
""")
python = sys.executable
assert "python" in python
env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"}
if os.getenv("PYTHONPATH"):
env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH")
if os.getenv("LD_LIBRARY_PATH"):
env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH")
# Make sure C++ logging is at default level for the test process.
proc = subprocess.run([python, "-c", program], capture_output=True,
env={"TF_CPP_MIN_LOG_LEVEL": "1"})
proc = subprocess.run(
[python, "-c", program],
capture_output=True,
env=env_variables,
)
lines = proc.stdout.split(b"\n")
lines.extend(proc.stderr.split(b"\n"))