mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
This commit is contained in:
commit
42ef649e65
18
.bazelrc
18
.bazelrc
@ -1,6 +1,10 @@
|
||||
############################################################################
|
||||
# All default build options below.
|
||||
|
||||
# Required by OpenXLA
|
||||
# https://github.com/openxla/xla/issues/1323
|
||||
build --nocheck_visibility
|
||||
|
||||
# Sets the default Apple platform to macOS.
|
||||
build --apple_platform_type=macos
|
||||
build --macos_minimum_os=10.14
|
||||
@ -35,9 +39,9 @@ build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
|
||||
|
||||
# Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled,
|
||||
# these values are overridden.
|
||||
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=false
|
||||
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false
|
||||
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false
|
||||
build --@xla//xla/python:enable_gpu=false
|
||||
build --@xla//xla/python:enable_tpu=false
|
||||
build --@xla//xla/python:enable_plugin_device=false
|
||||
|
||||
###########################################################################
|
||||
|
||||
@ -65,12 +69,12 @@ build:cuda --repo_env TF_NEED_CUDA=1
|
||||
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_52,sm_60,sm_70,compute_80"
|
||||
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
||||
build:cuda --@local_config_cuda//:enable_cuda
|
||||
build:cuda --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true
|
||||
build:cuda --@xla//xla/python:enable_gpu=true
|
||||
build:cuda --define=xla_python_enable_gpu=true
|
||||
|
||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true
|
||||
build:rocm --@xla//xla/python:enable_gpu=true
|
||||
build:rocm --define=xla_python_enable_gpu=true
|
||||
build:rocm --repo_env TF_NEED_ROCM=1
|
||||
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
|
||||
@ -113,10 +117,10 @@ build:macos --config=posix
|
||||
# Suppress all warning messages.
|
||||
build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
||||
|
||||
build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true
|
||||
build:tpu --@xla//xla/python:enable_tpu=true
|
||||
build:tpu --define=with_tpu_support=true
|
||||
|
||||
build:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true
|
||||
build:plugin_device --@xla//xla/python:enable_plugin_device=true
|
||||
|
||||
#########################################################################
|
||||
# RBE config options below.
|
||||
|
42
WORKSPACE
42
WORKSPACE
@ -1,16 +1,16 @@
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
# To update TensorFlow to a new revision,
|
||||
# To update XLA to a new revision,
|
||||
# a) update URL and strip_prefix to the new git commit hash
|
||||
# b) get the sha256 hash of the commit by running:
|
||||
# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update the sha256 with the result.
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
sha256 = "08fd0ab0b672510229ad2fff276a3634f205fc539fa16a5bdeeaaccd881ece27",
|
||||
strip_prefix = "tensorflow-2aaeef25361311b21b9e81e992edff94bcb6bae3",
|
||||
name = "xla",
|
||||
sha256 = "9f39af4d81d2c8bd52b47f4ef37dfd6642c6950112e4d9d3d95ae19982c46eba",
|
||||
strip_prefix = "xla-0f31407ee498e6dba242d03f8d382ebcfcc61790",
|
||||
urls = [
|
||||
"https://github.com/tensorflow/tensorflow/archive/2aaeef25361311b21b9e81e992edff94bcb6bae3.tar.gz",
|
||||
"https://github.com/openxla/xla/archive/0f31407ee498e6dba242d03f8d382ebcfcc61790.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
@ -19,26 +19,32 @@ http_archive(
|
||||
# local checkout by either:
|
||||
# a) overriding the TF repository on the build.py command line by passing a flag
|
||||
# like:
|
||||
# python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow
|
||||
# python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
|
||||
# or
|
||||
# b) by commenting out the http_archive above and uncommenting the following:
|
||||
# local_repository(
|
||||
# name = "org_tensorflow",
|
||||
# path = "/path/to/tensorflow",
|
||||
# name = "xla",
|
||||
# path = "/path/to/xla",
|
||||
# )
|
||||
|
||||
load("//third_party/ducc:workspace.bzl", ducc = "repo")
|
||||
ducc()
|
||||
|
||||
# Initialize TensorFlow's external dependencies.
|
||||
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
|
||||
tf_workspace3()
|
||||
load("@xla//:workspace4.bzl", "xla_workspace4")
|
||||
xla_workspace4()
|
||||
|
||||
load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2")
|
||||
tf_workspace2()
|
||||
load("@xla//:workspace3.bzl", "xla_workspace3")
|
||||
xla_workspace3()
|
||||
|
||||
load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1")
|
||||
tf_workspace1()
|
||||
load("@xla//:workspace2.bzl", "xla_workspace2")
|
||||
xla_workspace2()
|
||||
|
||||
load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0")
|
||||
tf_workspace0()
|
||||
load("@xla//:workspace1.bzl", "xla_workspace1")
|
||||
xla_workspace1()
|
||||
|
||||
load("@xla//:workspace0.bzl", "xla_workspace0")
|
||||
xla_workspace0()
|
||||
|
||||
|
||||
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
|
||||
flatbuffers()
|
@ -44,11 +44,11 @@ py_binary(
|
||||
"//jaxlib:README.md",
|
||||
"//jaxlib:setup.py",
|
||||
"//jaxlib:setup.cfg",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
|
||||
"@xla//xla/python:xla_client",
|
||||
] + if_windows([
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
]) + select({
|
||||
":remote_tpu_enabled": ["@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client"],
|
||||
":remote_tpu_enabled": ["@xla//xla/python/tpu_driver/client:py_tpu_client"],
|
||||
"//conditions:default": [],
|
||||
}) + if_cuda([
|
||||
"//jaxlib/cuda:cuda_gpu_support",
|
||||
|
@ -103,14 +103,14 @@ def patch_copy_xla_extension_stubs(dst_dir):
|
||||
os.makedirs(xla_extension_dir)
|
||||
for stub_name in _XLA_EXTENSION_STUBS:
|
||||
stub_path = r.Rlocation(
|
||||
"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name)
|
||||
"xla/xla/python/xla_extension/" + stub_name)
|
||||
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
|
||||
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
|
||||
continue
|
||||
with open(stub_path) as f:
|
||||
src = f.read()
|
||||
src = src.replace(
|
||||
"from tensorflow.compiler.xla.python import xla_extension",
|
||||
"from xla.python import xla_extension",
|
||||
"from .. import xla_extension"
|
||||
)
|
||||
with open(os.path.join(xla_extension_dir, stub_name), "w") as f:
|
||||
@ -118,14 +118,14 @@ def patch_copy_xla_extension_stubs(dst_dir):
|
||||
|
||||
|
||||
def patch_copy_tpu_client_py(dst_dir):
|
||||
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
|
||||
with open(r.Rlocation("xla/xla/python/tpu_driver/client/tpu_client.py")) as f:
|
||||
src = f.read()
|
||||
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
|
||||
src = src.replace("from xla.python import xla_extension as _xla",
|
||||
"from . import xla_extension as _xla")
|
||||
src = src.replace("from tensorflow.compiler.xla.python import xla_client",
|
||||
src = src.replace("from xla.python import xla_client",
|
||||
"from . import xla_client")
|
||||
src = src.replace(
|
||||
"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
|
||||
"from xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
|
||||
"from . import tpu_client_extension as _tpu_client")
|
||||
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
|
||||
f.write(src)
|
||||
@ -143,7 +143,7 @@ def verify_mac_libraries_dont_reference_chkstack():
|
||||
return
|
||||
nm = subprocess.run(
|
||||
["nm", "-g",
|
||||
r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so")
|
||||
r.Rlocation("xla/xla/python/xla_extension.so")
|
||||
],
|
||||
capture_output=True, text=True,
|
||||
check=False)
|
||||
@ -250,8 +250,8 @@ def prepare_wheel(sources_path):
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir)
|
||||
patch_copy_xla_extension_stubs(jaxlib_dir)
|
||||
|
||||
if exists("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"):
|
||||
copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so")
|
||||
if exists("xla/xla/python/tpu_driver/client/tpu_client_extension.so"):
|
||||
copy_to_jaxlib("xla/xla/python/tpu_driver/client/tpu_client_extension.so")
|
||||
patch_copy_tpu_client_py(jaxlib_dir)
|
||||
|
||||
|
||||
|
@ -66,11 +66,11 @@ specify the paths to CUDA and CUDNN, which you must have installed. Here
|
||||
may need to use `python3` instead. By default, the wheel is written to the
|
||||
`dist/` subdirectory of the current directory.
|
||||
|
||||
### Building jaxlib from source with a modified TensorFlow repository.
|
||||
### Building jaxlib from source with a modified XLA repository.
|
||||
|
||||
JAX depends on XLA, whose source code is in the
|
||||
[Tensorflow GitHub repository](https://github.com/tensorflow/tensorflow).
|
||||
By default JAX uses a pinned copy of the TensorFlow repository, but we often
|
||||
[XLA GitHub repository](https://github.com/openxla/xla).
|
||||
By default JAX uses a pinned copy of the XLA repository, but we often
|
||||
want to use a locally-modified copy of XLA when working on JAX. There are two
|
||||
ways to do this:
|
||||
|
||||
@ -78,12 +78,12 @@ ways to do this:
|
||||
line flag to `build.py` as follows:
|
||||
|
||||
```
|
||||
python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow
|
||||
python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
|
||||
```
|
||||
* modify the `WORKSPACE` file in the root of the JAX source tree to point to
|
||||
a different TensorFlow tree.
|
||||
a different XLA tree.
|
||||
|
||||
To contribute changes back to XLA, send PRs to the TensorFlow repository.
|
||||
To contribute changes back to XLA, send PRs to the XLA repository.
|
||||
|
||||
The version of XLA pinned by JAX is regularly updated, but is updated in
|
||||
particular before each `jaxlib` release.
|
||||
@ -141,7 +141,7 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
|
||||
rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
|
||||
```
|
||||
|
||||
AMD's fork of the XLA (TensorFlow) repository may include fixes
|
||||
AMD's fork of the XLA repository may include fixes
|
||||
not present in the upstream repository. To use AMD's fork, you should clone
|
||||
their repository:
|
||||
```
|
||||
@ -152,7 +152,7 @@ To build jaxlib with ROCM support, you can run the following build command,
|
||||
suitably adjusted for your paths and ROCM version.
|
||||
```
|
||||
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.3.0 \
|
||||
--bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow-upstream
|
||||
--bazel_options=--override_repository=xla=/path/to/xla-upstream
|
||||
```
|
||||
|
||||
## Installing `jax`
|
||||
|
@ -121,7 +121,7 @@ no released `jax` version uses that API.
|
||||
`jaxlib` is split across two main repositories, namely the
|
||||
[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib)
|
||||
and in the
|
||||
[XLA source tree, which lives inside the TensorFlow repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla).
|
||||
[XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla).
|
||||
The JAX-specific pieces inside XLA are primarily in the
|
||||
[`xla/python` subdirectory](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/python).
|
||||
|
||||
@ -164,7 +164,7 @@ compatibility, we have additional versioning that is independent of the `jaxlib`
|
||||
release version numbers.
|
||||
|
||||
We maintain an additional version number (`_version`) in
|
||||
[`xla_client.py` in the XLA repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py).
|
||||
[`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py).
|
||||
The idea is that this version number, is defined in `xla/python`
|
||||
together with the C++ parts of JAX, is also accessible to JAX Python as
|
||||
`jax._src.lib.xla_extension_version`, and must
|
||||
|
@ -12,28 +12,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"@org_tensorflow//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_binary",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
tf_cc_binary(
|
||||
cc_binary(
|
||||
name = "main",
|
||||
srcs = ["main.cc"],
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/compiler/xla:literal",
|
||||
"@org_tensorflow//tensorflow/compiler/xla:literal_util",
|
||||
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
|
||||
"@org_tensorflow//tensorflow/compiler/xla:status",
|
||||
"@org_tensorflow//tensorflow/compiler/xla:statusor",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader",
|
||||
"@org_tensorflow//tensorflow/core/platform:logging",
|
||||
"@org_tensorflow//tensorflow/core/platform:platform_port",
|
||||
"@xla//:literal",
|
||||
"@xla//:literal_util",
|
||||
"@xla//:shape_util",
|
||||
"@xla//:status",
|
||||
"@xla//:statusor",
|
||||
"@xla///pjrt:pjrt_client",
|
||||
"@xla///pjrt:tfrt_cpu_pjrt_client",
|
||||
"@xla///service:hlo_proto_cc",
|
||||
"@xla///tools:hlo_module_loader",
|
||||
"@tsl///platform:logging",
|
||||
"@tsl///platform:platform_port",
|
||||
],
|
||||
)
|
||||
|
@ -40,18 +40,18 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "xla/literal.h"
|
||||
#include "xla/literal_util.h"
|
||||
#include "xla/pjrt/pjrt_client.h"
|
||||
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
|
||||
#include "xla/status.h"
|
||||
#include "xla/statusor.h"
|
||||
#include "xla/tools/hlo_module_loader.h"
|
||||
#include "tsl/platform/init_main.h"
|
||||
#include "tsl/platform/logging.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::port::InitMain("", &argc, &argv);
|
||||
tsl::port::InitMain("", &argc, &argv);
|
||||
|
||||
// Load HloModule from file.
|
||||
std::string hlo_filename = "/tmp/fn_hlo.txt";
|
||||
|
@ -68,7 +68,7 @@ symlink_files(
|
||||
|
||||
symlink_files(
|
||||
name = "xla_client",
|
||||
srcs = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"],
|
||||
srcs = ["@xla///python:xla_client"],
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
@ -76,8 +76,8 @@ symlink_files(
|
||||
symlink_files(
|
||||
name = "xla_extension",
|
||||
srcs = if_windows(
|
||||
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.pyd"],
|
||||
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.so"],
|
||||
["@xla///python:xla_extension.pyd"],
|
||||
["@xla///python:xla_extension.so"],
|
||||
),
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
@ -140,7 +140,7 @@ pybind_extension(
|
||||
srcs = ["cpu_feature_guard.c"],
|
||||
module_name = "cpu_feature_guard",
|
||||
deps = [
|
||||
"@org_tensorflow//third_party/python_runtime:headers",
|
||||
"@xla//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -31,7 +31,7 @@ cc_library(
|
||||
srcs = ["lapack_kernels.cc"],
|
||||
hdrs = ["lapack_kernels.h"],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla///service:custom_call_status",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
],
|
||||
)
|
||||
@ -74,7 +74,7 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":ducc_fft_flatbuffers_cc",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla///service:custom_call_status",
|
||||
"@ducc",
|
||||
"@flatbuffers//:runtime_cc",
|
||||
],
|
||||
@ -106,7 +106,7 @@ cc_library(
|
||||
":ducc_fft_kernels",
|
||||
":lapack_kernels",
|
||||
":lapack_kernels_using_lapack",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"@xla///service:custom_call_target_registry",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
// This file is not used by JAX itself, but exists to assist with running
|
||||
// JAX-generated HLO code from outside of JAX.
|
||||
|
||||
#include "jaxlib/cpu/lapack_kernels.h"
|
||||
#include "jaxlib/cpu/ducc_fft_kernels.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "jaxlib/cpu/lapack_kernels.h"
|
||||
#include "xla/service/custom_call_target_registry.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
@ -15,10 +15,10 @@ limitations under the License.
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "ducc/src/ducc0/fft/fft.h"
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "jaxlib/cpu/ducc_fft_generated.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "ducc/src/ducc0/fft/fft.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef JAXLIB_CPU_DUCC_FFT_KERNELS_H_
|
||||
#define JAXLIB_CPU_DUCC_FFT_KERNELS_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
|
||||
// by the pybind wrapper that links them to an existing SciPy lapack instance,
|
||||
|
@ -51,8 +51,8 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":cuda_vendor",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib",
|
||||
"@xla///stream_executor/cuda:cusolver_lib",
|
||||
"@xla///stream_executor/cuda:cusparse_lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
@ -72,9 +72,9 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
"@xla///service:custom_call_status",
|
||||
"@xla///stream_executor/cuda:cublas_lib",
|
||||
"@xla///stream_executor/cuda:cudart_stub",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
@ -102,7 +102,7 @@ pybind_extension(
|
||||
":cublas_kernels",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib",
|
||||
"@xla///stream_executor/cuda:cublas_lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@pybind11",
|
||||
@ -118,9 +118,9 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudnn_lib",
|
||||
"@xla///service:custom_call_status",
|
||||
"@xla///stream_executor/cuda:cudart_stub",
|
||||
"@xla///stream_executor/cuda:cudnn_lib",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -157,8 +157,8 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib",
|
||||
"@xla///service:custom_call_status",
|
||||
"@xla///stream_executor/cuda:cusolver_lib",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
@ -180,8 +180,8 @@ pybind_extension(
|
||||
":cuda_vendor",
|
||||
":cusolver_kernels",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib",
|
||||
"@xla///stream_executor/cuda:cudart_stub",
|
||||
"@xla///stream_executor/cuda:cusolver_lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
@ -198,9 +198,9 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib",
|
||||
"@xla///service:custom_call_status",
|
||||
"@xla///stream_executor/cuda:cudart_stub",
|
||||
"@xla///stream_executor/cuda:cusparse_lib",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
@ -222,8 +222,8 @@ pybind_extension(
|
||||
":cuda_vendor",
|
||||
":cusparse_kernels",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib",
|
||||
"@xla///stream_executor/cuda:cudart_stub",
|
||||
"@xla///stream_executor/cuda:cusparse_lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
@ -249,7 +249,7 @@ cc_library(
|
||||
":cuda_lu_pivot_kernels_impl",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla///service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -264,7 +264,7 @@ cuda_library(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla///service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -284,7 +284,7 @@ pybind_extension(
|
||||
":cuda_lu_pivot_kernels_impl",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
"@xla///stream_executor/cuda:cudart_stub",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@pybind11",
|
||||
],
|
||||
@ -301,7 +301,7 @@ cc_library(
|
||||
":cuda_prng_kernels_impl",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla///service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -316,7 +316,7 @@ cuda_library(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla///service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -334,7 +334,7 @@ pybind_extension(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_prng_kernels",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
"@xla///stream_executor/cuda:cudart_stub",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@pybind11",
|
||||
],
|
||||
@ -351,7 +351,7 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
":cusolver_kernels",
|
||||
":cusparse_kernels",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"@xla///service:custom_call_target_registry",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <cstddef>
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/solver_kernels.h"
|
||||
#include "jaxlib/gpu/sparse_kernels.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "xla/service/custom_call_target_registry.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
|
@ -14,12 +14,11 @@
|
||||
|
||||
"""Bazel macros used by the JAX build."""
|
||||
|
||||
load("@org_tensorflow//tensorflow/tsl/platform/default:build_config.bzl", _pyx_library = "pyx_library")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", _if_windows = "if_windows", _pybind_extension = "pybind_extension")
|
||||
load("@tsl//tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
|
||||
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
|
||||
load("@org_tensorflow//tensorflow/core/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
|
||||
load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
|
||||
|
||||
# Explicitly re-exports names to avoid "unused variable" warnings from .bzl
|
||||
# lint tools.
|
||||
@ -27,7 +26,6 @@ cuda_library = _cuda_library
|
||||
rocm_library = _rocm_library
|
||||
pytype_strict_library = native.py_library
|
||||
pytype_test = native.py_test
|
||||
pyx_library = _pyx_library
|
||||
pybind_extension = _pybind_extension
|
||||
if_cuda_is_configured = _if_cuda_is_configured
|
||||
if_rocm_is_configured = _if_rocm_is_configured
|
||||
@ -66,8 +64,8 @@ def pytype_library(name, pytype_srcs = None, **kwargs):
|
||||
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
|
||||
lib_rule(name = name, **kwargs)
|
||||
|
||||
def py_extension(name, srcs, copts, deps):
|
||||
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)
|
||||
def py_extension(name, srcs, copts, deps, linkopts = []):
|
||||
pybind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
|
||||
|
||||
def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []):
|
||||
"""Workaround DLL building issue.
|
||||
|
@ -120,7 +120,7 @@ symlink_inputs(
|
||||
name = "mhlo_dialect",
|
||||
rule = py_library,
|
||||
symlinked_inputs = {"srcs": {"dialects": [
|
||||
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:MhloOpsPyFiles",
|
||||
"@xla//xla/mlir_hlo:MhloOpsPyFiles",
|
||||
]}},
|
||||
deps = [
|
||||
":core",
|
||||
|
@ -31,12 +31,25 @@ COPTS = [
|
||||
"-frtti",
|
||||
]
|
||||
|
||||
LINKOPTS = select({
|
||||
"@tsl//tsl:macos": [
|
||||
"-Wl,-rpath,@loader_path/",
|
||||
"-Wl,-rename_section,__TEXT,text_env,__TEXT,__text",
|
||||
],
|
||||
"@tsl//tsl:windows": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,-rpath,$$ORIGIN/",
|
||||
],
|
||||
})
|
||||
|
||||
|
||||
py_extension(
|
||||
name = "_mlir",
|
||||
srcs = [
|
||||
"@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI",
|
||||
@ -50,6 +63,7 @@ py_extension(
|
||||
"@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
||||
@ -64,6 +78,7 @@ py_extension(
|
||||
"@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
||||
@ -90,6 +105,7 @@ py_extension(
|
||||
name = "_site_initialize_0",
|
||||
srcs = ["_site_initialize_0.cc"],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
@ -106,15 +122,16 @@ py_extension(
|
||||
py_extension(
|
||||
name = "_mlirHlo",
|
||||
srcs = [
|
||||
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:bindings/python/MlirHloModule.cc",
|
||||
"@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@local_config_python//:headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:CAPIHeaders",
|
||||
"@xla//xla/mlir_hlo:CAPIHeaders",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
@ -129,6 +146,7 @@ py_extension(
|
||||
"@stablehlo//:stablehlo/integrations/python/ChloModule.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
@ -145,6 +163,7 @@ py_extension(
|
||||
"@stablehlo//:stablehlo/integrations/python/StablehloModule.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
@ -158,12 +177,12 @@ py_extension(
|
||||
cc_library(
|
||||
name = "jaxlib_mlir_capi_shared_library",
|
||||
srcs = select({
|
||||
"@org_tensorflow//tensorflow:windows": [":jaxlib_mlir_capi.dll"],
|
||||
"@org_tensorflow//tensorflow:macos": [":libjaxlib_mlir_capi.dylib"],
|
||||
"@tsl//tsl:windows": [":jaxlib_mlir_capi.dll"],
|
||||
"@tsl//tsl:macos": [":libjaxlib_mlir_capi.dylib"],
|
||||
"//conditions:default": [":libjaxlib_mlir_capi.so"],
|
||||
}),
|
||||
deps = select({
|
||||
"@org_tensorflow//tensorflow:windows": [":jaxlib_mlir_capi_dll"],
|
||||
"@tsl//tsl:windows": [":jaxlib_mlir_capi_dll"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
@ -174,7 +193,7 @@ cc_library(
|
||||
"@llvm-project//mlir:CAPISparseTensorObjects",
|
||||
"@llvm-project//mlir:CAPITransformsObjects",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:CAPIObjects",
|
||||
"@xla//xla/mlir_hlo:CAPIObjects",
|
||||
"@stablehlo//:chlo_capi_objects",
|
||||
"@stablehlo//:stablehlo_capi_objects",
|
||||
],
|
||||
|
@ -75,7 +75,7 @@ cc_library(
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_rocm//rocm:hipblas",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -114,7 +114,7 @@ cc_library(
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_rocm//rocm:hipsolver",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -154,7 +154,7 @@ cc_library(
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_rocm//rocm:hipsparse",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -197,7 +197,7 @@ cc_library(
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -210,7 +210,7 @@ rocm_library(
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -244,7 +244,7 @@ cc_library(
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -257,7 +257,7 @@ rocm_library(
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
|
15
third_party/BUILD.bazel
vendored
Normal file
15
third_party/BUILD.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
15
third_party/flatbuffers/BUILD.bazel
vendored
Normal file
15
third_party/flatbuffers/BUILD.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This empty BUILD file is required to make Bazel treat this directory as a package.
|
43
third_party/flatbuffers/BUILD.system
vendored
Normal file
43
third_party/flatbuffers/BUILD.system
vendored
Normal file
@ -0,0 +1,43 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
filegroup(
|
||||
name = "LICENSE.txt",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Public flatc library to compile flatbuffer files at runtime.
|
||||
cc_library(
|
||||
name = "flatbuffers",
|
||||
linkopts = ["-lflatbuffers"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Public flatc compiler library.
|
||||
cc_library(
|
||||
name = "flatc_library",
|
||||
linkopts = ["-lflatbuffers"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "lnflatc",
|
||||
outs = ["flatc.bin"],
|
||||
cmd = "ln -s $$(which flatc) $@",
|
||||
)
|
||||
|
||||
# Public flatc compiler.
|
||||
sh_binary(
|
||||
name = "flatc",
|
||||
srcs = ["flatc.bin"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime_cc",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "runtime_py",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
661
third_party/flatbuffers/build_defs.bzl
vendored
Normal file
661
third_party/flatbuffers/build_defs.bzl
vendored
Normal file
@ -0,0 +1,661 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""BUILD rules for generating flatbuffer files."""
|
||||
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
|
||||
flatc_path = "@flatbuffers//:flatc"
|
||||
zip_files = "//tensorflow/lite/tools:zip_files"
|
||||
|
||||
DEFAULT_INCLUDE_PATHS = [
|
||||
"./",
|
||||
"$(GENDIR)",
|
||||
"$(BINDIR)",
|
||||
]
|
||||
|
||||
DEFAULT_FLATC_ARGS = [
|
||||
"--no-union-value-namespacing",
|
||||
"--gen-object-api",
|
||||
]
|
||||
|
||||
def flatbuffer_library_public(
|
||||
name,
|
||||
srcs,
|
||||
outs,
|
||||
language_flag,
|
||||
out_prefix = "",
|
||||
includes = [],
|
||||
include_paths = [],
|
||||
compatible_with = [],
|
||||
flatc_args = DEFAULT_FLATC_ARGS,
|
||||
reflection_name = "",
|
||||
reflection_visibility = None, # buildifier: disable=unused-variable
|
||||
output_to_bindir = False):
|
||||
"""Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler.
|
||||
|
||||
Outs:
|
||||
filegroup(name): all generated source files.
|
||||
Fileset([reflection_name]): (Optional) all generated reflection binaries.
|
||||
|
||||
Args:
|
||||
name: Rule name.
|
||||
srcs: Source .fbs files. Sent in order to the compiler.
|
||||
outs: Output files from flatc.
|
||||
language_flag: Target language flag. One of [-c, -j, -js].
|
||||
out_prefix: Prepend this path to the front of all generated files except on
|
||||
single source targets. Usually is a directory name.
|
||||
includes: Optional, list of filegroups of schemas that the srcs depend on.
|
||||
include_paths: Optional, list of paths the includes files can be found in.
|
||||
compatible_with: Optional, passed to genrule for environments this rule
|
||||
can be built for.
|
||||
flatc_args: Optional, list of additional arguments to pass to flatc.
|
||||
reflection_name: Optional, if set this will generate the flatbuffer
|
||||
reflection binaries for the schemas.
|
||||
reflection_visibility: The visibility of the generated reflection Fileset.
|
||||
output_to_bindir: Passed to genrule for output to bin directory.
|
||||
"""
|
||||
include_paths_cmd = ["-I %s" % (s) for s in include_paths]
|
||||
|
||||
# '$(@D)' when given a single source target will give the appropriate
|
||||
# directory. Appending 'out_prefix' is only necessary when given a build
|
||||
# target with multiple sources.
|
||||
output_directory = (
|
||||
("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)")
|
||||
)
|
||||
genrule_cmd = " ".join([
|
||||
"for f in $(SRCS); do",
|
||||
"$(location %s)" % (flatc_path),
|
||||
" ".join(flatc_args),
|
||||
" ".join(include_paths_cmd),
|
||||
language_flag,
|
||||
output_directory,
|
||||
"$$f;",
|
||||
"done",
|
||||
])
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
outs = outs,
|
||||
output_to_bindir = output_to_bindir,
|
||||
compatible_with = compatible_with,
|
||||
tools = includes + [flatc_path],
|
||||
cmd = genrule_cmd,
|
||||
message = "Generating flatbuffer files for %s:" % (name),
|
||||
)
|
||||
if reflection_name:
|
||||
reflection_genrule_cmd = " ".join([
|
||||
"for f in $(SRCS); do",
|
||||
"$(location %s)" % (flatc_path),
|
||||
"-b --schema",
|
||||
" ".join(flatc_args),
|
||||
" ".join(include_paths_cmd),
|
||||
language_flag,
|
||||
output_directory,
|
||||
"$$f;",
|
||||
"done",
|
||||
])
|
||||
reflection_outs = [
|
||||
(out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1])
|
||||
for s in srcs
|
||||
]
|
||||
native.genrule(
|
||||
name = "%s_srcs" % reflection_name,
|
||||
srcs = srcs,
|
||||
outs = reflection_outs,
|
||||
output_to_bindir = output_to_bindir,
|
||||
compatible_with = compatible_with,
|
||||
tools = includes + [flatc_path],
|
||||
cmd = reflection_genrule_cmd,
|
||||
message = "Generating flatbuffer reflection binary for %s:" % (name),
|
||||
)
|
||||
# TODO(b/114456773): Make bazel rules proper and supported by flatbuffer
|
||||
# Have to comment this since FilesetEntry is not supported in bazel
|
||||
# starlark.
|
||||
# native.Fileset(
|
||||
# name = reflection_name,
|
||||
# out = "%s_out" % reflection_name,
|
||||
# entries = [
|
||||
# native.FilesetEntry(files = reflection_outs),
|
||||
# ],
|
||||
# visibility = reflection_visibility,
|
||||
# compatible_with = compatible_with,
|
||||
# )
|
||||
|
||||
def flatbuffer_cc_library(
|
||||
name,
|
||||
srcs,
|
||||
srcs_filegroup_name = "",
|
||||
out_prefix = "",
|
||||
includes = [],
|
||||
include_paths = [],
|
||||
compatible_with = [],
|
||||
flatc_args = DEFAULT_FLATC_ARGS,
|
||||
visibility = None,
|
||||
srcs_filegroup_visibility = None,
|
||||
gen_reflections = False):
|
||||
'''A cc_library with the generated reader/writers for the given flatbuffer definitions.
|
||||
|
||||
Outs:
|
||||
filegroup([name]_srcs): all generated .h files.
|
||||
filegroup(srcs_filegroup_name if specified, or [name]_includes if not):
|
||||
Other flatbuffer_cc_library's can pass this in for their `includes`
|
||||
parameter, if they depend on the schemas in this library.
|
||||
Fileset([name]_reflection): (Optional) all generated reflection binaries.
|
||||
cc_library([name]): library with sources and flatbuffers deps.
|
||||
|
||||
Remarks:
|
||||
** Because the genrule used to call flatc does not have any trivial way of
|
||||
computing the output list of files transitively generated by includes and
|
||||
--gen-includes (the default) being defined for flatc, the --gen-includes
|
||||
flag will not work as expected. The way around this is to add a dependency
|
||||
to the flatbuffer_cc_library defined alongside the flatc included Fileset.
|
||||
For example you might define:
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "my_fbs",
|
||||
srcs = [ "schemas/foo.fbs" ],
|
||||
includes = [ "//third_party/bazz:bazz_fbs_includes" ],
|
||||
)
|
||||
|
||||
In which foo.fbs includes a few files from the Fileset defined at
|
||||
//third_party/bazz:bazz_fbs_includes. When compiling the library that
|
||||
includes foo_generated.h, and therefore has my_fbs as a dependency, it
|
||||
will fail to find any of the bazz *_generated.h files unless you also
|
||||
add bazz's flatbuffer_cc_library to your own dependency list, e.g.:
|
||||
|
||||
cc_library(
|
||||
name = "my_lib",
|
||||
deps = [
|
||||
":my_fbs",
|
||||
"//third_party/bazz:bazz_fbs"
|
||||
],
|
||||
)
|
||||
|
||||
Happy dependent Flatbuffering!
|
||||
|
||||
Args:
|
||||
name: Rule name.
|
||||
srcs: Source .fbs files. Sent in order to the compiler.
|
||||
srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this
|
||||
filegroup into the `includes` parameter of any other
|
||||
flatbuffer_cc_library that depends on this one's schemas.
|
||||
out_prefix: Prepend this path to the front of all generated files. Usually
|
||||
is a directory name.
|
||||
includes: Optional, list of filegroups of schemas that the srcs depend on.
|
||||
** SEE REMARKS BELOW **
|
||||
include_paths: Optional, list of paths the includes files can be found in.
|
||||
compatible_with: Optional, passed to genrule for environments this rule
|
||||
can be built for
|
||||
flatc_args: Optional list of additional arguments to pass to flatc
|
||||
(e.g. --gen-mutable).
|
||||
visibility: The visibility of the generated cc_library. By default, use the
|
||||
default visibility of the project.
|
||||
srcs_filegroup_visibility: The visibility of the generated srcs filegroup.
|
||||
By default, use the value of the visibility parameter above.
|
||||
gen_reflections: Optional, if true this will generate the flatbuffer
|
||||
reflection binaries for the schemas.
|
||||
'''
|
||||
output_headers = [
|
||||
(out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1])
|
||||
for s in srcs
|
||||
]
|
||||
reflection_name = "%s_reflection" % name if gen_reflections else ""
|
||||
|
||||
flatbuffer_library_public(
|
||||
name = "%s_srcs" % (name),
|
||||
srcs = srcs,
|
||||
outs = output_headers,
|
||||
language_flag = "-c",
|
||||
out_prefix = out_prefix,
|
||||
includes = includes,
|
||||
include_paths = include_paths,
|
||||
compatible_with = compatible_with,
|
||||
flatc_args = flatc_args,
|
||||
reflection_name = reflection_name,
|
||||
reflection_visibility = visibility,
|
||||
)
|
||||
native.cc_library(
|
||||
name = name,
|
||||
hdrs = output_headers,
|
||||
srcs = output_headers,
|
||||
features = [
|
||||
"-parse_headers",
|
||||
],
|
||||
deps = [
|
||||
"@flatbuffers//:runtime_cc",
|
||||
],
|
||||
includes = ["."],
|
||||
linkstatic = 1,
|
||||
visibility = visibility,
|
||||
compatible_with = compatible_with,
|
||||
)
|
||||
|
||||
# A filegroup for the `srcs`. That is, all the schema files for this
|
||||
# Flatbuffer set.
|
||||
native.filegroup(
|
||||
name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name),
|
||||
srcs = srcs,
|
||||
visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility,
|
||||
compatible_with = compatible_with,
|
||||
)
|
||||
|
||||
# Custom provider to track dependencies transitively.
|
||||
FlatbufferInfo = provider(
|
||||
fields = {
|
||||
"transitive_srcs": "flatbuffer schema definitions.",
|
||||
},
|
||||
) # buildifier: disable=provider-params
|
||||
|
||||
def _flatbuffer_schemas_aspect_impl(target, ctx):
|
||||
_ignore = [target] # @unused
|
||||
transitive_srcs = depset()
|
||||
if hasattr(ctx.rule.attr, "deps"):
|
||||
for dep in ctx.rule.attr.deps:
|
||||
if FlatbufferInfo in dep:
|
||||
transitive_srcs = depset(dep[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs]) # buildifier: disable=overly-nested-depset
|
||||
if hasattr(ctx.rule.attr, "srcs"):
|
||||
for src in ctx.rule.attr.srcs:
|
||||
if FlatbufferInfo in src:
|
||||
transitive_srcs = depset(src[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs]) # buildifier: disable=overly-nested-depset
|
||||
for f in src.files:
|
||||
if f.extension == "fbs":
|
||||
transitive_srcs = depset([f], transitive = [transitive_srcs]) # buildifier: disable=overly-nested-depset
|
||||
return [FlatbufferInfo(transitive_srcs = transitive_srcs)]
|
||||
|
||||
# An aspect that runs over all dependencies and transitively collects
|
||||
# flatbuffer schema files.
|
||||
_flatbuffer_schemas_aspect = aspect(
|
||||
attr_aspects = [
|
||||
"deps",
|
||||
"srcs",
|
||||
],
|
||||
implementation = _flatbuffer_schemas_aspect_impl,
|
||||
)
|
||||
|
||||
# Rule to invoke the flatbuffer compiler.
|
||||
def _gen_flatbuffer_srcs_impl(ctx):
|
||||
outputs = ctx.attr.outputs
|
||||
include_paths = ctx.attr.include_paths
|
||||
if ctx.attr.no_includes:
|
||||
no_includes_statement = ["--no-includes"]
|
||||
else:
|
||||
no_includes_statement = []
|
||||
|
||||
if ctx.attr.language_flag == "--python":
|
||||
onefile_statement = ["--gen-onefile"]
|
||||
else:
|
||||
onefile_statement = []
|
||||
|
||||
# Need to generate all files in a directory.
|
||||
if not outputs:
|
||||
outputs = [ctx.actions.declare_directory("{}_all".format(ctx.attr.name))]
|
||||
output_directory = outputs[0].path
|
||||
else:
|
||||
outputs = [ctx.actions.declare_file(output) for output in outputs]
|
||||
output_directory = outputs[0].dirname
|
||||
|
||||
deps = depset(ctx.files.srcs + ctx.files.deps, transitive = [
|
||||
dep[FlatbufferInfo].transitive_srcs
|
||||
for dep in ctx.attr.deps
|
||||
if FlatbufferInfo in dep
|
||||
])
|
||||
|
||||
include_paths_cmd_line = []
|
||||
for s in include_paths:
|
||||
include_paths_cmd_line.extend(["-I", s])
|
||||
|
||||
for src in ctx.files.srcs:
|
||||
ctx.actions.run(
|
||||
inputs = deps,
|
||||
outputs = outputs,
|
||||
executable = ctx.executable._flatc,
|
||||
arguments = [
|
||||
ctx.attr.language_flag,
|
||||
"-o",
|
||||
output_directory,
|
||||
# Allow for absolute imports and referencing of generated files.
|
||||
"-I",
|
||||
"./",
|
||||
"-I",
|
||||
ctx.genfiles_dir.path,
|
||||
"-I",
|
||||
ctx.bin_dir.path,
|
||||
] + no_includes_statement +
|
||||
onefile_statement +
|
||||
include_paths_cmd_line + [
|
||||
"--no-union-value-namespacing",
|
||||
"--gen-object-api",
|
||||
src.path,
|
||||
],
|
||||
progress_message = "Generating flatbuffer files for {}:".format(src),
|
||||
use_default_shell_env = True,
|
||||
)
|
||||
return [
|
||||
DefaultInfo(files = depset(outputs)),
|
||||
]
|
||||
|
||||
_gen_flatbuffer_srcs = rule(
|
||||
_gen_flatbuffer_srcs_impl,
|
||||
attrs = {
|
||||
"srcs": attr.label_list(
|
||||
allow_files = [".fbs"],
|
||||
mandatory = True,
|
||||
),
|
||||
"outputs": attr.string_list(
|
||||
default = [],
|
||||
mandatory = False,
|
||||
),
|
||||
"deps": attr.label_list(
|
||||
default = [],
|
||||
mandatory = False,
|
||||
aspects = [_flatbuffer_schemas_aspect],
|
||||
),
|
||||
"include_paths": attr.string_list(
|
||||
default = [],
|
||||
mandatory = False,
|
||||
),
|
||||
"language_flag": attr.string(
|
||||
mandatory = True,
|
||||
),
|
||||
"no_includes": attr.bool(
|
||||
default = False,
|
||||
mandatory = False,
|
||||
),
|
||||
"_flatc": attr.label(
|
||||
default = Label("@flatbuffers//:flatc"),
|
||||
executable = True,
|
||||
cfg = "exec",
|
||||
),
|
||||
},
|
||||
output_to_genfiles = True,
|
||||
)
|
||||
|
||||
def flatbuffer_py_strip_prefix_srcs(name, srcs = [], strip_prefix = ""):
|
||||
"""Strips path prefix.
|
||||
|
||||
Args:
|
||||
name: Rule name. (required)
|
||||
srcs: Source .py files. (required)
|
||||
strip_prefix: Path that needs to be stripped from the srcs filepaths. (required)
|
||||
"""
|
||||
for src in srcs:
|
||||
native.genrule(
|
||||
name = name + "_" + src.replace(".", "_").replace("/", "_"),
|
||||
srcs = [src],
|
||||
outs = [src.replace(strip_prefix, "")],
|
||||
cmd = "cp $< $@",
|
||||
)
|
||||
|
||||
def _concat_flatbuffer_py_srcs_impl(ctx):
|
||||
# Merge all generated python files. The files are concatenated and import
|
||||
# statements are removed. Finally we import the flatbuffer runtime library.
|
||||
# IMPORTANT: Our Windows shell does not support "find ... -exec" properly.
|
||||
# If you're changing the commandline below, please build wheels and run smoke
|
||||
# tests on all the three operating systems.
|
||||
command = "echo 'import flatbuffers\n' > %s; "
|
||||
command += "for f in $(find %s -name '*.py' | sort); do cat $f | sed '/import flatbuffers/d' >> %s; done "
|
||||
ctx.actions.run_shell(
|
||||
inputs = ctx.attr.deps[0].files,
|
||||
outputs = [ctx.outputs.out],
|
||||
command = command % (
|
||||
ctx.outputs.out.path,
|
||||
ctx.attr.deps[0].files.to_list()[0].path,
|
||||
ctx.outputs.out.path,
|
||||
),
|
||||
use_default_shell_env = True,
|
||||
)
|
||||
|
||||
_concat_flatbuffer_py_srcs = rule(
|
||||
_concat_flatbuffer_py_srcs_impl,
|
||||
attrs = {
|
||||
"deps": attr.label_list(mandatory = True),
|
||||
},
|
||||
output_to_genfiles = True,
|
||||
outputs = {"out": "%{name}.py"},
|
||||
)
|
||||
|
||||
def flatbuffer_py_library(
|
||||
name,
|
||||
srcs,
|
||||
deps = [],
|
||||
include_paths = []):
|
||||
"""A py_library with the generated reader/writers for the given schema.
|
||||
|
||||
This rule assumes that the schema files define non-conflicting names, so that
|
||||
they can be merged in a single file. This is e.g. the case if only a single
|
||||
namespace is used.
|
||||
The rule call the flatbuffer compiler for all schema files and merges the
|
||||
generated python files into a single file that is wrapped in a py_library.
|
||||
|
||||
Args:
|
||||
name: Rule name. (required)
|
||||
srcs: List of source .fbs files. (required)
|
||||
deps: List of dependencies.
|
||||
include_paths: Optional, list of paths the includes files can be found in.
|
||||
"""
|
||||
all_srcs = "{}_srcs".format(name)
|
||||
_gen_flatbuffer_srcs(
|
||||
name = all_srcs,
|
||||
srcs = srcs,
|
||||
language_flag = "--python",
|
||||
deps = deps,
|
||||
include_paths = include_paths,
|
||||
)
|
||||
|
||||
# TODO(b/235550563): Remove the concatnation rule with 2.0.6 update.
|
||||
all_srcs_no_include = "{}_srcs_no_include".format(name)
|
||||
_gen_flatbuffer_srcs(
|
||||
name = all_srcs_no_include,
|
||||
srcs = srcs,
|
||||
language_flag = "--python",
|
||||
deps = deps,
|
||||
no_includes = True,
|
||||
include_paths = include_paths,
|
||||
)
|
||||
concat_py_srcs = "{}_generated".format(name)
|
||||
_concat_flatbuffer_py_srcs(
|
||||
name = concat_py_srcs,
|
||||
deps = [
|
||||
":{}".format(all_srcs_no_include),
|
||||
],
|
||||
)
|
||||
native.py_library(
|
||||
name = name,
|
||||
srcs = [
|
||||
":{}".format(concat_py_srcs),
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = deps + [
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
def flatbuffer_java_library(
|
||||
name,
|
||||
srcs,
|
||||
custom_package = "",
|
||||
package_prefix = "",
|
||||
include_paths = DEFAULT_INCLUDE_PATHS,
|
||||
flatc_args = DEFAULT_FLATC_ARGS,
|
||||
visibility = None):
|
||||
"""A java library with the generated reader/writers for the given flatbuffer definitions.
|
||||
|
||||
Args:
|
||||
name: Rule name. (required)
|
||||
srcs: List of source .fbs files including all includes. (required)
|
||||
custom_package: Package name of generated Java files. If not specified
|
||||
namespace in the schema files will be used. (optional)
|
||||
package_prefix: like custom_package, but prefixes to the existing
|
||||
namespace. (optional)
|
||||
include_paths: List of paths that includes files can be found in. (optional)
|
||||
flatc_args: List of additional arguments to pass to flatc. (optional)
|
||||
visibility: Visibility setting for the java_library rule. (optional)
|
||||
"""
|
||||
out_srcjar = "java_%s_all.srcjar" % name
|
||||
flatbuffer_java_srcjar(
|
||||
name = "%s_srcjar" % name,
|
||||
srcs = srcs,
|
||||
out = out_srcjar,
|
||||
custom_package = custom_package,
|
||||
flatc_args = flatc_args,
|
||||
include_paths = include_paths,
|
||||
package_prefix = package_prefix,
|
||||
)
|
||||
|
||||
native.filegroup(
|
||||
name = "%s.srcjar" % name,
|
||||
srcs = [out_srcjar],
|
||||
)
|
||||
|
||||
native.java_library(
|
||||
name = name,
|
||||
srcs = [out_srcjar],
|
||||
javacopts = ["-source 7 -target 7"],
|
||||
deps = [
|
||||
"@flatbuffers//:runtime_java",
|
||||
],
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
def flatbuffer_java_srcjar(
|
||||
name,
|
||||
srcs,
|
||||
out,
|
||||
custom_package = "",
|
||||
package_prefix = "",
|
||||
include_paths = DEFAULT_INCLUDE_PATHS,
|
||||
flatc_args = DEFAULT_FLATC_ARGS):
|
||||
"""Generate flatbuffer Java source files.
|
||||
|
||||
Args:
|
||||
name: Rule name. (required)
|
||||
srcs: List of source .fbs files including all includes. (required)
|
||||
out: Output file name. (required)
|
||||
custom_package: Package name of generated Java files. If not specified
|
||||
namespace in the schema files will be used. (optional)
|
||||
package_prefix: like custom_package, but prefixes to the existing
|
||||
namespace. (optional)
|
||||
include_paths: List of paths that includes files can be found in. (optional)
|
||||
flatc_args: List of additional arguments to pass to flatc. (optional)
|
||||
"""
|
||||
command_fmt = """set -e
|
||||
tmpdir=$(@D)
|
||||
schemas=$$tmpdir/schemas
|
||||
java_root=$$tmpdir/java
|
||||
rm -rf $$schemas
|
||||
rm -rf $$java_root
|
||||
mkdir -p $$schemas
|
||||
mkdir -p $$java_root
|
||||
|
||||
for src in $(SRCS); do
|
||||
dest=$$schemas/$$src
|
||||
rm -rf $$(dirname $$dest)
|
||||
mkdir -p $$(dirname $$dest)
|
||||
if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then
|
||||
cp -f $$src $$dest
|
||||
else
|
||||
if [ -z "{package_prefix}" ]; then
|
||||
sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest
|
||||
else
|
||||
sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
flatc_arg_I="-I $$tmpdir/schemas"
|
||||
for include_path in {include_paths}; do
|
||||
flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path"
|
||||
done
|
||||
|
||||
flatc_additional_args=
|
||||
for arg in {flatc_args}; do
|
||||
flatc_additional_args="$$flatc_additional_args $$arg"
|
||||
done
|
||||
|
||||
for src in $(SRCS); do
|
||||
$(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src
|
||||
done
|
||||
|
||||
$(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root
|
||||
"""
|
||||
genrule_cmd = command_fmt.format(
|
||||
package_name = native.package_name(),
|
||||
custom_package = custom_package,
|
||||
package_prefix = package_prefix,
|
||||
flatc_path = flatc_path,
|
||||
zip_files = zip_files,
|
||||
include_paths = " ".join(include_paths),
|
||||
flatc_args = " ".join(flatc_args),
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
outs = [out],
|
||||
tools = [flatc_path, zip_files],
|
||||
cmd = genrule_cmd,
|
||||
)
|
||||
|
||||
def flatbuffer_android_library(
|
||||
name,
|
||||
srcs,
|
||||
custom_package = "",
|
||||
package_prefix = "",
|
||||
include_paths = DEFAULT_INCLUDE_PATHS,
|
||||
flatc_args = DEFAULT_FLATC_ARGS,
|
||||
visibility = None):
|
||||
"""An android_library with the generated reader/writers for the given flatbuffer definitions.
|
||||
|
||||
Args:
|
||||
name: Rule name. (required)
|
||||
srcs: List of source .fbs files including all includes. (required)
|
||||
custom_package: Package name of generated Java files. If not specified
|
||||
namespace in the schema files will be used. (optional)
|
||||
package_prefix: like custom_package, but prefixes to the existing
|
||||
namespace. (optional)
|
||||
include_paths: List of paths that includes files can be found in. (optional)
|
||||
flatc_args: List of additional arguments to pass to flatc. (optional)
|
||||
visibility: Visibility setting for the android_library rule. (optional)
|
||||
"""
|
||||
out_srcjar = "android_%s_all.srcjar" % name
|
||||
flatbuffer_java_srcjar(
|
||||
name = "%s_srcjar" % name,
|
||||
srcs = srcs,
|
||||
out = out_srcjar,
|
||||
custom_package = custom_package,
|
||||
flatc_args = flatc_args,
|
||||
include_paths = include_paths,
|
||||
package_prefix = package_prefix,
|
||||
)
|
||||
|
||||
native.filegroup(
|
||||
name = "%s.srcjar" % name,
|
||||
srcs = [out_srcjar],
|
||||
)
|
||||
|
||||
# To support org.checkerframework.dataflow.qual.Pure.
|
||||
checkerframework_annotations = [
|
||||
"@org_checkerframework_qual",
|
||||
] if "--java-checkerframework" in flatc_args else []
|
||||
|
||||
android_library(
|
||||
name = name,
|
||||
srcs = [out_srcjar],
|
||||
javacopts = ["-source 7 -target 7"],
|
||||
visibility = visibility,
|
||||
deps = [
|
||||
"@flatbuffers//:runtime_android",
|
||||
] + checkerframework_annotations,
|
||||
)
|
192
third_party/flatbuffers/flatbuffers.BUILD
vendored
Normal file
192
third_party/flatbuffers/flatbuffers.BUILD
vendored
Normal file
@ -0,0 +1,192 @@
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
load(":build_defs.bzl", "flatbuffer_py_strip_prefix_srcs")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE.txt"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
config_setting(
|
||||
name = "platform_freebsd",
|
||||
values = {"cpu": "freebsd"},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "platform_openbsd",
|
||||
values = {"cpu": "openbsd"},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "windows",
|
||||
values = {"cpu": "x64_windows"},
|
||||
)
|
||||
|
||||
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
|
||||
|
||||
# Public flatc library to compile flatbuffer files at runtime.
|
||||
cc_library(
|
||||
name = "flatbuffers",
|
||||
hdrs = ["//:public_headers"],
|
||||
linkstatic = 1,
|
||||
strip_include_prefix = "/include",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//src:flatbuffers"],
|
||||
)
|
||||
|
||||
# Public C++ headers for the Flatbuffers library.
|
||||
filegroup(
|
||||
name = "public_headers",
|
||||
srcs = [
|
||||
"include/flatbuffers/allocator.h",
|
||||
"include/flatbuffers/array.h",
|
||||
"include/flatbuffers/base.h",
|
||||
"include/flatbuffers/bfbs_generator.h",
|
||||
"include/flatbuffers/buffer.h",
|
||||
"include/flatbuffers/buffer_ref.h",
|
||||
"include/flatbuffers/code_generators.h",
|
||||
"include/flatbuffers/default_allocator.h",
|
||||
"include/flatbuffers/detached_buffer.h",
|
||||
"include/flatbuffers/flatbuffer_builder.h",
|
||||
"include/flatbuffers/flatbuffers.h",
|
||||
"include/flatbuffers/flexbuffers.h",
|
||||
"include/flatbuffers/hash.h",
|
||||
"include/flatbuffers/idl.h",
|
||||
"include/flatbuffers/minireflect.h",
|
||||
"include/flatbuffers/reflection.h",
|
||||
"include/flatbuffers/reflection_generated.h",
|
||||
"include/flatbuffers/registry.h",
|
||||
"include/flatbuffers/stl_emulation.h",
|
||||
"include/flatbuffers/string.h",
|
||||
"include/flatbuffers/struct.h",
|
||||
"include/flatbuffers/table.h",
|
||||
"include/flatbuffers/util.h",
|
||||
"include/flatbuffers/vector.h",
|
||||
"include/flatbuffers/vector_downward.h",
|
||||
"include/flatbuffers/verifier.h",
|
||||
],
|
||||
visibility = ["//:__subpackages__"],
|
||||
)
|
||||
|
||||
# Public flatc compiler library.
|
||||
cc_library(
|
||||
name = "flatc_library",
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@flatbuffers//src:flatc_library",
|
||||
],
|
||||
)
|
||||
|
||||
# Public flatc compiler.
|
||||
cc_binary(
|
||||
name = "flatc",
|
||||
linkopts = select({
|
||||
":platform_freebsd": [
|
||||
"-lm",
|
||||
],
|
||||
":windows": [],
|
||||
"//conditions:default": [
|
||||
"-lm",
|
||||
"-ldl",
|
||||
],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@flatbuffers//src:flatc",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "flatc_headers",
|
||||
srcs = [
|
||||
"include/flatbuffers/flatc.h",
|
||||
],
|
||||
visibility = ["//:__subpackages__"],
|
||||
)
|
||||
|
||||
# Library used by flatbuffer_cc_library rules.
|
||||
cc_library(
|
||||
name = "runtime_cc",
|
||||
hdrs = [
|
||||
"include/flatbuffers/allocator.h",
|
||||
"include/flatbuffers/array.h",
|
||||
"include/flatbuffers/base.h",
|
||||
"include/flatbuffers/buffer.h",
|
||||
"include/flatbuffers/buffer_ref.h",
|
||||
"include/flatbuffers/default_allocator.h",
|
||||
"include/flatbuffers/detached_buffer.h",
|
||||
"include/flatbuffers/flatbuffer_builder.h",
|
||||
"include/flatbuffers/flatbuffers.h",
|
||||
"include/flatbuffers/flexbuffers.h",
|
||||
"include/flatbuffers/stl_emulation.h",
|
||||
"include/flatbuffers/string.h",
|
||||
"include/flatbuffers/struct.h",
|
||||
"include/flatbuffers/table.h",
|
||||
"include/flatbuffers/util.h",
|
||||
"include/flatbuffers/vector.h",
|
||||
"include/flatbuffers/vector_downward.h",
|
||||
"include/flatbuffers/verifier.h",
|
||||
],
|
||||
linkstatic = 1,
|
||||
strip_include_prefix = "/include",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
flatbuffer_py_strip_prefix_srcs(
|
||||
name = "flatbuffer_py_strip_prefix",
|
||||
srcs = [
|
||||
"python/flatbuffers/__init__.py",
|
||||
"python/flatbuffers/_version.py",
|
||||
"python/flatbuffers/builder.py",
|
||||
"python/flatbuffers/compat.py",
|
||||
"python/flatbuffers/encode.py",
|
||||
"python/flatbuffers/flexbuffers.py",
|
||||
"python/flatbuffers/number_types.py",
|
||||
"python/flatbuffers/packer.py",
|
||||
"python/flatbuffers/table.py",
|
||||
"python/flatbuffers/util.py",
|
||||
],
|
||||
strip_prefix = "python/flatbuffers/",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "runtime_py_srcs",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"_version.py",
|
||||
"builder.py",
|
||||
"compat.py",
|
||||
"encode.py",
|
||||
"flexbuffers.py",
|
||||
"number_types.py",
|
||||
"packer.py",
|
||||
"table.py",
|
||||
"util.py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "runtime_py",
|
||||
srcs = [":runtime_py_srcs"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "runtime_java_srcs",
|
||||
srcs = glob(["java/com/google/flatbuffers/**/*.java"]),
|
||||
)
|
||||
|
||||
java_library(
|
||||
name = "runtime_java",
|
||||
srcs = [":runtime_java_srcs"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "runtime_android",
|
||||
srcs = [":runtime_java_srcs"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
30
third_party/flatbuffers/workspace.bzl
vendored
Normal file
30
third_party/flatbuffers/workspace.bzl
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Loads the Flatbuffers library."""
|
||||
|
||||
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
name = "flatbuffers",
|
||||
strip_prefix = "flatbuffers-2.0.6",
|
||||
sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9",
|
||||
urls = tf_mirror_urls("https://github.com/google/flatbuffers/archive/v2.0.6.tar.gz"),
|
||||
build_file = "//third_party/flatbuffers:flatbuffers.BUILD",
|
||||
system_build_file = "//third_party/flatbuffers:BUILD.system",
|
||||
link_files = {
|
||||
"//third_party/flatbuffers:build_defs.bzl": "build_defs.bzl",
|
||||
},
|
||||
)
|
166
third_party/repo.bzl
vendored
Normal file
166
third_party/repo.bzl
vendored
Normal file
@ -0,0 +1,166 @@
|
||||
# Copyright 2017 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for defining TensorFlow Bazel dependencies."""
|
||||
|
||||
def tf_mirror_urls(url):
|
||||
"""A helper for generating TF-mirror versions of URLs.
|
||||
|
||||
Given a URL, it returns a list of the TF-mirror cache version of that URL
|
||||
and the original URL, suitable for use in `urls` field of `tf_http_archive`.
|
||||
"""
|
||||
if not url.startswith("https://"):
|
||||
return [url]
|
||||
return [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/%s" % url[8:],
|
||||
url,
|
||||
]
|
||||
|
||||
def _get_env_var(ctx, name):
|
||||
if name in ctx.os.environ:
|
||||
return ctx.os.environ[name]
|
||||
else:
|
||||
return None
|
||||
|
||||
# Checks if we should use the system lib instead of the bundled one
|
||||
def _use_system_lib(ctx, name):
|
||||
syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS")
|
||||
if not syslibenv:
|
||||
return False
|
||||
return name in [n.strip() for n in syslibenv.split(",")]
|
||||
|
||||
def _get_link_dict(ctx, link_files, build_file):
|
||||
link_dict = {ctx.path(v): ctx.path(Label(k)) for k, v in link_files.items()}
|
||||
if build_file:
|
||||
# Use BUILD.bazel because it takes precedence over BUILD.
|
||||
link_dict[ctx.path("BUILD.bazel")] = ctx.path(Label(build_file))
|
||||
return link_dict
|
||||
|
||||
def _tf_http_archive_impl(ctx):
|
||||
# Construct all paths early on to prevent rule restart. We want the
|
||||
# attributes to be strings instead of labels because they refer to files
|
||||
# in the TensorFlow repository, not files in repos depending on TensorFlow.
|
||||
# See also https://github.com/bazelbuild/bazel/issues/10515.
|
||||
link_dict = _get_link_dict(ctx, ctx.attr.link_files, ctx.attr.build_file)
|
||||
|
||||
# For some reason, we need to "resolve" labels once before the
|
||||
# download_and_extract otherwise it'll invalidate and re-download the
|
||||
# archive each time.
|
||||
# https://github.com/bazelbuild/bazel/issues/10515
|
||||
patch_files = ctx.attr.patch_file
|
||||
for patch_file in patch_files:
|
||||
if patch_file:
|
||||
ctx.path(Label(patch_file))
|
||||
|
||||
if _use_system_lib(ctx, ctx.attr.name):
|
||||
link_dict.update(_get_link_dict(
|
||||
ctx = ctx,
|
||||
link_files = ctx.attr.system_link_files,
|
||||
build_file = ctx.attr.system_build_file,
|
||||
))
|
||||
else:
|
||||
ctx.download_and_extract(
|
||||
url = ctx.attr.urls,
|
||||
sha256 = ctx.attr.sha256,
|
||||
type = ctx.attr.type,
|
||||
stripPrefix = ctx.attr.strip_prefix,
|
||||
)
|
||||
if patch_files:
|
||||
for patch_file in patch_files:
|
||||
patch_file = ctx.path(Label(patch_file)) if patch_file else None
|
||||
if patch_file:
|
||||
ctx.patch(patch_file, strip = 1)
|
||||
|
||||
for dst, src in link_dict.items():
|
||||
ctx.delete(dst)
|
||||
ctx.symlink(src, dst)
|
||||
|
||||
_tf_http_archive = repository_rule(
|
||||
implementation = _tf_http_archive_impl,
|
||||
attrs = {
|
||||
"sha256": attr.string(mandatory = True),
|
||||
"urls": attr.string_list(mandatory = True),
|
||||
"strip_prefix": attr.string(),
|
||||
"type": attr.string(),
|
||||
"patch_file": attr.string_list(),
|
||||
"build_file": attr.string(),
|
||||
"system_build_file": attr.string(),
|
||||
"link_files": attr.string_dict(),
|
||||
"system_link_files": attr.string_dict(),
|
||||
},
|
||||
environ = ["TF_SYSTEM_LIBS"],
|
||||
)
|
||||
|
||||
def tf_http_archive(name, sha256, urls, **kwargs):
|
||||
"""Downloads and creates Bazel repos for dependencies.
|
||||
|
||||
This is a swappable replacement for both http_archive() and
|
||||
new_http_archive() that offers some additional features. It also helps
|
||||
ensure best practices are followed.
|
||||
|
||||
File arguments are relative to the TensorFlow repository by default. Dependent
|
||||
repositories that use this rule should refer to files either with absolute
|
||||
labels (e.g. '@foo//:bar') or from a label created in their repository (e.g.
|
||||
'str(Label("//:bar"))').
|
||||
|
||||
Args:
|
||||
name: name of the repository
|
||||
sha256: sha256 sum as a string
|
||||
urls: list of mirror URLs
|
||||
**kwargs: additional arguments to _tf_http_archive
|
||||
"""
|
||||
if len(urls) < 2:
|
||||
fail("tf_http_archive(urls) must have redundant URLs.")
|
||||
|
||||
if not any([mirror in urls[0] for mirror in (
|
||||
"mirror.tensorflow.org",
|
||||
"mirror.bazel.build",
|
||||
"storage.googleapis.com",
|
||||
)]):
|
||||
fail("The first entry of tf_http_archive(urls) must be a mirror " +
|
||||
"URL, preferrably mirror.tensorflow.org. Even if you don't have " +
|
||||
"permission to mirror the file, please put the correctly " +
|
||||
"formatted mirror URL there anyway, because someone will come " +
|
||||
"along shortly thereafter and mirror the file.")
|
||||
|
||||
if native.existing_rule(name):
|
||||
print("\n\033[1;33mWarning:\033[0m skipping import of repository '" +
|
||||
name + "' because it already exists.\n") # buildifier: disable=print
|
||||
return
|
||||
|
||||
_tf_http_archive(
|
||||
name = name,
|
||||
sha256 = sha256,
|
||||
urls = urls,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _tf_vendored_impl(repository_ctx):
|
||||
parent_path = repository_ctx.path(repository_ctx.attr.parent).dirname
|
||||
|
||||
# get_child doesn't allow slashes. Yes this is silly. bazel_skylib paths
|
||||
# doesn't work with path objects.
|
||||
relpath_parts = repository_ctx.attr.relpath.split("/")
|
||||
vendored_path = parent_path
|
||||
for part in relpath_parts:
|
||||
vendored_path = vendored_path.get_child(part)
|
||||
repository_ctx.symlink(vendored_path, ".")
|
||||
|
||||
tf_vendored = repository_rule(
|
||||
implementation = _tf_vendored_impl,
|
||||
attrs = {
|
||||
"parent": attr.label(default = "//:WORKSPACE"),
|
||||
"relpath": attr.string(),
|
||||
},
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user