From c3cab2e3d3044eb77b94b9119f16aadc0abd014d Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 10 May 2024 12:55:17 -0700 Subject: [PATCH] Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3 PiperOrigin-RevId: 632579101 --- jaxlib/BUILD | 21 +++++++++------------ jaxlib/tools/build_wheel.py | 2 +- tests/logging_test.py | 13 +++++++++++-- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 5b5080895..2c59acdcf 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -18,7 +18,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "if_building_mosaic_gpu", - "if_windows", "py_library_providing_imports_info", "pybind_extension", "pytype_library", @@ -31,6 +30,13 @@ package( default_visibility = ["//:__subpackages__"], ) +# This makes xla_extension module accessible from jax._src.lib. +genrule( + name = "xla_extension_py", + outs = ["xla_extension.py"], + cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@", +) + py_library_providing_imports_info( name = "jaxlib", srcs = [ @@ -47,8 +53,8 @@ py_library_providing_imports_info( "lapack.py", ":version", ":xla_client", + ":xla_extension_py", ], - data = [":xla_extension"], lib_rule = pytype_library, deps = [ ":cpu_feature_guard", @@ -71,6 +77,7 @@ py_library_providing_imports_info( "//jaxlib/mlir:vector_dialect", "//jaxlib/mosaic", "//jaxlib/triton", + "@xla//xla/python:xla_extension", ] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]), ) @@ -88,16 +95,6 @@ symlink_files( flatten = True, ) -symlink_files( - name = "xla_extension", - srcs = if_windows( - ["@xla//xla/python:xla_extension.pyd"], - ["@xla//xla/python:xla_extension.so"], - ), - dst = ".", - flatten = True, -) - exports_files([ "README.md", "setup.py", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 651b7889e..703bc26e3 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -205,7 +205,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): "__main__/jaxlib/gpu_sparse.py", "__main__/jaxlib/version.py", "__main__/jaxlib/xla_client.py", - f"__main__/jaxlib/xla_extension.{pyext}", + f"xla/xla/python/xla_extension.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing diff --git a/tests/logging_test.py b/tests/logging_test.py index 6b02432ce..454b525f2 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -15,6 +15,7 @@ import contextlib import io import logging +import os import platform import subprocess import sys @@ -70,9 +71,17 @@ class LoggingTest(jtu.JaxTestCase): """) python = sys.executable assert "python" in python + env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} + if os.getenv("PYTHONPATH"): + env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") + if os.getenv("LD_LIBRARY_PATH"): + env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") # Make sure C++ logging is at default level for the test process. - proc = subprocess.run([python, "-c", program], capture_output=True, - env={"TF_CPP_MIN_LOG_LEVEL": "1"}) + proc = subprocess.run( + [python, "-c", program], + capture_output=True, + env=env_variables, + ) lines = proc.stdout.split(b"\n") lines.extend(proc.stderr.split(b"\n"))