mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3
PiperOrigin-RevId: 632579101
This commit is contained in:
parent
c231cd51eb
commit
c3cab2e3d3
21
jaxlib/BUILD
21
jaxlib/BUILD
@ -18,7 +18,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_mosaic_gpu",
|
||||
"if_windows",
|
||||
"py_library_providing_imports_info",
|
||||
"pybind_extension",
|
||||
"pytype_library",
|
||||
@ -31,6 +30,13 @@ package(
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
)
|
||||
|
||||
# This makes xla_extension module accessible from jax._src.lib.
|
||||
genrule(
|
||||
name = "xla_extension_py",
|
||||
outs = ["xla_extension.py"],
|
||||
cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@",
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "jaxlib",
|
||||
srcs = [
|
||||
@ -47,8 +53,8 @@ py_library_providing_imports_info(
|
||||
"lapack.py",
|
||||
":version",
|
||||
":xla_client",
|
||||
":xla_extension_py",
|
||||
],
|
||||
data = [":xla_extension"],
|
||||
lib_rule = pytype_library,
|
||||
deps = [
|
||||
":cpu_feature_guard",
|
||||
@ -71,6 +77,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
"//jaxlib/mosaic",
|
||||
"//jaxlib/triton",
|
||||
"@xla//xla/python:xla_extension",
|
||||
] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
|
||||
)
|
||||
|
||||
@ -88,16 +95,6 @@ symlink_files(
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
symlink_files(
|
||||
name = "xla_extension",
|
||||
srcs = if_windows(
|
||||
["@xla//xla/python:xla_extension.pyd"],
|
||||
["@xla//xla/python:xla_extension.so"],
|
||||
),
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"README.md",
|
||||
"setup.py",
|
||||
|
@ -205,7 +205,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
|
||||
"__main__/jaxlib/gpu_sparse.py",
|
||||
"__main__/jaxlib/version.py",
|
||||
"__main__/jaxlib/xla_client.py",
|
||||
f"__main__/jaxlib/xla_extension.{pyext}",
|
||||
f"xla/xla/python/xla_extension.{pyext}",
|
||||
],
|
||||
)
|
||||
# This file is required by PEP-561. It marks jaxlib as package containing
|
||||
|
@ -15,6 +15,7 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
@ -70,9 +71,17 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
""")
|
||||
python = sys.executable
|
||||
assert "python" in python
|
||||
env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"}
|
||||
if os.getenv("PYTHONPATH"):
|
||||
env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH")
|
||||
if os.getenv("LD_LIBRARY_PATH"):
|
||||
env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH")
|
||||
# Make sure C++ logging is at default level for the test process.
|
||||
proc = subprocess.run([python, "-c", program], capture_output=True,
|
||||
env={"TF_CPP_MIN_LOG_LEVEL": "1"})
|
||||
proc = subprocess.run(
|
||||
[python, "-c", program],
|
||||
capture_output=True,
|
||||
env=env_variables,
|
||||
)
|
||||
|
||||
lines = proc.stdout.split(b"\n")
|
||||
lines.extend(proc.stderr.split(b"\n"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user