Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.

Change in preparation for allowing JAX tests to run under Bazel.

Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.

PiperOrigin-RevId: 458522398
This commit is contained in:
Peter Hawkins 2022-07-01 12:31:16 -07:00 committed by jax authors
parent 68b9eaf0ee
commit 7c49864fdf
5 changed files with 28 additions and 16 deletions

View File

@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "a99890443df024e52d9c7b075e9916250c6cc6b778d62c384b7dcd1903d8f4f1",
strip_prefix = "tensorflow-d250676d7776cfbca38e8690b75e1376afecf58d",
sha256 = "634e5ee7fba57ba8e1e6af18c8521d9849d631248979f0fa4824e772ecabdb79",
strip_prefix = "tensorflow-6e406f0bd8d08be3fa43daa306e8565cdb29a546",
urls = [
"https://github.com/tensorflow/tensorflow/archive/d250676d7776cfbca38e8690b75e1376afecf58d.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/6e406f0bd8d08be3fa43daa306e8565cdb29a546.tar.gz",
],
)

View File

@ -17,7 +17,7 @@
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_windows")
load("//jaxlib:jax.bzl", "if_windows")
licenses(["notice"]) # Apache 2

View File

@ -117,15 +117,6 @@ def patch_copy_xla_extension_stubs(dst_dir):
f.write(src)
def patch_copy_xla_client_py(dst_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
src = f.read()
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
with open(os.path.join(dst_dir, "xla_client.py"), "w") as f:
f.write(src)
def patch_copy_tpu_client_py(dst_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
src = f.read()
@ -186,6 +177,8 @@ def prepare_wheel(sources_path):
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")
copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py")
copy_to_jaxlib("__main__/jaxlib/version.py")
copy_to_jaxlib("__main__/jaxlib/xla_client.py")
copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}")
cuda_dir = os.path.join(jaxlib_dir, "cuda")
if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"):
@ -242,7 +235,6 @@ def prepare_wheel(sources_path):
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirTransforms.{pyext}", dst_dir=mlir_libs_dir)
copy_to_jaxlib(f"org_tensorflow/tensorflow/compiler/xla/python/xla_extension.{pyext}")
if _is_windows():
copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir)
elif _is_mac():
@ -250,7 +242,6 @@ def prepare_wheel(sources_path):
else:
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir)
patch_copy_xla_extension_stubs(jaxlib_dir)
patch_copy_xla_client_py(jaxlib_dir)
if exists("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"):
copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so")

View File

@ -18,6 +18,7 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"flatbuffer_cc_library",
"if_windows",
"pybind_extension",
)
@ -37,7 +38,9 @@ py_library(
"mhlo_helpers.py",
"pocketfft.py",
":version",
":xla_client",
],
data = [":xla_extension"],
deps = [
":_lapack",
":_pocketfft",
@ -61,6 +64,23 @@ symlink_files(
flatten = True,
)
symlink_files(
name = "xla_client",
srcs = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"],
dst = ".",
flatten = True,
)
symlink_files(
name = "xla_extension",
srcs = if_windows(
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.pyd"],
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.so"],
),
dst = ".",
flatten = True,
)
exports_files([
"setup.py",
"setup.cfg",

View File

@ -15,7 +15,7 @@
"""Bazel macros used by the JAX build."""
load("@org_tensorflow//tensorflow/core/platform/default:build_config.bzl", _pyx_library = "pyx_library")
load("@org_tensorflow//tensorflow:tensorflow.bzl", _pybind_extension = "pybind_extension")
load("@org_tensorflow//tensorflow:tensorflow.bzl", _if_windows = "if_windows", _pybind_extension = "pybind_extension")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
@ -29,6 +29,7 @@ pyx_library = _pyx_library
pybind_extension = _pybind_extension
if_cuda_is_configured = _if_cuda_is_configured
if_rocm_is_configured = _if_rocm_is_configured
if_windows = _if_windows
flatbuffer_cc_library = _flatbuffer_cc_library
def py_extension(name, srcs, copts, deps):