Merge pull request #14475 from hawkinsp:openxla

PiperOrigin-RevId: 516316330
This commit is contained in:
jax authors 2023-03-13 14:04:41 -07:00
commit 42ef649e65
39 changed files with 1288 additions and 144 deletions

View File

@ -1,6 +1,10 @@
############################################################################
# All default build options below.
# Required by OpenXLA
# https://github.com/openxla/xla/issues/1323
build --nocheck_visibility
# Sets the default Apple platform to macOS.
build --apple_platform_type=macos
build --macos_minimum_os=10.14
@ -35,9 +39,9 @@ build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
# Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled,
# these values are overridden.
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=false
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false
build --@xla//xla/python:enable_gpu=false
build --@xla//xla/python:enable_tpu=false
build --@xla//xla/python:enable_plugin_device=false
###########################################################################
@ -65,12 +69,12 @@ build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_52,sm_60,sm_70,compute_80"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda
build:cuda --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true
build:cuda --@xla//xla/python:enable_gpu=true
build:cuda --define=xla_python_enable_gpu=true
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true
build:rocm --@xla//xla/python:enable_gpu=true
build:rocm --define=xla_python_enable_gpu=true
build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
@ -113,10 +117,10 @@ build:macos --config=posix
# Suppress all warning messages.
build:short_logs --output_filter=DONT_MATCH_ANYTHING
build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true
build:tpu --@xla//xla/python:enable_tpu=true
build:tpu --define=with_tpu_support=true
build:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true
build:plugin_device --@xla//xla/python:enable_plugin_device=true
#########################################################################
# RBE config options below.

View File

