Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3

PiperOrigin-RevId: 632579101
This commit is contained in:
jax authors 2024-05-10 12:55:17 -07:00
parent c231cd51eb
commit c3cab2e3d3
3 changed files with 21 additions and 15 deletions

View File

@ -18,7 +18,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"if_building_mosaic_gpu",
"if_windows",
"py_library_providing_imports_info",
"pybind_extension",
"pytype_library",
@ -31,6 +30,13 @@ package(
default_visibility = ["//:__subpackages__"],
)
# This makes xla_extension module accessible from jax._src.lib.
genrule(
name = "xla_extension_py",
outs = ["xla_extension.py"],
cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@",
)
py_library_providing_imports_info(
name = "jaxlib",
srcs = [
@ -47,8 +53,8 @@ py_library_providing_imports_info(
"lapack.py",
":version",
":xla_client",
":xla_extension_py",
],
data = [":xla_extension"],
lib_rule = pytype_library,
deps = [
":cpu_feature_guard",
@ -71,6 +77,7 @@ py_library_providing_imports_info(
"//jaxlib/mlir:vector_dialect",
"//jaxlib/mosaic",
"//jaxlib/triton",
"@xla//xla/python:xla_extension",
] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
)
@ -88,16 +95,6 @@ symlink_files(
flatten = True,
)
symlink_files(
name = "xla_extension",
srcs = if_windows(
["@xla//xla/python:xla_extension.pyd"],
["@xla//xla/python:xla_extension.so"],
),
dst = ".",
flatten = True,
)
exports_files([
"README.md",
"setup.py",

View File

@ -205,7 +205,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
"__main__/jaxlib/gpu_sparse.py",
"__main__/jaxlib/version.py",
"__main__/jaxlib/xla_client.py",
f"__main__/jaxlib/xla_extension.{pyext}",
f"xla/xla/python/xla_extension.{pyext}",
],
)
# This file is required by PEP-561. It marks jaxlib as package containing

View File

@ -15,6 +15,7 @@
import contextlib
import io
import logging
import os
import platform
import subprocess
import sys
@ -70,9 +71,17 @@ class LoggingTest(jtu.JaxTestCase):
""")
python = sys.executable
assert "python" in python
env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"}
if os.getenv("PYTHONPATH"):
env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH")
if os.getenv("LD_LIBRARY_PATH"):
env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH")
# Make sure C++ logging is at default level for the test process.
proc = subprocess.run([python, "-c", program], capture_output=True,
env={"TF_CPP_MIN_LOG_LEVEL": "1"})
proc = subprocess.run(
[python, "-c", program],
capture_output=True,
env=env_variables,
)
lines = proc.stdout.split(b"\n")
lines.extend(proc.stderr.split(b"\n"))