mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.
Change in preparation for allowing JAX tests to run under Bazel. Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used. PiperOrigin-RevId: 458522398
This commit is contained in:
parent
68b9eaf0ee
commit
7c49864fdf
@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
# and update the sha256 with the result.
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
sha256 = "a99890443df024e52d9c7b075e9916250c6cc6b778d62c384b7dcd1903d8f4f1",
|
||||
strip_prefix = "tensorflow-d250676d7776cfbca38e8690b75e1376afecf58d",
|
||||
sha256 = "634e5ee7fba57ba8e1e6af18c8521d9849d631248979f0fa4824e772ecabdb79",
|
||||
strip_prefix = "tensorflow-6e406f0bd8d08be3fa43daa306e8565cdb29a546",
|
||||
urls = [
|
||||
"https://github.com/tensorflow/tensorflow/archive/d250676d7776cfbca38e8690b75e1376afecf58d.tar.gz",
|
||||
"https://github.com/tensorflow/tensorflow/archive/6e406f0bd8d08be3fa43daa306e8565cdb29a546.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_windows")
|
||||
load("//jaxlib:jax.bzl", "if_windows")
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
|
@ -117,15 +117,6 @@ def patch_copy_xla_extension_stubs(dst_dir):
|
||||
f.write(src)
|
||||
|
||||
|
||||
def patch_copy_xla_client_py(dst_dir):
|
||||
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
|
||||
src = f.read()
|
||||
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
|
||||
"from . import xla_extension as _xla")
|
||||
with open(os.path.join(dst_dir, "xla_client.py"), "w") as f:
|
||||
f.write(src)
|
||||
|
||||
|
||||
def patch_copy_tpu_client_py(dst_dir):
|
||||
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
|
||||
src = f.read()
|
||||
@ -186,6 +177,8 @@ def prepare_wheel(sources_path):
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/version.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/xla_client.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}")
|
||||
|
||||
cuda_dir = os.path.join(jaxlib_dir, "cuda")
|
||||
if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"):
|
||||
@ -242,7 +235,6 @@ def prepare_wheel(sources_path):
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirTransforms.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(f"org_tensorflow/tensorflow/compiler/xla/python/xla_extension.{pyext}")
|
||||
if _is_windows():
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir)
|
||||
elif _is_mac():
|
||||
@ -250,7 +242,6 @@ def prepare_wheel(sources_path):
|
||||
else:
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir)
|
||||
patch_copy_xla_extension_stubs(jaxlib_dir)
|
||||
patch_copy_xla_client_py(jaxlib_dir)
|
||||
|
||||
if exists("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"):
|
||||
copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so")
|
||||
|
20
jaxlib/BUILD
20
jaxlib/BUILD
@ -18,6 +18,7 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"flatbuffer_cc_library",
|
||||
"if_windows",
|
||||
"pybind_extension",
|
||||
)
|
||||
|
||||
@ -37,7 +38,9 @@ py_library(
|
||||
"mhlo_helpers.py",
|
||||
"pocketfft.py",
|
||||
":version",
|
||||
":xla_client",
|
||||
],
|
||||
data = [":xla_extension"],
|
||||
deps = [
|
||||
":_lapack",
|
||||
":_pocketfft",
|
||||
@ -61,6 +64,23 @@ symlink_files(
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
symlink_files(
|
||||
name = "xla_client",
|
||||
srcs = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"],
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
symlink_files(
|
||||
name = "xla_extension",
|
||||
srcs = if_windows(
|
||||
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.pyd"],
|
||||
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.so"],
|
||||
),
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"setup.py",
|
||||
"setup.cfg",
|
||||
|
@ -15,7 +15,7 @@
|
||||
"""Bazel macros used by the JAX build."""
|
||||
|
||||
load("@org_tensorflow//tensorflow/core/platform/default:build_config.bzl", _pyx_library = "pyx_library")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", _pybind_extension = "pybind_extension")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", _if_windows = "if_windows", _pybind_extension = "pybind_extension")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
|
||||
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
|
||||
@ -29,6 +29,7 @@ pyx_library = _pyx_library
|
||||
pybind_extension = _pybind_extension
|
||||
if_cuda_is_configured = _if_cuda_is_configured
|
||||
if_rocm_is_configured = _if_rocm_is_configured
|
||||
if_windows = _if_windows
|
||||
flatbuffer_cc_library = _flatbuffer_cc_library
|
||||
|
||||
def py_extension(name, srcs, copts, deps):
|
||||
|
Loading…
x
Reference in New Issue
Block a user