mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Reverts 0267ed0ba9584bbc137792361b53aa80e9c4d306
PiperOrigin-RevId: 632548226
This commit is contained in:
parent
586568f4fe
commit
6c425338d2
21
jaxlib/BUILD
21
jaxlib/BUILD
@ -18,6 +18,7 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_mosaic_gpu",
|
||||
"if_windows",
|
||||
"py_library_providing_imports_info",
|
||||
"pybind_extension",
|
||||
"pytype_library",
|
||||
@ -30,13 +31,6 @@ package(
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
)
|
||||
|
||||
# This makes xla_extension module accessible from jax._src.lib.
|
||||
genrule(
|
||||
name = "xla_extension_py",
|
||||
outs = ["xla_extension.py"],
|
||||
cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@",
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "jaxlib",
|
||||
srcs = [
|
||||
@ -53,8 +47,8 @@ py_library_providing_imports_info(
|
||||
"lapack.py",
|
||||
":version",
|
||||
":xla_client",
|
||||
":xla_extension_py",
|
||||
],
|
||||
data = [":xla_extension"],
|
||||
lib_rule = pytype_library,
|
||||
deps = [
|
||||
":cpu_feature_guard",
|
||||
@ -77,7 +71,6 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
"//jaxlib/mosaic",
|
||||
"//jaxlib/triton",
|
||||
"@xla//xla/python:xla_extension",
|
||||
] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
|
||||
)
|
||||
|
||||
@ -95,6 +88,16 @@ symlink_files(
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
symlink_files(
|
||||
name = "xla_extension",
|
||||
srcs = if_windows(
|
||||
["@xla//xla/python:xla_extension.pyd"],
|
||||
["@xla//xla/python:xla_extension.so"],
|
||||
),
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"README.md",
|
||||
"setup.py",
|
||||
|
@ -15,7 +15,6 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
@ -71,17 +70,9 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
""")
|
||||
python = sys.executable
|
||||
assert "python" in python
|
||||
env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"}
|
||||
if os.getenv("PYTHONPATH"):
|
||||
env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH")
|
||||
if os.getenv("LD_LIBRARY_PATH"):
|
||||
env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH")
|
||||
# Make sure C++ logging is at default level for the test process.
|
||||
proc = subprocess.run(
|
||||
[python, "-c", program],
|
||||
capture_output=True,
|
||||
env=env_variables,
|
||||
)
|
||||
proc = subprocess.run([python, "-c", program], capture_output=True,
|
||||
env={"TF_CPP_MIN_LOG_LEVEL": "1"})
|
||||
|
||||
lines = proc.stdout.split(b"\n")
|
||||
lines.extend(proc.stderr.split(b"\n"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user