Reverts 0267ed0ba9584bbc137792361b53aa80e9c4d306

PiperOrigin-RevId: 632548226
This commit is contained in:
Peter Hawkins 2024-05-10 11:05:47 -07:00 committed by jax authors
parent 586568f4fe
commit 6c425338d2
2 changed files with 14 additions and 20 deletions

View File

@ -18,6 +18,7 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"if_building_mosaic_gpu",
"if_windows",
"py_library_providing_imports_info",
"pybind_extension",
"pytype_library",
@ -30,13 +31,6 @@ package(
default_visibility = ["//:__subpackages__"],
)
# This makes xla_extension module accessible from jax._src.lib.
genrule(
name = "xla_extension_py",
outs = ["xla_extension.py"],
cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@",
)
py_library_providing_imports_info(
name = "jaxlib",
srcs = [
@ -53,8 +47,8 @@ py_library_providing_imports_info(
"lapack.py",
":version",
":xla_client",
":xla_extension_py",
],
data = [":xla_extension"],
lib_rule = pytype_library,
deps = [
":cpu_feature_guard",
@ -77,7 +71,6 @@ py_library_providing_imports_info(
"//jaxlib/mlir:vector_dialect",
"//jaxlib/mosaic",
"//jaxlib/triton",
"@xla//xla/python:xla_extension",
] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
)
@ -95,6 +88,16 @@ symlink_files(
flatten = True,
)
symlink_files(
name = "xla_extension",
srcs = if_windows(
["@xla//xla/python:xla_extension.pyd"],
["@xla//xla/python:xla_extension.so"],
),
dst = ".",
flatten = True,
)
exports_files([
"README.md",
"setup.py",

View File

@ -15,7 +15,6 @@
import contextlib
import io
import logging
import os
import platform
import subprocess
import sys
@ -71,17 +70,9 @@ class LoggingTest(jtu.JaxTestCase):
""")
python = sys.executable
assert "python" in python
env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"}
if os.getenv("PYTHONPATH"):
env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH")
if os.getenv("LD_LIBRARY_PATH"):
env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH")
# Make sure C++ logging is at default level for the test process.
proc = subprocess.run(
[python, "-c", program],
capture_output=True,
env=env_variables,
)
proc = subprocess.run([python, "-c", program], capture_output=True,
env={"TF_CPP_MIN_LOG_LEVEL": "1"})
lines = proc.stdout.split(b"\n")
lines.extend(proc.stderr.split(b"\n"))