From 6c425338d20c0c9be3fc69d2f07ababf79c881d3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 10 May 2024 11:05:47 -0700 Subject: [PATCH] Reverts 0267ed0ba9584bbc137792361b53aa80e9c4d306 PiperOrigin-RevId: 632548226 --- jaxlib/BUILD | 21 ++++++++++++--------- tests/logging_test.py | 13 ++----------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 2c59acdcf..5b5080895 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -18,6 +18,7 @@ load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "if_building_mosaic_gpu", + "if_windows", "py_library_providing_imports_info", "pybind_extension", "pytype_library", @@ -30,13 +31,6 @@ package( default_visibility = ["//:__subpackages__"], ) -# This makes xla_extension module accessible from jax._src.lib. -genrule( - name = "xla_extension_py", - outs = ["xla_extension.py"], - cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@", -) - py_library_providing_imports_info( name = "jaxlib", srcs = [ @@ -53,8 +47,8 @@ py_library_providing_imports_info( "lapack.py", ":version", ":xla_client", - ":xla_extension_py", ], + data = [":xla_extension"], lib_rule = pytype_library, deps = [ ":cpu_feature_guard", @@ -77,7 +71,6 @@ py_library_providing_imports_info( "//jaxlib/mlir:vector_dialect", "//jaxlib/mosaic", "//jaxlib/triton", - "@xla//xla/python:xla_extension", ] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]), ) @@ -95,6 +88,16 @@ symlink_files( flatten = True, ) +symlink_files( + name = "xla_extension", + srcs = if_windows( + ["@xla//xla/python:xla_extension.pyd"], + ["@xla//xla/python:xla_extension.so"], + ), + dst = ".", + flatten = True, +) + exports_files([ "README.md", "setup.py", diff --git a/tests/logging_test.py b/tests/logging_test.py index 454b525f2..6b02432ce 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -15,7 +15,6 @@ import contextlib import io import logging -import os import platform import subprocess import sys @@ -71,17 +70,9 @@ class LoggingTest(jtu.JaxTestCase): """) python = sys.executable assert "python" in python - env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} - if os.getenv("PYTHONPATH"): - env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") - if os.getenv("LD_LIBRARY_PATH"): - env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") # Make sure C++ logging is at default level for the test process. - proc = subprocess.run( - [python, "-c", program], - capture_output=True, - env=env_variables, - ) + proc = subprocess.run([python, "-c", program], capture_output=True, + env={"TF_CPP_MIN_LOG_LEVEL": "1"}) lines = proc.stdout.split(b"\n") lines.extend(proc.stderr.split(b"\n"))