@ -1,16 +1,16 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
# To update TensorFlow to a new revision,
# To update XLA to a new revision,
# a) update URL and strip_prefix to the new git commit hash
# b) get the sha256 hash of the commit by running:
# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "08fd0ab0b672510229ad2fff276a3634f205fc539fa16a5bdeeaaccd881ece27",
strip_prefix = "tensorflow-2aaeef25361311b21b9e81e992edff94bcb6bae3",
name = "xla",
sha256 = "9f39af4d81d2c8bd52b47f4ef37dfd6642c6950112e4d9d3d95ae19982c46eba",
strip_prefix = "xla-0f31407ee498e6dba242d03f8d382ebcfcc61790",
urls = [
"https://github.com/tensorflow/tensorflow/archive/2aaeef25361311b21b9e81e992edff94bcb6bae3.tar.gz",
"https://github.com/openxla/xla/archive/0f31407ee498e6dba242d03f8d382ebcfcc61790.tar.gz",
],
)
@ -19,26 +19,32 @@ http_archive(
# local checkout by either:
# a) overriding the TF repository on the build.py command line by passing a flag
# like:
# python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow
# python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
# or
# b) by commenting out the http_archive above and uncommenting the following:
# local_repository(
# name = "org_tensorflow",
# path = "/path/to/tensorflow",
# name = "xla",
# path = "/path/to/xla",
# )
load("//third_party/ducc:workspace.bzl", ducc = "repo")
ducc()
# Initialize TensorFlow's external dependencies.
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
tf_workspace3()
load("@xla//:workspace4.bzl", "xla_workspace4")
xla_workspace4()
load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2")
tf_workspace2()
load("@xla//:workspace3.bzl", "xla_workspace3")
xla_workspace3()
load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1")
tf_workspace1()
load("@xla//:workspace2.bzl", "xla_workspace2")
xla_workspace2()
load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0")
tf_workspace0()
load("@xla//:workspace1.bzl", "xla_workspace1")
xla_workspace1()
load("@xla//:workspace0.bzl", "xla_workspace0")
xla_workspace0()
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()

View File

@ -44,11 +44,11 @@ py_binary(
"//jaxlib:README.md",
"//jaxlib:setup.py",
"//jaxlib:setup.cfg",
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
"@xla//xla/python:xla_client",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + select({
":remote_tpu_enabled": ["@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client"],
":remote_tpu_enabled": ["@xla//xla/python/tpu_driver/client:py_tpu_client"],
"//conditions:default": [],
}) + if_cuda([
"//jaxlib/cuda:cuda_gpu_support",

View File

@ -103,14 +103,14 @@ def patch_copy_xla_extension_stubs(dst_dir):
os.makedirs(xla_extension_dir)
for stub_name in _XLA_EXTENSION_STUBS:
stub_path = r.Rlocation(
"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name)
"xla/xla/python/xla_extension/" + stub_name)
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
continue
with open(stub_path) as f:
src = f.read()
src = src.replace(
"from tensorflow.compiler.xla.python import xla_extension",
"from xla.python import xla_extension",
"from .. import xla_extension"
)
with open(os.path.join(xla_extension_dir, stub_name), "w") as f:
@ -118,14 +118,14 @@ def patch_copy_xla_extension_stubs(dst_dir):
def patch_copy_tpu_client_py(dst_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
with open(r.Rlocation("xla/xla/python/tpu_driver/client/tpu_client.py")) as f:
src = f.read()
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
src = src.replace("from xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
src = src.replace("from tensorflow.compiler.xla.python import xla_client",
src = src.replace("from xla.python import xla_client",
"from . import xla_client")
src = src.replace(
"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
"from xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
"from . import tpu_client_extension as _tpu_client")
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
f.write(src)
@ -143,7 +143,7 @@ def verify_mac_libraries_dont_reference_chkstack():
return
nm = subprocess.run(
["nm", "-g",
r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so")
r.Rlocation("xla/xla/python/xla_extension.so")
],
capture_output=True, text=True,
check=False)
@ -250,8 +250,8 @@ def prepare_wheel(sources_path):
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir)
patch_copy_xla_extension_stubs(jaxlib_dir)
if exists("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"):
copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so")
if exists("xla/xla/python/tpu_driver/client/tpu_client_extension.so"):
copy_to_jaxlib("xla/xla/python/tpu_driver/client/tpu_client_extension.so")
patch_copy_tpu_client_py(jaxlib_dir)

View File

@ -66,11 +66,11 @@ specify the paths to CUDA and CUDNN, which you must have installed. Here
may need to use `python3` instead. By default, the wheel is written to the
`dist/` subdirectory of the current directory.
### Building jaxlib from source with a modified TensorFlow repository.
### Building jaxlib from source with a modified XLA repository.
JAX depends on XLA, whose source code is in the
[Tensorflow GitHub repository](https://github.com/tensorflow/tensorflow).
By default JAX uses a pinned copy of the TensorFlow repository, but we often
[XLA GitHub repository](https://github.com/openxla/xla).
By default JAX uses a pinned copy of the XLA repository, but we often
want to use a locally-modified copy of XLA when working on JAX. There are two
ways to do this:
@ -78,12 +78,12 @@ ways to do this:
line flag to `build.py` as follows:
```
python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow
python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
```
* modify the `WORKSPACE` file in the root of the JAX source tree to point to
a different TensorFlow tree.
a different XLA tree.
To contribute changes back to XLA, send PRs to the TensorFlow repository.
To contribute changes back to XLA, send PRs to the XLA repository.
The version of XLA pinned by JAX is regularly updated, but is updated in
particular before each `jaxlib` release.
@ -141,7 +141,7 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
```
AMD's fork of the XLA (TensorFlow) repository may include fixes
AMD's fork of the XLA repository may include fixes
not present in the upstream repository. To use AMD's fork, you should clone
their repository:
```
@ -152,7 +152,7 @@ To build jaxlib with ROCM support, you can run the following build command,
suitably adjusted for your paths and ROCM version.
```
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.3.0 \
--bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow-upstream
--bazel_options=--override_repository=xla=/path/to/xla-upstream
```
## Installing `jax`

View File

@ -121,7 +121,7 @@ no released `jax` version uses that API.
`jaxlib` is split across two main repositories, namely the
[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib)
and in the
[XLA source tree, which lives inside the TensorFlow repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla).
[XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla).
The JAX-specific pieces inside XLA are primarily in the
[`xla/python` subdirectory](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/python).
@ -164,7 +164,7 @@ compatibility, we have additional versioning that is independent of the `jaxlib`
release version numbers.
We maintain an additional version number (`_version`) in
[`xla_client.py` in the XLA repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py).
[`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py).
The idea is that this version number, is defined in `xla/python`
together with the C++ parts of JAX, is also accessible to JAX Python as
`jax._src.lib.xla_extension_version`, and must

View File

@ -12,28 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_cc_binary",
)
licenses(["notice"])
tf_cc_binary(
cc_binary(
name = "main",
srcs = ["main.cc"],
tags = ["manual"],
deps = [
"@org_tensorflow//tensorflow/compiler/xla:literal",
"@org_tensorflow//tensorflow/compiler/xla:literal_util",
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
"@org_tensorflow//tensorflow/compiler/xla:status",
"@org_tensorflow//tensorflow/compiler/xla:statusor",
"@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client",
"@org_tensorflow//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
"@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader",
"@org_tensorflow//tensorflow/core/platform:logging",
"@org_tensorflow//tensorflow/core/platform:platform_port",
"@xla//:literal",
"@xla//:literal_util",
"@xla//:shape_util",
"@xla//:status",
"@xla//:statusor",
"@xla///pjrt:pjrt_client",
"@xla///pjrt:tfrt_cpu_pjrt_client",
"@xla///service:hlo_proto_cc",
"@xla///tools:hlo_module_loader",
"@tsl///platform:logging",
"@tsl///platform:platform_port",
],
)

View File

@ -40,18 +40,18 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/tools/hlo_module_loader.h"
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"
int main(int argc, char** argv) {
tensorflow::port::InitMain("", &argc, &argv);
tsl::port::InitMain("", &argc, &argv);
// Load HloModule from file.
std::string hlo_filename = "/tmp/fn_hlo.txt";

View File

@ -68,7 +68,7 @@ symlink_files(
symlink_files(
name = "xla_client",
srcs = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"],
srcs = ["@xla///python:xla_client"],
dst = ".",
flatten = True,
)
@ -76,8 +76,8 @@ symlink_files(
symlink_files(
name = "xla_extension",
srcs = if_windows(
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.pyd"],
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.so"],
["@xla///python:xla_extension.pyd"],
["@xla///python:xla_extension.so"],
),
dst = ".",
flatten = True,
@ -140,7 +140,7 @@ pybind_extension(
srcs = ["cpu_feature_guard.c"],
module_name = "cpu_feature_guard",
deps = [
"@org_tensorflow//third_party/python_runtime:headers",
"@xla//third_party/python_runtime:headers",
],
)

View File

@ -31,7 +31,7 @@ cc_library(
srcs = ["lapack_kernels.cc"],
hdrs = ["lapack_kernels.h"],
deps = [
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla///service:custom_call_status",
"@com_google_absl//absl/base:dynamic_annotations",
],
)
@ -74,7 +74,7 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":ducc_fft_flatbuffers_cc",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla///service:custom_call_status",
"@ducc",
"@flatbuffers//:runtime_cc",
],
@ -106,7 +106,7 @@ cc_library(
":ducc_fft_kernels",
":lapack_kernels",
":lapack_kernels_using_lapack",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
"@xla///service:custom_call_target_registry",
],
alwayslink = 1,
)

View File

@ -16,9 +16,9 @@ limitations under the License.
// This file is not used by JAX itself, but exists to assist with running
// JAX-generated HLO code from outside of JAX.
#include "jaxlib/cpu/lapack_kernels.h"
#include "jaxlib/cpu/ducc_fft_kernels.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "jaxlib/cpu/lapack_kernels.h"
#include "xla/service/custom_call_target_registry.h"
namespace jax {
namespace {

View File

@ -15,10 +15,10 @@ limitations under the License.
#include <complex>
#include "ducc/src/ducc0/fft/fft.h"
#include "flatbuffers/flatbuffers.h"
#include "jaxlib/cpu/ducc_fft_generated.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "ducc/src/ducc0/fft/fft.h"
#include "xla/service/custom_call_status.h"
namespace jax {

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef JAXLIB_CPU_DUCC_FFT_KERNELS_H_
#define JAXLIB_CPU_DUCC_FFT_KERNELS_H_
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <complex>
#include <cstdint>
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
// by the pybind wrapper that links them to an existing SciPy lapack instance,

View File

@ -51,8 +51,8 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":cuda_vendor",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib",
"@xla///stream_executor/cuda:cusolver_lib",
"@xla///stream_executor/cuda:cusparse_lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@ -72,9 +72,9 @@ cc_library(
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@xla///service:custom_call_status",
"@xla///stream_executor/cuda:cublas_lib",
"@xla///stream_executor/cuda:cudart_stub",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
@ -102,7 +102,7 @@ pybind_extension(
":cublas_kernels",
":cuda_vendor",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib",
"@xla///stream_executor/cuda:cublas_lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@pybind11",
@ -118,9 +118,9 @@ cc_library(
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudnn_lib",
"@xla///service:custom_call_status",
"@xla///stream_executor/cuda:cudart_stub",
"@xla///stream_executor/cuda:cudnn_lib",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
@ -157,8 +157,8 @@ cc_library(
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib",
"@xla///service:custom_call_status",
"@xla///stream_executor/cuda:cusolver_lib",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
@ -180,8 +180,8 @@ pybind_extension(
":cuda_vendor",
":cusolver_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib",
"@xla///stream_executor/cuda:cudart_stub",
"@xla///stream_executor/cuda:cusolver_lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
@ -198,9 +198,9 @@ cc_library(
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib",
"@xla///service:custom_call_status",
"@xla///stream_executor/cuda:cudart_stub",
"@xla///stream_executor/cuda:cusparse_lib",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
@ -222,8 +222,8 @@ pybind_extension(
":cuda_vendor",
":cusparse_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib",
"@xla///stream_executor/cuda:cudart_stub",
"@xla///stream_executor/cuda:cusparse_lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
@ -249,7 +249,7 @@ cc_library(
":cuda_lu_pivot_kernels_impl",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla///service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -264,7 +264,7 @@ cuda_library(
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla///service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -284,7 +284,7 @@ pybind_extension(
":cuda_lu_pivot_kernels_impl",
":cuda_vendor",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@xla///stream_executor/cuda:cudart_stub",
"@local_config_cuda//cuda:cuda_headers",
"@pybind11",
],
@ -301,7 +301,7 @@ cc_library(
":cuda_prng_kernels_impl",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla///service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -316,7 +316,7 @@ cuda_library(
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla///service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -334,7 +334,7 @@ pybind_extension(
":cuda_gpu_kernel_helpers",
":cuda_prng_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@xla///stream_executor/cuda:cudart_stub",
"@local_config_cuda//cuda:cuda_headers",
"@pybind11",
],
@ -351,7 +351,7 @@ cc_library(
":cuda_vendor",
":cusolver_kernels",
":cusparse_kernels",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
"@xla///service:custom_call_target_registry",
],
alwayslink = 1,
)

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <cstddef>
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "jaxlib/gpu/solver_kernels.h"
#include "jaxlib/gpu/sparse_kernels.h"
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "xla/service/custom_call_target_registry.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <string>
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <string>
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
#ifdef JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cusolverSp.h"

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/handle_pool.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
#ifdef JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cusolverSp.h"

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/handle_pool.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "xla/service/custom_call_status.h"
namespace jax {

View File

@ -14,12 +14,11 @@
"""Bazel macros used by the JAX build."""
load("@org_tensorflow//tensorflow/tsl/platform/default:build_config.bzl", _pyx_library = "pyx_library")
load("@org_tensorflow//tensorflow:tensorflow.bzl", _if_windows = "if_windows", _pybind_extension = "pybind_extension")
load("@tsl//tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
load("@org_tensorflow//tensorflow/core/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
# Explicitly re-exports names to avoid "unused variable" warnings from .bzl
# lint tools.
@ -27,7 +26,6 @@ cuda_library = _cuda_library
rocm_library = _rocm_library
pytype_strict_library = native.py_library
pytype_test = native.py_test
pyx_library = _pyx_library
pybind_extension = _pybind_extension
if_cuda_is_configured = _if_cuda_is_configured
if_rocm_is_configured = _if_rocm_is_configured
@ -66,8 +64,8 @@ def pytype_library(name, pytype_srcs = None, **kwargs):
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
lib_rule(name = name, **kwargs)
def py_extension(name, srcs, copts, deps):
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)
def py_extension(name, srcs, copts, deps, linkopts = []):
pybind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []):
"""Workaround DLL building issue.

View File

@ -120,7 +120,7 @@ symlink_inputs(
name = "mhlo_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:MhloOpsPyFiles",
"@xla//xla/mlir_hlo:MhloOpsPyFiles",
]}},
deps = [
":core",

View File

@ -31,12 +31,25 @@ COPTS = [
"-frtti",
]
LINKOPTS = select({
"@tsl//tsl:macos": [
"-Wl,-rpath,@loader_path/",
"-Wl,-rename_section,__TEXT,text_env,__TEXT,__text",
],
"@tsl//tsl:windows": [],
"//conditions:default": [
"-Wl,-rpath,$$ORIGIN/",
],
})
py_extension(
name = "_mlir",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI",
@ -50,6 +63,7 @@ py_extension(
"@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
@ -64,6 +78,7 @@ py_extension(
"@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
@ -90,6 +105,7 @@ py_extension(
name = "_site_initialize_0",
srcs = ["_site_initialize_0.cc"],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
@ -106,15 +122,16 @@ py_extension(
py_extension(
name = "_mlirHlo",
srcs = [
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:bindings/python/MlirHloModule.cc",
"@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@local_config_python//:headers",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:CAPIHeaders",
"@xla//xla/mlir_hlo:CAPIHeaders",
"@pybind11",
],
)
@ -129,6 +146,7 @@ py_extension(
"@stablehlo//:stablehlo/integrations/python/ChloModule.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
@ -145,6 +163,7 @@ py_extension(
"@stablehlo//:stablehlo/integrations/python/StablehloModule.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
@ -158,12 +177,12 @@ py_extension(
cc_library(
name = "jaxlib_mlir_capi_shared_library",
srcs = select({
"@org_tensorflow//tensorflow:windows": [":jaxlib_mlir_capi.dll"],
"@org_tensorflow//tensorflow:macos": [":libjaxlib_mlir_capi.dylib"],
"@tsl//tsl:windows": [":jaxlib_mlir_capi.dll"],
"@tsl//tsl:macos": [":libjaxlib_mlir_capi.dylib"],
"//conditions:default": [":libjaxlib_mlir_capi.so"],
}),
deps = select({
"@org_tensorflow//tensorflow:windows": [":jaxlib_mlir_capi_dll"],
"@tsl//tsl:windows": [":jaxlib_mlir_capi_dll"],
"//conditions:default": [],
}),
)
@ -174,7 +193,7 @@ cc_library(
"@llvm-project//mlir:CAPISparseTensorObjects",
"@llvm-project//mlir:CAPITransformsObjects",
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:CAPIObjects",
"@xla//xla/mlir_hlo:CAPIObjects",
"@stablehlo//:chlo_capi_objects",
"@stablehlo//:stablehlo_capi_objects",
],

View File

@ -75,7 +75,7 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla//xla/service:custom_call_status",
],
)
@ -114,7 +114,7 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipsolver",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla//xla/service:custom_call_status",
],
)
@ -154,7 +154,7 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipsparse",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla//xla/service:custom_call_status",
],
)
@ -197,7 +197,7 @@ cc_library(
":hip_vendor",
"//jaxlib:kernel_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla//xla/service:custom_call_status",
],
)
@ -210,7 +210,7 @@ rocm_library(
":hip_vendor",
"//jaxlib:kernel_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla//xla/service:custom_call_status",
],
)
@ -244,7 +244,7 @@ cc_library(
":hip_vendor",
"//jaxlib:kernel_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla//xla/service:custom_call_status",
],
)
@ -257,7 +257,7 @@ rocm_library(
":hip_vendor",
"//jaxlib:kernel_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla//xla/service:custom_call_status",
],
)

15
third_party/BUILD.bazel vendored Normal file
View File

@ -0,0 +1,15 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0

15
third_party/flatbuffers/BUILD.bazel vendored Normal file
View File

@ -0,0 +1,15 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This empty BUILD file is required to make Bazel treat this directory as a package.

43
third_party/flatbuffers/BUILD.system vendored Normal file
View File

@ -0,0 +1,43 @@
licenses(["notice"]) # Apache 2.0
filegroup(
name = "LICENSE.txt",
visibility = ["//visibility:public"],
)
# Public flatc library to compile flatbuffer files at runtime.
cc_library(
name = "flatbuffers",
linkopts = ["-lflatbuffers"],
visibility = ["//visibility:public"],
)
# Public flatc compiler library.
cc_library(
name = "flatc_library",
linkopts = ["-lflatbuffers"],
visibility = ["//visibility:public"],
)
genrule(
name = "lnflatc",
outs = ["flatc.bin"],
cmd = "ln -s $$(which flatc) $@",
)
# Public flatc compiler.
sh_binary(
name = "flatc",
srcs = ["flatc.bin"],
visibility = ["//visibility:public"],
)
cc_library(
name = "runtime_cc",
visibility = ["//visibility:public"],
)
py_library(
name = "runtime_py",
visibility = ["//visibility:public"],
)

661
third_party/flatbuffers/build_defs.bzl vendored Normal file
View File

@ -0,0 +1,661 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BUILD rules for generating flatbuffer files."""
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
flatc_path = "@flatbuffers//:flatc"
zip_files = "//tensorflow/lite/tools:zip_files"
DEFAULT_INCLUDE_PATHS = [
"./",
"$(GENDIR)",
"$(BINDIR)",
]
DEFAULT_FLATC_ARGS = [
"--no-union-value-namespacing",
"--gen-object-api",
]
def flatbuffer_library_public(
name,
srcs,
outs,
language_flag,
out_prefix = "",
includes = [],
include_paths = [],
compatible_with = [],
flatc_args = DEFAULT_FLATC_ARGS,
reflection_name = "",
reflection_visibility = None, # buildifier: disable=unused-variable
output_to_bindir = False):
"""Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler.
Outs:
filegroup(name): all generated source files.
Fileset([reflection_name]): (Optional) all generated reflection binaries.
Args:
name: Rule name.
srcs: Source .fbs files. Sent in order to the compiler.
outs: Output files from flatc.
language_flag: Target language flag. One of [-c, -j, -js].
out_prefix: Prepend this path to the front of all generated files except on
single source targets. Usually is a directory name.
includes: Optional, list of filegroups of schemas that the srcs depend on.
include_paths: Optional, list of paths the includes files can be found in.
compatible_with: Optional, passed to genrule for environments this rule
can be built for.
flatc_args: Optional, list of additional arguments to pass to flatc.
reflection_name: Optional, if set this will generate the flatbuffer
reflection binaries for the schemas.
reflection_visibility: The visibility of the generated reflection Fileset.
output_to_bindir: Passed to genrule for output to bin directory.
"""
include_paths_cmd = ["-I %s" % (s) for s in include_paths]
# '$(@D)' when given a single source target will give the appropriate
# directory. Appending 'out_prefix' is only necessary when given a build
# target with multiple sources.
output_directory = (
("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)")
)
genrule_cmd = " ".join([
"for f in $(SRCS); do",
"$(location %s)" % (flatc_path),
" ".join(flatc_args),
" ".join(include_paths_cmd),
language_flag,
output_directory,
"$$f;",
"done",
])
native.genrule(
name = name,
srcs = srcs,
outs = outs,
output_to_bindir = output_to_bindir,
compatible_with = compatible_with,
tools = includes + [flatc_path],
cmd = genrule_cmd,
message = "Generating flatbuffer files for %s:" % (name),
)
if reflection_name:
reflection_genrule_cmd = " ".join([
"for f in $(SRCS); do",
"$(location %s)" % (flatc_path),
"-b --schema",
" ".join(flatc_args),
" ".join(include_paths_cmd),
language_flag,
output_directory,
"$$f;",
"done",
])
reflection_outs = [
(out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1])
for s in srcs
]
native.genrule(
name = "%s_srcs" % reflection_name,
srcs = srcs,
outs = reflection_outs,
output_to_bindir = output_to_bindir,
compatible_with = compatible_with,
tools = includes + [flatc_path],
cmd = reflection_genrule_cmd,
message = "Generating flatbuffer reflection binary for %s:" % (name),
)
# TODO(b/114456773): Make bazel rules proper and supported by flatbuffer
# Have to comment this since FilesetEntry is not supported in bazel
# starlark.
# native.Fileset(
# name = reflection_name,
# out = "%s_out" % reflection_name,
# entries = [
# native.FilesetEntry(files = reflection_outs),
# ],
# visibility = reflection_visibility,
# compatible_with = compatible_with,
# )
def flatbuffer_cc_library(
name,
srcs,
srcs_filegroup_name = "",
out_prefix = "",
includes = [],
include_paths = [],
compatible_with = [],
flatc_args = DEFAULT_FLATC_ARGS,
visibility = None,
srcs_filegroup_visibility = None,
gen_reflections = False):
'''A cc_library with the generated reader/writers for the given flatbuffer definitions.
Outs:
filegroup([name]_srcs): all generated .h files.
filegroup(srcs_filegroup_name if specified, or [name]_includes if not):
Other flatbuffer_cc_library's can pass this in for their `includes`
parameter, if they depend on the schemas in this library.
Fileset([name]_reflection): (Optional) all generated reflection binaries.
cc_library([name]): library with sources and flatbuffers deps.
Remarks:
** Because the genrule used to call flatc does not have any trivial way of
computing the output list of files transitively generated by includes and
--gen-includes (the default) being defined for flatc, the --gen-includes
flag will not work as expected. The way around this is to add a dependency
to the flatbuffer_cc_library defined alongside the flatc included Fileset.
For example you might define:
flatbuffer_cc_library(
name = "my_fbs",
srcs = [ "schemas/foo.fbs" ],
includes = [ "//third_party/bazz:bazz_fbs_includes" ],
)
In which foo.fbs includes a few files from the Fileset defined at
//third_party/bazz:bazz_fbs_includes. When compiling the library that
includes foo_generated.h, and therefore has my_fbs as a dependency, it
will fail to find any of the bazz *_generated.h files unless you also
add bazz's flatbuffer_cc_library to your own dependency list, e.g.:
cc_library(
name = "my_lib",
deps = [
":my_fbs",
"//third_party/bazz:bazz_fbs"
],
)
Happy dependent Flatbuffering!
Args:
name: Rule name.
srcs: Source .fbs files. Sent in order to the compiler.
srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this
filegroup into the `includes` parameter of any other
flatbuffer_cc_library that depends on this one's schemas.
out_prefix: Prepend this path to the front of all generated files. Usually
is a directory name.
includes: Optional, list of filegroups of schemas that the srcs depend on.
** SEE REMARKS BELOW **
include_paths: Optional, list of paths the includes files can be found in.
compatible_with: Optional, passed to genrule for environments this rule
can be built for
flatc_args: Optional list of additional arguments to pass to flatc
(e.g. --gen-mutable).
visibility: The visibility of the generated cc_library. By default, use the
default visibility of the project.
srcs_filegroup_visibility: The visibility of the generated srcs filegroup.
By default, use the value of the visibility parameter above.
gen_reflections: Optional, if true this will generate the flatbuffer
reflection binaries for the schemas.
'''
output_headers = [
(out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1])
for s in srcs
]
reflection_name = "%s_reflection" % name if gen_reflections else ""
flatbuffer_library_public(
name = "%s_srcs" % (name),
srcs = srcs,
outs = output_headers,
language_flag = "-c",
out_prefix = out_prefix,
includes = includes,
include_paths = include_paths,
compatible_with = compatible_with,
flatc_args = flatc_args,
reflection_name = reflection_name,
reflection_visibility = visibility,
)
native.cc_library(
name = name,
hdrs = output_headers,
srcs = output_headers,
features = [
"-parse_headers",
],
deps = [
"@flatbuffers//:runtime_cc",
],
includes = ["."],
linkstatic = 1,
visibility = visibility,
compatible_with = compatible_with,
)
# A filegroup for the `srcs`. That is, all the schema files for this
# Flatbuffer set.
native.filegroup(
name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name),
srcs = srcs,
visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility,
compatible_with = compatible_with,
)
# Custom provider to track dependencies transitively.
FlatbufferInfo = provider(
fields = {
"transitive_srcs": "flatbuffer schema definitions.",
},
) # buildifier: disable=provider-params
def _flatbuffer_schemas_aspect_impl(target, ctx):
_ignore = [target] # @unused
transitive_srcs = depset()
if hasattr(ctx.rule.attr, "deps"):
for dep in ctx.rule.attr.deps:
if FlatbufferInfo in dep:
transitive_srcs = depset(dep[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs]) # buildifier: disable=overly-nested-depset
if hasattr(ctx.rule.attr, "srcs"):
for src in ctx.rule.attr.srcs:
if FlatbufferInfo in src:
transitive_srcs = depset(src[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs]) # buildifier: disable=overly-nested-depset
for f in src.files:
if f.extension == "fbs":
transitive_srcs = depset([f], transitive = [transitive_srcs]) # buildifier: disable=overly-nested-depset
return [FlatbufferInfo(transitive_srcs = transitive_srcs)]
# An aspect that runs over all dependencies and transitively collects
# flatbuffer schema files.
_flatbuffer_schemas_aspect = aspect(
attr_aspects = [
"deps",
"srcs",
],
implementation = _flatbuffer_schemas_aspect_impl,
)
# Rule to invoke the flatbuffer compiler.
def _gen_flatbuffer_srcs_impl(ctx):
outputs = ctx.attr.outputs
include_paths = ctx.attr.include_paths
if ctx.attr.no_includes:
no_includes_statement = ["--no-includes"]
else:
no_includes_statement = []
if ctx.attr.language_flag == "--python":
onefile_statement = ["--gen-onefile"]
else:
onefile_statement = []
# Need to generate all files in a directory.
if not outputs:
outputs = [ctx.actions.declare_directory("{}_all".format(ctx.attr.name))]
output_directory = outputs[0].path
else:
outputs = [ctx.actions.declare_file(output) for output in outputs]
output_directory = outputs[0].dirname
deps = depset(ctx.files.srcs + ctx.files.deps, transitive = [
dep[FlatbufferInfo].transitive_srcs
for dep in ctx.attr.deps
if FlatbufferInfo in dep
])
include_paths_cmd_line = []
for s in include_paths:
include_paths_cmd_line.extend(["-I", s])
for src in ctx.files.srcs:
ctx.actions.run(
inputs = deps,
outputs = outputs,
executable = ctx.executable._flatc,
arguments = [
ctx.attr.language_flag,
"-o",
output_directory,
# Allow for absolute imports and referencing of generated files.
"-I",
"./",
"-I",
ctx.genfiles_dir.path,
"-I",
ctx.bin_dir.path,
] + no_includes_statement +
onefile_statement +
include_paths_cmd_line + [
"--no-union-value-namespacing",
"--gen-object-api",
src.path,
],
progress_message = "Generating flatbuffer files for {}:".format(src),
use_default_shell_env = True,
)
return [
DefaultInfo(files = depset(outputs)),
]
_gen_flatbuffer_srcs = rule(
_gen_flatbuffer_srcs_impl,
attrs = {
"srcs": attr.label_list(
allow_files = [".fbs"],
mandatory = True,
),
"outputs": attr.string_list(
default = [],
mandatory = False,
),
"deps": attr.label_list(
default = [],
mandatory = False,
aspects = [_flatbuffer_schemas_aspect],
),
"include_paths": attr.string_list(
default = [],
mandatory = False,
),
"language_flag": attr.string(
mandatory = True,
),
"no_includes": attr.bool(
default = False,
mandatory = False,
),
"_flatc": attr.label(
default = Label("@flatbuffers//:flatc"),
executable = True,
cfg = "exec",
),
},
output_to_genfiles = True,
)
def flatbuffer_py_strip_prefix_srcs(name, srcs = [], strip_prefix = ""):
"""Strips path prefix.
Args:
name: Rule name. (required)
srcs: Source .py files. (required)
strip_prefix: Path that needs to be stripped from the srcs filepaths. (required)
"""
for src in srcs:
native.genrule(
name = name + "_" + src.replace(".", "_").replace("/", "_"),
srcs = [src],
outs = [src.replace(strip_prefix, "")],
cmd = "cp $< $@",
)
def _concat_flatbuffer_py_srcs_impl(ctx):
# Merge all generated python files. The files are concatenated and import
# statements are removed. Finally we import the flatbuffer runtime library.
# IMPORTANT: Our Windows shell does not support "find ... -exec" properly.
# If you're changing the commandline below, please build wheels and run smoke
# tests on all the three operating systems.
command = "echo 'import flatbuffers\n' > %s; "
command += "for f in $(find %s -name '*.py' | sort); do cat $f | sed '/import flatbuffers/d' >> %s; done "
ctx.actions.run_shell(
inputs = ctx.attr.deps[0].files,
outputs = [ctx.outputs.out],
command = command % (
ctx.outputs.out.path,
ctx.attr.deps[0].files.to_list()[0].path,
ctx.outputs.out.path,
),
use_default_shell_env = True,
)
_concat_flatbuffer_py_srcs = rule(
_concat_flatbuffer_py_srcs_impl,
attrs = {
"deps": attr.label_list(mandatory = True),
},
output_to_genfiles = True,
outputs = {"out": "%{name}.py"},
)
def flatbuffer_py_library(
name,
srcs,
deps = [],
include_paths = []):
"""A py_library with the generated reader/writers for the given schema.
This rule assumes that the schema files define non-conflicting names, so that
they can be merged in a single file. This is e.g. the case if only a single
namespace is used.
The rule call the flatbuffer compiler for all schema files and merges the
generated python files into a single file that is wrapped in a py_library.
Args:
name: Rule name. (required)
srcs: List of source .fbs files. (required)
deps: List of dependencies.
include_paths: Optional, list of paths the includes files can be found in.
"""
all_srcs = "{}_srcs".format(name)
_gen_flatbuffer_srcs(
name = all_srcs,
srcs = srcs,
language_flag = "--python",
deps = deps,
include_paths = include_paths,
)
# TODO(b/235550563): Remove the concatnation rule with 2.0.6 update.
all_srcs_no_include = "{}_srcs_no_include".format(name)
_gen_flatbuffer_srcs(
name = all_srcs_no_include,
srcs = srcs,
language_flag = "--python",
deps = deps,
no_includes = True,
include_paths = include_paths,
)
concat_py_srcs = "{}_generated".format(name)
_concat_flatbuffer_py_srcs(
name = concat_py_srcs,
deps = [
":{}".format(all_srcs_no_include),
],
)
native.py_library(
name = name,
srcs = [
":{}".format(concat_py_srcs),
],
srcs_version = "PY3",
deps = deps + [
"@flatbuffers//:runtime_py",
],
)
def flatbuffer_java_library(
name,
srcs,
custom_package = "",
package_prefix = "",
include_paths = DEFAULT_INCLUDE_PATHS,
flatc_args = DEFAULT_FLATC_ARGS,
visibility = None):
"""A java library with the generated reader/writers for the given flatbuffer definitions.
Args:
name: Rule name. (required)
srcs: List of source .fbs files including all includes. (required)
custom_package: Package name of generated Java files. If not specified
namespace in the schema files will be used. (optional)
package_prefix: like custom_package, but prefixes to the existing
namespace. (optional)
include_paths: List of paths that includes files can be found in. (optional)
flatc_args: List of additional arguments to pass to flatc. (optional)
visibility: Visibility setting for the java_library rule. (optional)
"""
out_srcjar = "java_%s_all.srcjar" % name
flatbuffer_java_srcjar(
name = "%s_srcjar" % name,
srcs = srcs,
out = out_srcjar,
custom_package = custom_package,
flatc_args = flatc_args,
include_paths = include_paths,
package_prefix = package_prefix,
)
native.filegroup(
name = "%s.srcjar" % name,
srcs = [out_srcjar],
)
native.java_library(
name = name,
srcs = [out_srcjar],
javacopts = ["-source 7 -target 7"],
deps = [
"@flatbuffers//:runtime_java",
],
visibility = visibility,
)
def flatbuffer_java_srcjar(
name,
srcs,
out,
custom_package = "",
package_prefix = "",
include_paths = DEFAULT_INCLUDE_PATHS,
flatc_args = DEFAULT_FLATC_ARGS):
"""Generate flatbuffer Java source files.
Args:
name: Rule name. (required)
srcs: List of source .fbs files including all includes. (required)
out: Output file name. (required)
custom_package: Package name of generated Java files. If not specified
namespace in the schema files will be used. (optional)
package_prefix: like custom_package, but prefixes to the existing
namespace. (optional)
include_paths: List of paths that includes files can be found in. (optional)
flatc_args: List of additional arguments to pass to flatc. (optional)
"""
command_fmt = """set -e
tmpdir=$(@D)
schemas=$$tmpdir/schemas
java_root=$$tmpdir/java
rm -rf $$schemas
rm -rf $$java_root
mkdir -p $$schemas
mkdir -p $$java_root
for src in $(SRCS); do
dest=$$schemas/$$src
rm -rf $$(dirname $$dest)
mkdir -p $$(dirname $$dest)
if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then
cp -f $$src $$dest
else
if [ -z "{package_prefix}" ]; then
sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest
else
sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest
fi
fi
done
flatc_arg_I="-I $$tmpdir/schemas"
for include_path in {include_paths}; do
flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path"
done
flatc_additional_args=
for arg in {flatc_args}; do
flatc_additional_args="$$flatc_additional_args $$arg"
done
for src in $(SRCS); do
$(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src
done
$(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root
"""
genrule_cmd = command_fmt.format(
package_name = native.package_name(),
custom_package = custom_package,
package_prefix = package_prefix,
flatc_path = flatc_path,
zip_files = zip_files,
include_paths = " ".join(include_paths),
flatc_args = " ".join(flatc_args),
)
native.genrule(
name = name,
srcs = srcs,
outs = [out],
tools = [flatc_path, zip_files],
cmd = genrule_cmd,
)
def flatbuffer_android_library(
name,
srcs,
custom_package = "",
package_prefix = "",
include_paths = DEFAULT_INCLUDE_PATHS,
flatc_args = DEFAULT_FLATC_ARGS,
visibility = None):
"""An android_library with the generated reader/writers for the given flatbuffer definitions.
Args:
name: Rule name. (required)
srcs: List of source .fbs files including all includes. (required)
custom_package: Package name of generated Java files. If not specified
namespace in the schema files will be used. (optional)
package_prefix: like custom_package, but prefixes to the existing
namespace. (optional)
include_paths: List of paths that includes files can be found in. (optional)
flatc_args: List of additional arguments to pass to flatc. (optional)
visibility: Visibility setting for the android_library rule. (optional)
"""
out_srcjar = "android_%s_all.srcjar" % name
flatbuffer_java_srcjar(
name = "%s_srcjar" % name,
srcs = srcs,
out = out_srcjar,
custom_package = custom_package,
flatc_args = flatc_args,
include_paths = include_paths,
package_prefix = package_prefix,
)
native.filegroup(
name = "%s.srcjar" % name,
srcs = [out_srcjar],
)
# To support org.checkerframework.dataflow.qual.Pure.
checkerframework_annotations = [
"@org_checkerframework_qual",
] if "--java-checkerframework" in flatc_args else []
android_library(
name = name,
srcs = [out_srcjar],
javacopts = ["-source 7 -target 7"],
visibility = visibility,
deps = [
"@flatbuffers//:runtime_android",
] + checkerframework_annotations,
)

View File

@ -0,0 +1,192 @@
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
load(":build_defs.bzl", "flatbuffer_py_strip_prefix_srcs")
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE.txt"])
licenses(["notice"])
config_setting(
name = "platform_freebsd",
values = {"cpu": "freebsd"},
)
config_setting(
name = "platform_openbsd",
values = {"cpu": "openbsd"},
)
config_setting(
name = "windows",
values = {"cpu": "x64_windows"},
)
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
# Public flatc library to compile flatbuffer files at runtime.
cc_library(
name = "flatbuffers",
hdrs = ["//:public_headers"],
linkstatic = 1,
strip_include_prefix = "/include",
visibility = ["//visibility:public"],
deps = ["//src:flatbuffers"],
)
# Public C++ headers for the Flatbuffers library.
filegroup(
name = "public_headers",
srcs = [
"include/flatbuffers/allocator.h",
"include/flatbuffers/array.h",
"include/flatbuffers/base.h",
"include/flatbuffers/bfbs_generator.h",
"include/flatbuffers/buffer.h",
"include/flatbuffers/buffer_ref.h",
"include/flatbuffers/code_generators.h",
"include/flatbuffers/default_allocator.h",
"include/flatbuffers/detached_buffer.h",
"include/flatbuffers/flatbuffer_builder.h",
"include/flatbuffers/flatbuffers.h",
"include/flatbuffers/flexbuffers.h",
"include/flatbuffers/hash.h",
"include/flatbuffers/idl.h",
"include/flatbuffers/minireflect.h",
"include/flatbuffers/reflection.h",
"include/flatbuffers/reflection_generated.h",
"include/flatbuffers/registry.h",
"include/flatbuffers/stl_emulation.h",
"include/flatbuffers/string.h",
"include/flatbuffers/struct.h",
"include/flatbuffers/table.h",
"include/flatbuffers/util.h",
"include/flatbuffers/vector.h",
"include/flatbuffers/vector_downward.h",
"include/flatbuffers/verifier.h",
],
visibility = ["//:__subpackages__"],
)
# Public flatc compiler library.
cc_library(
name = "flatc_library",
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"@flatbuffers//src:flatc_library",
],
)
# Public flatc compiler.
cc_binary(
name = "flatc",
linkopts = select({
":platform_freebsd": [
"-lm",
],
":windows": [],
"//conditions:default": [
"-lm",
"-ldl",
],
}),
visibility = ["//visibility:public"],
deps = [
"@flatbuffers//src:flatc",
],
)
filegroup(
name = "flatc_headers",
srcs = [
"include/flatbuffers/flatc.h",
],
visibility = ["//:__subpackages__"],
)
# Library used by flatbuffer_cc_library rules.
cc_library(
name = "runtime_cc",
hdrs = [
"include/flatbuffers/allocator.h",
"include/flatbuffers/array.h",
"include/flatbuffers/base.h",
"include/flatbuffers/buffer.h",
"include/flatbuffers/buffer_ref.h",
"include/flatbuffers/default_allocator.h",
"include/flatbuffers/detached_buffer.h",
"include/flatbuffers/flatbuffer_builder.h",
"include/flatbuffers/flatbuffers.h",
"include/flatbuffers/flexbuffers.h",
"include/flatbuffers/stl_emulation.h",
"include/flatbuffers/string.h",
"include/flatbuffers/struct.h",
"include/flatbuffers/table.h",
"include/flatbuffers/util.h",
"include/flatbuffers/vector.h",
"include/flatbuffers/vector_downward.h",
"include/flatbuffers/verifier.h",
],
linkstatic = 1,
strip_include_prefix = "/include",
visibility = ["//visibility:public"],
)
flatbuffer_py_strip_prefix_srcs(
name = "flatbuffer_py_strip_prefix",
srcs = [
"python/flatbuffers/__init__.py",
"python/flatbuffers/_version.py",
"python/flatbuffers/builder.py",
"python/flatbuffers/compat.py",
"python/flatbuffers/encode.py",
"python/flatbuffers/flexbuffers.py",
"python/flatbuffers/number_types.py",
"python/flatbuffers/packer.py",
"python/flatbuffers/table.py",
"python/flatbuffers/util.py",
],
strip_prefix = "python/flatbuffers/",
)
filegroup(
name = "runtime_py_srcs",
srcs = [
"__init__.py",
"_version.py",
"builder.py",
"compat.py",
"encode.py",
"flexbuffers.py",
"number_types.py",
"packer.py",
"table.py",
"util.py",
],
)
py_library(
name = "runtime_py",
srcs = [":runtime_py_srcs"],
visibility = ["//visibility:public"],
)
filegroup(
name = "runtime_java_srcs",
srcs = glob(["java/com/google/flatbuffers/**/*.java"]),
)
java_library(
name = "runtime_java",
srcs = [":runtime_java_srcs"],
visibility = ["//visibility:public"],
)
android_library(
name = "runtime_android",
srcs = [":runtime_java_srcs"],
visibility = ["//visibility:public"],
)

30
third_party/flatbuffers/workspace.bzl vendored Normal file
View File

@ -0,0 +1,30 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loads the Flatbuffers library."""
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
def repo():
tf_http_archive(
name = "flatbuffers",
strip_prefix = "flatbuffers-2.0.6",
sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9",
urls = tf_mirror_urls("https://github.com/google/flatbuffers/archive/v2.0.6.tar.gz"),
build_file = "//third_party/flatbuffers:flatbuffers.BUILD",
system_build_file = "//third_party/flatbuffers:BUILD.system",
link_files = {
"//third_party/flatbuffers:build_defs.bzl": "build_defs.bzl",
},
)

166
third_party/repo.bzl vendored Normal file
View File

@ -0,0 +1,166 @@
# Copyright 2017 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for defining TensorFlow Bazel dependencies."""
def tf_mirror_urls(url):
"""A helper for generating TF-mirror versions of URLs.
Given a URL, it returns a list of the TF-mirror cache version of that URL
and the original URL, suitable for use in `urls` field of `tf_http_archive`.
"""
if not url.startswith("https://"):
return [url]
return [
"https://storage.googleapis.com/mirror.tensorflow.org/%s" % url[8:],
url,
]
def _get_env_var(ctx, name):
if name in ctx.os.environ:
return ctx.os.environ[name]
else:
return None
# Checks if we should use the system lib instead of the bundled one
def _use_system_lib(ctx, name):
syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS")
if not syslibenv:
return False
return name in [n.strip() for n in syslibenv.split(",")]
def _get_link_dict(ctx, link_files, build_file):
link_dict = {ctx.path(v): ctx.path(Label(k)) for k, v in link_files.items()}
if build_file:
# Use BUILD.bazel because it takes precedence over BUILD.
link_dict[ctx.path("BUILD.bazel")] = ctx.path(Label(build_file))
return link_dict
def _tf_http_archive_impl(ctx):
# Construct all paths early on to prevent rule restart. We want the
# attributes to be strings instead of labels because they refer to files
# in the TensorFlow repository, not files in repos depending on TensorFlow.
# See also https://github.com/bazelbuild/bazel/issues/10515.
link_dict = _get_link_dict(ctx, ctx.attr.link_files, ctx.attr.build_file)
# For some reason, we need to "resolve" labels once before the
# download_and_extract otherwise it'll invalidate and re-download the
# archive each time.
# https://github.com/bazelbuild/bazel/issues/10515
patch_files = ctx.attr.patch_file
for patch_file in patch_files:
if patch_file:
ctx.path(Label(patch_file))
if _use_system_lib(ctx, ctx.attr.name):
link_dict.update(_get_link_dict(
ctx = ctx,
link_files = ctx.attr.system_link_files,
build_file = ctx.attr.system_build_file,
))
else:
ctx.download_and_extract(
url = ctx.attr.urls,
sha256 = ctx.attr.sha256,
type = ctx.attr.type,
stripPrefix = ctx.attr.strip_prefix,
)
if patch_files:
for patch_file in patch_files:
patch_file = ctx.path(Label(patch_file)) if patch_file else None
if patch_file:
ctx.patch(patch_file, strip = 1)
for dst, src in link_dict.items():
ctx.delete(dst)
ctx.symlink(src, dst)
_tf_http_archive = repository_rule(
implementation = _tf_http_archive_impl,
attrs = {
"sha256": attr.string(mandatory = True),
"urls": attr.string_list(mandatory = True),
"strip_prefix": attr.string(),
"type": attr.string(),
"patch_file": attr.string_list(),
"build_file": attr.string(),
"system_build_file": attr.string(),
"link_files": attr.string_dict(),
"system_link_files": attr.string_dict(),
},
environ = ["TF_SYSTEM_LIBS"],
)
def tf_http_archive(name, sha256, urls, **kwargs):
"""Downloads and creates Bazel repos for dependencies.
This is a swappable replacement for both http_archive() and
new_http_archive() that offers some additional features. It also helps
ensure best practices are followed.
File arguments are relative to the TensorFlow repository by default. Dependent
repositories that use this rule should refer to files either with absolute
labels (e.g. '@foo//:bar') or from a label created in their repository (e.g.
'str(Label("//:bar"))').
Args:
name: name of the repository
sha256: sha256 sum as a string
urls: list of mirror URLs
**kwargs: additional arguments to _tf_http_archive
"""
if len(urls) < 2:
fail("tf_http_archive(urls) must have redundant URLs.")
if not any([mirror in urls[0] for mirror in (
"mirror.tensorflow.org",
"mirror.bazel.build",
"storage.googleapis.com",
)]):
fail("The first entry of tf_http_archive(urls) must be a mirror " +
"URL, preferrably mirror.tensorflow.org. Even if you don't have " +
"permission to mirror the file, please put the correctly " +
"formatted mirror URL there anyway, because someone will come " +
"along shortly thereafter and mirror the file.")
if native.existing_rule(name):
print("\n\033[1;33mWarning:\033[0m skipping import of repository '" +
name + "' because it already exists.\n") # buildifier: disable=print
return
_tf_http_archive(
name = name,
sha256 = sha256,
urls = urls,
**kwargs
)
def _tf_vendored_impl(repository_ctx):
parent_path = repository_ctx.path(repository_ctx.attr.parent).dirname
# get_child doesn't allow slashes. Yes this is silly. bazel_skylib paths
# doesn't work with path objects.
relpath_parts = repository_ctx.attr.relpath.split("/")
vendored_path = parent_path
for part in relpath_parts:
vendored_path = vendored_path.get_child(part)
repository_ctx.symlink(vendored_path, ".")
tf_vendored = repository_rule(
implementation = _tf_vendored_impl,
attrs = {
"parent": attr.label(default = "//:WORKSPACE"),
"relpath": attr.string(),
},
)