From 0267ed0ba9584bbc137792361b53aa80e9c4d306 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 10 May 2024 08:47:22 -0700 Subject: [PATCH] Replace xla_extension symlink with genrule that makes xla_extension module accessible from jax._src.lib. The runfiles of the original targets were lost when the symlinked files were used. This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved. PiperOrigin-RevId: 632508121 --- jaxlib/BUILD | 21 +++++++++------------ tests/logging_test.py | 13 +++++++++++-- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 5b5080895..2c59acdcf 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -18,7 +18,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "if_building_mosaic_gpu", - "if_windows", "py_library_providing_imports_info", "pybind_extension", "pytype_library", @@ -31,6 +30,13 @@ package( default_visibility = ["//:__subpackages__"], ) +# This makes xla_extension module accessible from jax._src.lib. +genrule( + name = "xla_extension_py", + outs = ["xla_extension.py"], + cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@", +) + py_library_providing_imports_info( name = "jaxlib", srcs = [ @@ -47,8 +53,8 @@ py_library_providing_imports_info( "lapack.py", ":version", ":xla_client", + ":xla_extension_py", ], - data = [":xla_extension"], lib_rule = pytype_library, deps = [ ":cpu_feature_guard", @@ -71,6 +77,7 @@ py_library_providing_imports_info( "//jaxlib/mlir:vector_dialect", "//jaxlib/mosaic", "//jaxlib/triton", + "@xla//xla/python:xla_extension", ] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]), ) @@ -88,16 +95,6 @@ symlink_files( flatten = True, ) -symlink_files( - name = "xla_extension", - srcs = if_windows( - ["@xla//xla/python:xla_extension.pyd"], - ["@xla//xla/python:xla_extension.so"], - ), - dst = ".", - flatten = True, -) - exports_files([ "README.md", "setup.py", diff --git a/tests/logging_test.py b/tests/logging_test.py index 6b02432ce..454b525f2 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -15,6 +15,7 @@ import contextlib import io import logging +import os import platform import subprocess import sys @@ -70,9 +71,17 @@ class LoggingTest(jtu.JaxTestCase): """) python = sys.executable assert "python" in python + env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} + if os.getenv("PYTHONPATH"): + env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") + if os.getenv("LD_LIBRARY_PATH"): + env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") # Make sure C++ logging is at default level for the test process. - proc = subprocess.run([python, "-c", program], capture_output=True, - env={"TF_CPP_MIN_LOG_LEVEL": "1"}) + proc = subprocess.run( + [python, "-c", program], + capture_output=True, + env=env_variables, + ) lines = proc.stdout.split(b"\n") lines.extend(proc.stderr.split(b"\n"))