From 172a831219aa7d3524c0c8b5779dc29597a05810 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 14 Feb 2023 21:24:27 +0000 Subject: [PATCH] Switch JAX to use the OpenXLA repository. --- .bazelrc | 18 +- WORKSPACE | 42 +- build/BUILD.bazel | 4 +- build/build_wheel.py | 18 +- docs/developer.md | 16 +- docs/jep/9419-jax-versioning.md | 4 +- examples/jax_cpp/BUILD | 29 +- examples/jax_cpp/main.cc | 20 +- jaxlib/BUILD | 8 +- jaxlib/cpu/BUILD | 6 +- jaxlib/cpu/cpu_kernels.cc | 2 +- jaxlib/cpu/ducc_fft_kernels.cc | 2 +- jaxlib/cpu/ducc_fft_kernels.h | 2 +- jaxlib/cpu/lapack_kernels.h | 2 +- jaxlib/cuda/BUILD | 50 +- jaxlib/gpu/blas_kernels.cc | 2 +- jaxlib/gpu/blas_kernels.h | 2 +- jaxlib/gpu/gpu_kernels.cc | 2 +- jaxlib/gpu/lu_pivot_kernels.cc | 2 +- jaxlib/gpu/lu_pivot_kernels.h | 2 +- jaxlib/gpu/prng_kernels.cc | 2 +- jaxlib/gpu/prng_kernels.h | 2 +- jaxlib/gpu/rnn_kernels.cc | 2 +- jaxlib/gpu/rnn_kernels.h | 2 +- jaxlib/gpu/solver_kernels.cc | 2 +- jaxlib/gpu/solver_kernels.h | 2 +- jaxlib/gpu/sparse_kernels.cc | 2 +- jaxlib/gpu/sparse_kernels.h | 2 +- jaxlib/jax.bzl | 11 +- jaxlib/mlir/BUILD.bazel | 2 +- jaxlib/mlir/_mlir_libs/BUILD.bazel | 31 +- jaxlib/rocm/BUILD.bazel | 14 +- third_party/BUILD | 15 + third_party/flatbuffers/BUILD.bazel | 15 + third_party/flatbuffers/BUILD.system | 43 ++ third_party/flatbuffers/build_defs.bzl | 661 ++++++++++++++++++++++ third_party/flatbuffers/flatbuffers.BUILD | 192 +++++++ third_party/flatbuffers/workspace.bzl | 30 + third_party/repo.bzl | 160 ++++++ 39 files changed, 1281 insertions(+), 142 deletions(-) create mode 100644 third_party/BUILD create mode 100644 third_party/flatbuffers/BUILD.bazel create mode 100644 third_party/flatbuffers/BUILD.system create mode 100644 third_party/flatbuffers/build_defs.bzl create mode 100644 third_party/flatbuffers/flatbuffers.BUILD create mode 100644 third_party/flatbuffers/workspace.bzl create mode 100644 third_party/repo.bzl diff --git a/.bazelrc b/.bazelrc index e0ac0e3ea..a59dd9ed2 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,6 +1,10 @@ ############################################################################ # All default build options below. +# Required by OpenXLA +# https://github.com/openxla/xla/issues/1323 +build --nocheck_visibility + # Sets the default Apple platform to macOS. build --apple_platform_type=macos build --macos_minimum_os=10.14 @@ -35,9 +39,9 @@ build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. # Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled, # these values are overridden. -build --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=false -build --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false -build --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false +build --@xla//xla/python:enable_gpu=false +build --@xla//xla/python:enable_tpu=false +build --@xla//xla/python:enable_plugin_device=false ########################################################################### @@ -65,12 +69,12 @@ build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_52,sm_60,sm_70,compute_80" build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda -build:cuda --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true +build:cuda --@xla//xla/python:enable_gpu=true build:cuda --define=xla_python_enable_gpu=true build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true -build:rocm --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true +build:rocm --@xla//xla/python:enable_gpu=true build:rocm --define=xla_python_enable_gpu=true build:rocm --repo_env TF_NEED_ROCM=1 build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" @@ -113,10 +117,10 @@ build:macos --config=posix # Suppress all warning messages. build:short_logs --output_filter=DONT_MATCH_ANYTHING -build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true +build:tpu --@xla//xla/python:enable_tpu=true build:tpu --define=with_tpu_support=true -build:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true +build:plugin_device --@xla//xla/python:enable_plugin_device=true ######################################################################### # RBE config options below. diff --git a/WORKSPACE b/WORKSPACE index 45c6af915..da8723791 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,16 +1,16 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -# To update TensorFlow to a new revision, +# To update XLA to a new revision, # a) update URL and strip_prefix to the new git commit hash # b) get the sha256 hash of the commit by running: -# curl -L https://github.com/tensorflow/tensorflow/archive/.tar.gz | sha256sum +# curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. http_archive( - name = "org_tensorflow", - sha256 = "08fd0ab0b672510229ad2fff276a3634f205fc539fa16a5bdeeaaccd881ece27", - strip_prefix = "tensorflow-2aaeef25361311b21b9e81e992edff94bcb6bae3", + name = "xla", + sha256 = "9f39af4d81d2c8bd52b47f4ef37dfd6642c6950112e4d9d3d95ae19982c46eba", + strip_prefix = "xla-0f31407ee498e6dba242d03f8d382ebcfcc61790", urls = [ - "https://github.com/tensorflow/tensorflow/archive/2aaeef25361311b21b9e81e992edff94bcb6bae3.tar.gz", + "https://github.com/openxla/xla/archive/0f31407ee498e6dba242d03f8d382ebcfcc61790.tar.gz", ], ) @@ -19,26 +19,32 @@ http_archive( # local checkout by either: # a) overriding the TF repository on the build.py command line by passing a flag # like: -# python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow +# python build/build.py --bazel_options=--override_repository=xla=/path/to/xla # or # b) by commenting out the http_archive above and uncommenting the following: # local_repository( -# name = "org_tensorflow", -# path = "/path/to/tensorflow", +# name = "xla", +# path = "/path/to/xla", # ) load("//third_party/ducc:workspace.bzl", ducc = "repo") ducc() -# Initialize TensorFlow's external dependencies. -load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") -tf_workspace3() +load("@xla//:workspace4.bzl", "xla_workspace4") +xla_workspace4() -load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") -tf_workspace2() +load("@xla//:workspace3.bzl", "xla_workspace3") +xla_workspace3() -load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") -tf_workspace1() +load("@xla//:workspace2.bzl", "xla_workspace2") +xla_workspace2() -load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") -tf_workspace0() +load("@xla//:workspace1.bzl", "xla_workspace1") +xla_workspace1() + +load("@xla//:workspace0.bzl", "xla_workspace0") +xla_workspace0() + + +load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") +flatbuffers() \ No newline at end of file diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 889f5a730..114290dff 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -44,11 +44,11 @@ py_binary( "//jaxlib:README.md", "//jaxlib:setup.py", "//jaxlib:setup.cfg", - "@org_tensorflow//tensorflow/compiler/xla/python:xla_client", + "@xla//xla/python:xla_client", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]) + select({ - ":remote_tpu_enabled": ["@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client"], + ":remote_tpu_enabled": ["@xla//xla/python/tpu_driver/client:py_tpu_client"], "//conditions:default": [], }) + if_cuda([ "//jaxlib/cuda:cuda_gpu_support", diff --git a/build/build_wheel.py b/build/build_wheel.py index 1743ab555..a3f386c77 100644 --- a/build/build_wheel.py +++ b/build/build_wheel.py @@ -103,14 +103,14 @@ def patch_copy_xla_extension_stubs(dst_dir): os.makedirs(xla_extension_dir) for stub_name in _XLA_EXTENSION_STUBS: stub_path = r.Rlocation( - "org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name) + "xla/xla/python/xla_extension/" + stub_name) stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path). if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path): continue with open(stub_path) as f: src = f.read() src = src.replace( - "from tensorflow.compiler.xla.python import xla_extension", + "from xla.python import xla_extension", "from .. import xla_extension" ) with open(os.path.join(xla_extension_dir, stub_name), "w") as f: @@ -118,14 +118,14 @@ def patch_copy_xla_extension_stubs(dst_dir): def patch_copy_tpu_client_py(dst_dir): - with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f: + with open(r.Rlocation("xla/xla/python/tpu_driver/client/tpu_client.py")) as f: src = f.read() - src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla", + src = src.replace("from xla.python import xla_extension as _xla", "from . import xla_extension as _xla") - src = src.replace("from tensorflow.compiler.xla.python import xla_client", + src = src.replace("from xla.python import xla_client", "from . import xla_client") src = src.replace( - "from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client", + "from xla.python.tpu_driver.client import tpu_client_extension as _tpu_client", "from . import tpu_client_extension as _tpu_client") with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f: f.write(src) @@ -143,7 +143,7 @@ def verify_mac_libraries_dont_reference_chkstack(): return nm = subprocess.run( ["nm", "-g", - r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so") + r.Rlocation("xla/xla/python/xla_extension.so") ], capture_output=True, text=True, check=False) @@ -250,8 +250,8 @@ def prepare_wheel(sources_path): copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir) patch_copy_xla_extension_stubs(jaxlib_dir) - if exists("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"): - copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so") + if exists("xla/xla/python/tpu_driver/client/tpu_client_extension.so"): + copy_to_jaxlib("xla/xla/python/tpu_driver/client/tpu_client_extension.so") patch_copy_tpu_client_py(jaxlib_dir) diff --git a/docs/developer.md b/docs/developer.md index 673c81fbd..322e38e81 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -66,11 +66,11 @@ specify the paths to CUDA and CUDNN, which you must have installed. Here may need to use `python3` instead. By default, the wheel is written to the `dist/` subdirectory of the current directory. -### Building jaxlib from source with a modified TensorFlow repository. +### Building jaxlib from source with a modified XLA repository. JAX depends on XLA, whose source code is in the -[Tensorflow GitHub repository](https://github.com/tensorflow/tensorflow). -By default JAX uses a pinned copy of the TensorFlow repository, but we often +[XLA GitHub repository](https://github.com/openxla/xla). +By default JAX uses a pinned copy of the XLA repository, but we often want to use a locally-modified copy of XLA when working on JAX. There are two ways to do this: @@ -78,12 +78,12 @@ ways to do this: line flag to `build.py` as follows: ``` - python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow + python build/build.py --bazel_options=--override_repository=xla=/path/to/xla ``` * modify the `WORKSPACE` file in the root of the JAX source tree to point to - a different TensorFlow tree. + a different XLA tree. -To contribute changes back to XLA, send PRs to the TensorFlow repository. +To contribute changes back to XLA, send PRs to the XLA repository. The version of XLA pinned by JAX is regularly updated, but is updated in particular before each `jaxlib` release. @@ -141,7 +141,7 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \ rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs ``` -AMD's fork of the XLA (TensorFlow) repository may include fixes +AMD's fork of the XLA repository may include fixes not present in the upstream repository. To use AMD's fork, you should clone their repository: ``` @@ -152,7 +152,7 @@ To build jaxlib with ROCM support, you can run the following build command, suitably adjusted for your paths and ROCM version. ``` python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.3.0 \ - --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow-upstream + --bazel_options=--override_repository=xla=/path/to/xla-upstream ``` ## Installing `jax` diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index 598c39f51..f0c65f92a 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -121,7 +121,7 @@ no released `jax` version uses that API. `jaxlib` is split across two main repositories, namely the [`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib) and in the -[XLA source tree, which lives inside the TensorFlow repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla). +[XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla). The JAX-specific pieces inside XLA are primarily in the [`xla/python` subdirectory](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/python). @@ -164,7 +164,7 @@ compatibility, we have additional versioning that is independent of the `jaxlib` release version numbers. We maintain an additional version number (`_version`) in -[`xla_client.py` in the XLA repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py). +[`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py). The idea is that this version number, is defined in `xla/python` together with the C++ parts of JAX, is also accessible to JAX Python as `jax._src.lib.xla_extension_version`, and must diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index 681e5f1ec..b84916440 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -12,28 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@org_tensorflow//tensorflow:tensorflow.bzl", - "tf_cc_binary", -) - licenses(["notice"]) -tf_cc_binary( +cc_binary( name = "main", srcs = ["main.cc"], tags = ["manual"], deps = [ - "@org_tensorflow//tensorflow/compiler/xla:literal", - "@org_tensorflow//tensorflow/compiler/xla:literal_util", - "@org_tensorflow//tensorflow/compiler/xla:shape_util", - "@org_tensorflow//tensorflow/compiler/xla:status", - "@org_tensorflow//tensorflow/compiler/xla:statusor", - "@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client", - "@org_tensorflow//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client", - "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", - "@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader", - "@org_tensorflow//tensorflow/core/platform:logging", - "@org_tensorflow//tensorflow/core/platform:platform_port", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:tfrt_cpu_pjrt_client", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/tools:hlo_module_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:platform_port", ], ) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index c9af5ef6b..221134ab7 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -40,18 +40,18 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" -#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h" -#include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/hlo_module_loader.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/tools/hlo_module_loader.h" +#include "tsl/platform/init_main.h" +#include "tsl/platform/logging.h" int main(int argc, char** argv) { - tensorflow::port::InitMain("", &argc, &argv); + tsl::port::InitMain("", &argc, &argv); // Load HloModule from file. std::string hlo_filename = "/tmp/fn_hlo.txt"; diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 5f4e80f96..ffe92525e 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -68,7 +68,7 @@ symlink_files( symlink_files( name = "xla_client", - srcs = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"], + srcs = ["@xla//xla/python:xla_client"], dst = ".", flatten = True, ) @@ -76,8 +76,8 @@ symlink_files( symlink_files( name = "xla_extension", srcs = if_windows( - ["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.pyd"], - ["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.so"], + ["@xla//xla/python:xla_extension.pyd"], + ["@xla//xla/python:xla_extension.so"], ), dst = ".", flatten = True, @@ -140,7 +140,7 @@ pybind_extension( srcs = ["cpu_feature_guard.c"], module_name = "cpu_feature_guard", deps = [ - "@org_tensorflow//third_party/python_runtime:headers", + "@xla//third_party/python_runtime:headers", ], ) diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 0ba15d760..85af47c9b 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -31,7 +31,7 @@ cc_library( srcs = ["lapack_kernels.cc"], hdrs = ["lapack_kernels.h"], deps = [ - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -74,7 +74,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":ducc_fft_flatbuffers_cc", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", "@ducc", "@flatbuffers//:runtime_cc", ], @@ -106,7 +106,7 @@ cc_library( ":ducc_fft_kernels", ":lapack_kernels", ":lapack_kernels_using_lapack", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry", + "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, ) diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index d7ead0a81..938e38839 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -18,7 +18,7 @@ limitations under the License. #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/cpu/ducc_fft_kernels.h" -#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" +#include "xla/service/custom_call_target_registry.h" namespace jax { namespace { diff --git a/jaxlib/cpu/ducc_fft_kernels.cc b/jaxlib/cpu/ducc_fft_kernels.cc index 789de83eb..029b72ffe 100644 --- a/jaxlib/cpu/ducc_fft_kernels.cc +++ b/jaxlib/cpu/ducc_fft_kernels.cc @@ -17,7 +17,7 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" #include "jaxlib/cpu/ducc_fft_generated.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" #include "ducc/src/ducc0/fft/fft.h" namespace jax { diff --git a/jaxlib/cpu/ducc_fft_kernels.h b/jaxlib/cpu/ducc_fft_kernels.h index b0ababd76..3a925587c 100644 --- a/jaxlib/cpu/ducc_fft_kernels.h +++ b/jaxlib/cpu/ducc_fft_kernels.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef JAXLIB_CPU_DUCC_FFT_KERNELS_H_ #define JAXLIB_CPU_DUCC_FFT_KERNELS_H_ -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 12a72cb1b..4641b772c 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" // Underlying function pointers (e.g., Trsm::Fn) are initialized either // by the pybind wrapper that links them to an existing SciPy lapack instance, diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 43383c492..3943b603d 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -51,8 +51,8 @@ cc_library( features = ["-use_header_modules"], deps = [ ":cuda_vendor", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib", + "@xla//xla/stream_executor/cuda:cusolver_lib", + "@xla//xla/stream_executor/cuda:cusparse_lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -72,9 +72,9 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", + "@xla//xla/service:custom_call_status", + "@xla//xla/stream_executor/cuda:cublas_lib", + "@xla//xla/stream_executor/cuda:cudart_stub", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -102,7 +102,7 @@ pybind_extension( ":cublas_kernels", ":cuda_vendor", "//jaxlib:kernel_pybind11_helpers", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib", + "@xla//xla/stream_executor/cuda:cublas_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@pybind11", @@ -118,9 +118,9 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudnn_lib", + "@xla//xla/service:custom_call_status", + "@xla//xla/stream_executor/cuda:cudart_stub", + "@xla//xla/stream_executor/cuda:cudnn_lib", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -157,8 +157,8 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib", + "@xla//xla/service:custom_call_status", + "@xla//xla/stream_executor/cuda:cusolver_lib", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", @@ -180,8 +180,8 @@ pybind_extension( ":cuda_vendor", ":cusolver_kernels", "//jaxlib:kernel_pybind11_helpers", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib", + "@xla//xla/stream_executor/cuda:cudart_stub", + "@xla//xla/stream_executor/cuda:cusolver_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", @@ -198,9 +198,9 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib", + "@xla//xla/service:custom_call_status", + "@xla//xla/stream_executor/cuda:cudart_stub", + "@xla//xla/stream_executor/cuda:cusparse_lib", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", @@ -222,8 +222,8 @@ pybind_extension( ":cuda_vendor", ":cusparse_kernels", "//jaxlib:kernel_pybind11_helpers", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib", + "@xla//xla/stream_executor/cuda:cudart_stub", + "@xla//xla/stream_executor/cuda:cusparse_lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -249,7 +249,7 @@ cc_library( ":cuda_lu_pivot_kernels_impl", ":cuda_vendor", "//jaxlib:kernel_helpers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -264,7 +264,7 @@ cuda_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:kernel_helpers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -284,7 +284,7 @@ pybind_extension( ":cuda_lu_pivot_kernels_impl", ":cuda_vendor", "//jaxlib:kernel_pybind11_helpers", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", + "@xla//xla/stream_executor/cuda:cudart_stub", "@local_config_cuda//cuda:cuda_headers", "@pybind11", ], @@ -301,7 +301,7 @@ cc_library( ":cuda_prng_kernels_impl", ":cuda_vendor", "//jaxlib:kernel_helpers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -316,7 +316,7 @@ cuda_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:kernel_helpers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -334,7 +334,7 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", "//jaxlib:kernel_pybind11_helpers", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", + "@xla//xla/stream_executor/cuda:cudart_stub", "@local_config_cuda//cuda:cuda_headers", "@pybind11", ], @@ -351,7 +351,7 @@ cc_library( ":cuda_vendor", ":cusolver_kernels", ":cusparse_kernels", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry", + "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, ) diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc index 551c9112c..93f4f3ceb 100644 --- a/jaxlib/gpu/blas_kernels.cc +++ b/jaxlib/gpu/blas_kernels.cc @@ -28,7 +28,7 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { diff --git a/jaxlib/gpu/blas_kernels.h b/jaxlib/gpu/blas_kernels.h index 8fc4daba1..724565ea7 100644 --- a/jaxlib/gpu/blas_kernels.h +++ b/jaxlib/gpu/blas_kernels.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 20be3fea8..e8891146e 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -22,7 +22,7 @@ limitations under the License. #include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" -#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" +#include "xla/service/custom_call_target_registry.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/lu_pivot_kernels.cc b/jaxlib/gpu/lu_pivot_kernels.cc index 407dd90d6..c705f6f08 100644 --- a/jaxlib/gpu/lu_pivot_kernels.cc +++ b/jaxlib/gpu/lu_pivot_kernels.cc @@ -20,7 +20,7 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/lu_pivot_kernels.h b/jaxlib/gpu/lu_pivot_kernels.h index 6eae513ae..b6ece773a 100644 --- a/jaxlib/gpu/lu_pivot_kernels.h +++ b/jaxlib/gpu/lu_pivot_kernels.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/prng_kernels.cc b/jaxlib/gpu/prng_kernels.cc index 6101f41de..00ee4e648 100644 --- a/jaxlib/gpu/prng_kernels.cc +++ b/jaxlib/gpu/prng_kernels.cc @@ -19,7 +19,7 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/prng_kernels.h b/jaxlib/gpu/prng_kernels.h index 08428acd4..c72966c33 100644 --- a/jaxlib/gpu/prng_kernels.h +++ b/jaxlib/gpu/prng_kernels.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 03fc434f5..803af11c0 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -23,7 +23,7 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index 5aaf82d65..5de695ab9 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index b15d529ac..d2d56ed0b 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -27,7 +27,7 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h index a0d603c17..10fcca31c 100644 --- a/jaxlib/gpu/solver_kernels.h +++ b/jaxlib/gpu/solver_kernels.h @@ -19,7 +19,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/handle_pool.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 6a5b8ced7..b0024989a 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -28,7 +28,7 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 8880119b7..aa7e8215b 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -25,7 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/handle_pool.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "xla/service/custom_call_status.h" namespace jax { diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 028284c98..6139e9c04 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -14,12 +14,12 @@ """Bazel macros used by the JAX build.""" -load("@org_tensorflow//tensorflow/tsl/platform/default:build_config.bzl", _pyx_library = "pyx_library") -load("@org_tensorflow//tensorflow:tensorflow.bzl", _if_windows = "if_windows", _pybind_extension = "pybind_extension") + +load("@tsl//tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library") -load("@org_tensorflow//tensorflow/core/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") +load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") # Explicitly re-exports names to avoid "unused variable" warnings from .bzl # lint tools. @@ -28,7 +28,6 @@ rocm_library = _rocm_library pytype_library = native.py_library pytype_strict_library = native.py_library pytype_test = native.py_test -pyx_library = _pyx_library pybind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured if_rocm_is_configured = _if_rocm_is_configured @@ -63,8 +62,8 @@ jax2tf_deps = [] def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs): lib_rule(name = name, **kwargs) -def py_extension(name, srcs, copts, deps): - pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name) +def py_extension(name, srcs, copts, deps, linkopts = []): + pybind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name) def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []): """Workaround DLL building issue. diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index 4219f5806..1849d7acc 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -120,7 +120,7 @@ symlink_inputs( name = "mhlo_dialect", rule = py_library, symlinked_inputs = {"srcs": {"dialects": [ - "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:MhloOpsPyFiles", + "@xla//xla/mlir_hlo:MhloOpsPyFiles", ]}}, deps = [ ":core", diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 49777f63c..9bfb97424 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -31,12 +31,25 @@ COPTS = [ "-frtti", ] +LINKOPTS = select({ + "@tsl//tsl:macos": [ + "-Wl,-rpath,@loader_path/", + "-Wl,-rename_section,__TEXT,text_env,__TEXT,__text", + ], + "@tsl//tsl:windows": [], + "//conditions:default": [ + "-Wl,-rpath,$$ORIGIN/", + ], + }) + + py_extension( name = "_mlir", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp", ], copts = COPTS, + linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI", @@ -50,6 +63,7 @@ py_extension( "@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp", ], copts = COPTS, + linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPISparseTensorHeaders", @@ -64,6 +78,7 @@ py_extension( "@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp", ], copts = COPTS, + linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPISparseTensorHeaders", @@ -90,6 +105,7 @@ py_extension( name = "_site_initialize_0", srcs = ["_site_initialize_0.cc"], copts = COPTS, + linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPIIRHeaders", @@ -106,15 +122,16 @@ py_extension( py_extension( name = "_mlirHlo", srcs = [ - "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:bindings/python/MlirHloModule.cc", + "@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc", ], copts = COPTS, + linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPIIRHeaders", "@llvm-project//mlir:MLIRBindingsPythonHeaders", "@local_config_python//:headers", - "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:CAPIHeaders", + "@xla//xla/mlir_hlo:CAPIHeaders", "@pybind11", ], ) @@ -129,6 +146,7 @@ py_extension( "@stablehlo//:stablehlo/integrations/python/ChloModule.cpp", ], copts = COPTS, + linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPIIRHeaders", @@ -145,6 +163,7 @@ py_extension( "@stablehlo//:stablehlo/integrations/python/StablehloModule.cpp", ], copts = COPTS, + linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPIIRHeaders", @@ -158,12 +177,12 @@ py_extension( cc_library( name = "jaxlib_mlir_capi_shared_library", srcs = select({ - "@org_tensorflow//tensorflow:windows": [":jaxlib_mlir_capi.dll"], - "@org_tensorflow//tensorflow:macos": [":libjaxlib_mlir_capi.dylib"], + "@tsl//tsl:windows": [":jaxlib_mlir_capi.dll"], + "@tsl//tsl:macos": [":libjaxlib_mlir_capi.dylib"], "//conditions:default": [":libjaxlib_mlir_capi.so"], }), deps = select({ - "@org_tensorflow//tensorflow:windows": [":jaxlib_mlir_capi_dll"], + "@tsl//tsl:windows": [":jaxlib_mlir_capi_dll"], "//conditions:default": [], }), ) @@ -174,7 +193,7 @@ cc_library( "@llvm-project//mlir:CAPISparseTensorObjects", "@llvm-project//mlir:CAPITransformsObjects", "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", - "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:CAPIObjects", + "@xla//xla/mlir_hlo:CAPIObjects", "@stablehlo//:chlo_capi_objects", "@stablehlo//:stablehlo_capi_objects", ], diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 50eb4f258..57391299f 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -75,7 +75,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:rocm_headers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", ], ) @@ -114,7 +114,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", ], ) @@ -154,7 +154,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", ], ) @@ -197,7 +197,7 @@ cc_library( ":hip_vendor", "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", ], ) @@ -210,7 +210,7 @@ rocm_library( ":hip_vendor", "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", ], ) @@ -244,7 +244,7 @@ cc_library( ":hip_vendor", "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", ], ) @@ -257,7 +257,7 @@ rocm_library( ":hip_vendor", "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@xla//xla/service:custom_call_status", ], ) diff --git a/third_party/BUILD b/third_party/BUILD new file mode 100644 index 000000000..718567c61 --- /dev/null +++ b/third_party/BUILD @@ -0,0 +1,15 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel new file mode 100644 index 000000000..1b9f84966 --- /dev/null +++ b/third_party/flatbuffers/BUILD.bazel @@ -0,0 +1,15 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty BUILD file is required to make Bazel treat this directory as a package. \ No newline at end of file diff --git a/third_party/flatbuffers/BUILD.system b/third_party/flatbuffers/BUILD.system new file mode 100644 index 000000000..8fe4d7a59 --- /dev/null +++ b/third_party/flatbuffers/BUILD.system @@ -0,0 +1,43 @@ +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "LICENSE.txt", + visibility = ["//visibility:public"], +) + +# Public flatc library to compile flatbuffer files at runtime. +cc_library( + name = "flatbuffers", + linkopts = ["-lflatbuffers"], + visibility = ["//visibility:public"], +) + +# Public flatc compiler library. +cc_library( + name = "flatc_library", + linkopts = ["-lflatbuffers"], + visibility = ["//visibility:public"], +) + +genrule( + name = "lnflatc", + outs = ["flatc.bin"], + cmd = "ln -s $$(which flatc) $@", +) + +# Public flatc compiler. +sh_binary( + name = "flatc", + srcs = ["flatc.bin"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "runtime_cc", + visibility = ["//visibility:public"], +) + +py_library( + name = "runtime_py", + visibility = ["//visibility:public"], +) diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl new file mode 100644 index 000000000..85aa5170e --- /dev/null +++ b/third_party/flatbuffers/build_defs.bzl @@ -0,0 +1,661 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BUILD rules for generating flatbuffer files.""" + +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +flatc_path = "@flatbuffers//:flatc" +zip_files = "//tensorflow/lite/tools:zip_files" + +DEFAULT_INCLUDE_PATHS = [ + "./", + "$(GENDIR)", + "$(BINDIR)", +] + +DEFAULT_FLATC_ARGS = [ + "--no-union-value-namespacing", + "--gen-object-api", +] + +def flatbuffer_library_public( + name, + srcs, + outs, + language_flag, + out_prefix = "", + includes = [], + include_paths = [], + compatible_with = [], + flatc_args = DEFAULT_FLATC_ARGS, + reflection_name = "", + reflection_visibility = None, + output_to_bindir = False): + """Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler. + + Outs: + filegroup(name): all generated source files. + Fileset([reflection_name]): (Optional) all generated reflection binaries. + + Args: + name: Rule name. + srcs: Source .fbs files. Sent in order to the compiler. + outs: Output files from flatc. + language_flag: Target language flag. One of [-c, -j, -js]. + out_prefix: Prepend this path to the front of all generated files except on + single source targets. Usually is a directory name. + includes: Optional, list of filegroups of schemas that the srcs depend on. + include_paths: Optional, list of paths the includes files can be found in. + compatible_with: Optional, passed to genrule for environments this rule + can be built for. + flatc_args: Optional, list of additional arguments to pass to flatc. + reflection_name: Optional, if set this will generate the flatbuffer + reflection binaries for the schemas. + reflection_visibility: The visibility of the generated reflection Fileset. + output_to_bindir: Passed to genrule for output to bin directory. + """ + include_paths_cmd = ["-I %s" % (s) for s in include_paths] + + # '$(@D)' when given a single source target will give the appropriate + # directory. Appending 'out_prefix' is only necessary when given a build + # target with multiple sources. + output_directory = ( + ("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)") + ) + genrule_cmd = " ".join([ + "for f in $(SRCS); do", + "$(location %s)" % (flatc_path), + " ".join(flatc_args), + " ".join(include_paths_cmd), + language_flag, + output_directory, + "$$f;", + "done", + ]) + native.genrule( + name = name, + srcs = srcs, + outs = outs, + output_to_bindir = output_to_bindir, + compatible_with = compatible_with, + tools = includes + [flatc_path], + cmd = genrule_cmd, + message = "Generating flatbuffer files for %s:" % (name), + ) + if reflection_name: + reflection_genrule_cmd = " ".join([ + "for f in $(SRCS); do", + "$(location %s)" % (flatc_path), + "-b --schema", + " ".join(flatc_args), + " ".join(include_paths_cmd), + language_flag, + output_directory, + "$$f;", + "done", + ]) + reflection_outs = [ + (out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1]) + for s in srcs + ] + native.genrule( + name = "%s_srcs" % reflection_name, + srcs = srcs, + outs = reflection_outs, + output_to_bindir = output_to_bindir, + compatible_with = compatible_with, + tools = includes + [flatc_path], + cmd = reflection_genrule_cmd, + message = "Generating flatbuffer reflection binary for %s:" % (name), + ) + # TODO(b/114456773): Make bazel rules proper and supported by flatbuffer + # Have to comment this since FilesetEntry is not supported in bazel + # skylark. + # native.Fileset( + # name = reflection_name, + # out = "%s_out" % reflection_name, + # entries = [ + # native.FilesetEntry(files = reflection_outs), + # ], + # visibility = reflection_visibility, + # compatible_with = compatible_with, + # ) + +def flatbuffer_cc_library( + name, + srcs, + srcs_filegroup_name = "", + out_prefix = "", + includes = [], + include_paths = [], + compatible_with = [], + flatc_args = DEFAULT_FLATC_ARGS, + visibility = None, + srcs_filegroup_visibility = None, + gen_reflections = False): + '''A cc_library with the generated reader/writers for the given flatbuffer definitions. + + Outs: + filegroup([name]_srcs): all generated .h files. + filegroup(srcs_filegroup_name if specified, or [name]_includes if not): + Other flatbuffer_cc_library's can pass this in for their `includes` + parameter, if they depend on the schemas in this library. + Fileset([name]_reflection): (Optional) all generated reflection binaries. + cc_library([name]): library with sources and flatbuffers deps. + + Remarks: + ** Because the genrule used to call flatc does not have any trivial way of + computing the output list of files transitively generated by includes and + --gen-includes (the default) being defined for flatc, the --gen-includes + flag will not work as expected. The way around this is to add a dependency + to the flatbuffer_cc_library defined alongside the flatc included Fileset. + For example you might define: + + flatbuffer_cc_library( + name = "my_fbs", + srcs = [ "schemas/foo.fbs" ], + includes = [ "//third_party/bazz:bazz_fbs_includes" ], + ) + + In which foo.fbs includes a few files from the Fileset defined at + //third_party/bazz:bazz_fbs_includes. When compiling the library that + includes foo_generated.h, and therefore has my_fbs as a dependency, it + will fail to find any of the bazz *_generated.h files unless you also + add bazz's flatbuffer_cc_library to your own dependency list, e.g.: + + cc_library( + name = "my_lib", + deps = [ + ":my_fbs", + "//third_party/bazz:bazz_fbs" + ], + ) + + Happy dependent Flatbuffering! + + Args: + name: Rule name. + srcs: Source .fbs files. Sent in order to the compiler. + srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this + filegroup into the `includes` parameter of any other + flatbuffer_cc_library that depends on this one's schemas. + out_prefix: Prepend this path to the front of all generated files. Usually + is a directory name. + includes: Optional, list of filegroups of schemas that the srcs depend on. + ** SEE REMARKS BELOW ** + include_paths: Optional, list of paths the includes files can be found in. + compatible_with: Optional, passed to genrule for environments this rule + can be built for + flatc_args: Optional list of additional arguments to pass to flatc + (e.g. --gen-mutable). + visibility: The visibility of the generated cc_library. By default, use the + default visibility of the project. + srcs_filegroup_visibility: The visibility of the generated srcs filegroup. + By default, use the value of the visibility parameter above. + gen_reflections: Optional, if true this will generate the flatbuffer + reflection binaries for the schemas. + ''' + output_headers = [ + (out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1]) + for s in srcs + ] + reflection_name = "%s_reflection" % name if gen_reflections else "" + + flatbuffer_library_public( + name = "%s_srcs" % (name), + srcs = srcs, + outs = output_headers, + language_flag = "-c", + out_prefix = out_prefix, + includes = includes, + include_paths = include_paths, + compatible_with = compatible_with, + flatc_args = flatc_args, + reflection_name = reflection_name, + reflection_visibility = visibility, + ) + native.cc_library( + name = name, + hdrs = output_headers, + srcs = output_headers, + features = [ + "-parse_headers", + ], + deps = [ + "@flatbuffers//:runtime_cc", + ], + includes = ["."], + linkstatic = 1, + visibility = visibility, + compatible_with = compatible_with, + ) + + # A filegroup for the `srcs`. That is, all the schema files for this + # Flatbuffer set. + native.filegroup( + name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name), + srcs = srcs, + visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility, + compatible_with = compatible_with, + ) + +# Custom provider to track dependencies transitively. +FlatbufferInfo = provider( + fields = { + "transitive_srcs": "flatbuffer schema definitions.", + }, +) + +def _flatbuffer_schemas_aspect_impl(target, ctx): + _ignore = [target] + transitive_srcs = depset() + if hasattr(ctx.rule.attr, "deps"): + for dep in ctx.rule.attr.deps: + if FlatbufferInfo in dep: + transitive_srcs = depset(dep[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs]) + if hasattr(ctx.rule.attr, "srcs"): + for src in ctx.rule.attr.srcs: + if FlatbufferInfo in src: + transitive_srcs = depset(src[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs]) + for f in src.files: + if f.extension == "fbs": + transitive_srcs = depset([f], transitive = [transitive_srcs]) + return [FlatbufferInfo(transitive_srcs = transitive_srcs)] + +# An aspect that runs over all dependencies and transitively collects +# flatbuffer schema files. +_flatbuffer_schemas_aspect = aspect( + attr_aspects = [ + "deps", + "srcs", + ], + implementation = _flatbuffer_schemas_aspect_impl, +) + +# Rule to invoke the flatbuffer compiler. +def _gen_flatbuffer_srcs_impl(ctx): + outputs = ctx.attr.outputs + include_paths = ctx.attr.include_paths + if ctx.attr.no_includes: + no_includes_statement = ["--no-includes"] + else: + no_includes_statement = [] + + if ctx.attr.language_flag == "--python": + onefile_statement = ["--gen-onefile"] + else: + onefile_statement = [] + + # Need to generate all files in a directory. + if not outputs: + outputs = [ctx.actions.declare_directory("{}_all".format(ctx.attr.name))] + output_directory = outputs[0].path + else: + outputs = [ctx.actions.declare_file(output) for output in outputs] + output_directory = outputs[0].dirname + + deps = depset(ctx.files.srcs + ctx.files.deps, transitive = [ + dep[FlatbufferInfo].transitive_srcs + for dep in ctx.attr.deps + if FlatbufferInfo in dep + ]) + + include_paths_cmd_line = [] + for s in include_paths: + include_paths_cmd_line.extend(["-I", s]) + + for src in ctx.files.srcs: + ctx.actions.run( + inputs = deps, + outputs = outputs, + executable = ctx.executable._flatc, + arguments = [ + ctx.attr.language_flag, + "-o", + output_directory, + # Allow for absolute imports and referencing of generated files. + "-I", + "./", + "-I", + ctx.genfiles_dir.path, + "-I", + ctx.bin_dir.path, + ] + no_includes_statement + + onefile_statement + + include_paths_cmd_line + [ + "--no-union-value-namespacing", + "--gen-object-api", + src.path, + ], + progress_message = "Generating flatbuffer files for {}:".format(src), + use_default_shell_env = True, + ) + return [ + DefaultInfo(files = depset(outputs)), + ] + +_gen_flatbuffer_srcs = rule( + _gen_flatbuffer_srcs_impl, + attrs = { + "srcs": attr.label_list( + allow_files = [".fbs"], + mandatory = True, + ), + "outputs": attr.string_list( + default = [], + mandatory = False, + ), + "deps": attr.label_list( + default = [], + mandatory = False, + aspects = [_flatbuffer_schemas_aspect], + ), + "include_paths": attr.string_list( + default = [], + mandatory = False, + ), + "language_flag": attr.string( + mandatory = True, + ), + "no_includes": attr.bool( + default = False, + mandatory = False, + ), + "_flatc": attr.label( + default = Label("@flatbuffers//:flatc"), + executable = True, + cfg = "exec", + ), + }, + output_to_genfiles = True, +) + +def flatbuffer_py_strip_prefix_srcs(name, srcs = [], strip_prefix = ""): + """Strips path prefix. + + Args: + name: Rule name. (required) + srcs: Source .py files. (required) + strip_prefix: Path that needs to be stripped from the srcs filepaths. (required) + """ + for src in srcs: + native.genrule( + name = name + "_" + src.replace(".", "_").replace("/", "_"), + srcs = [src], + outs = [src.replace(strip_prefix, "")], + cmd = "cp $< $@", + ) + +def _concat_flatbuffer_py_srcs_impl(ctx): + # Merge all generated python files. The files are concatenated and import + # statements are removed. Finally we import the flatbuffer runtime library. + # IMPORTANT: Our Windows shell does not support "find ... -exec" properly. + # If you're changing the commandline below, please build wheels and run smoke + # tests on all the three operating systems. + command = "echo 'import flatbuffers\n' > %s; " + command += "for f in $(find %s -name '*.py' | sort); do cat $f | sed '/import flatbuffers/d' >> %s; done " + ctx.actions.run_shell( + inputs = ctx.attr.deps[0].files, + outputs = [ctx.outputs.out], + command = command % ( + ctx.outputs.out.path, + ctx.attr.deps[0].files.to_list()[0].path, + ctx.outputs.out.path, + ), + use_default_shell_env = True, + ) + +_concat_flatbuffer_py_srcs = rule( + _concat_flatbuffer_py_srcs_impl, + attrs = { + "deps": attr.label_list(mandatory = True), + }, + output_to_genfiles = True, + outputs = {"out": "%{name}.py"}, +) + +def flatbuffer_py_library( + name, + srcs, + deps = [], + include_paths = []): + """A py_library with the generated reader/writers for the given schema. + + This rule assumes that the schema files define non-conflicting names, so that + they can be merged in a single file. This is e.g. the case if only a single + namespace is used. + The rule call the flatbuffer compiler for all schema files and merges the + generated python files into a single file that is wrapped in a py_library. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files. (required) + deps: List of dependencies. + include_paths: Optional, list of paths the includes files can be found in. + """ + all_srcs = "{}_srcs".format(name) + _gen_flatbuffer_srcs( + name = all_srcs, + srcs = srcs, + language_flag = "--python", + deps = deps, + include_paths = include_paths, + ) + + # TODO(b/235550563): Remove the concatnation rule with 2.0.6 update. + all_srcs_no_include = "{}_srcs_no_include".format(name) + _gen_flatbuffer_srcs( + name = all_srcs_no_include, + srcs = srcs, + language_flag = "--python", + deps = deps, + no_includes = True, + include_paths = include_paths, + ) + concat_py_srcs = "{}_generated".format(name) + _concat_flatbuffer_py_srcs( + name = concat_py_srcs, + deps = [ + ":{}".format(all_srcs_no_include), + ], + ) + native.py_library( + name = name, + srcs = [ + ":{}".format(concat_py_srcs), + ], + srcs_version = "PY3", + deps = deps + [ + "@flatbuffers//:runtime_py", + ], + ) + +def flatbuffer_java_library( + name, + srcs, + custom_package = "", + package_prefix = "", + include_paths = DEFAULT_INCLUDE_PATHS, + flatc_args = DEFAULT_FLATC_ARGS, + visibility = None): + """A java library with the generated reader/writers for the given flatbuffer definitions. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files including all includes. (required) + custom_package: Package name of generated Java files. If not specified + namespace in the schema files will be used. (optional) + package_prefix: like custom_package, but prefixes to the existing + namespace. (optional) + include_paths: List of paths that includes files can be found in. (optional) + flatc_args: List of additional arguments to pass to flatc. (optional) + visibility: Visibility setting for the java_library rule. (optional) + """ + out_srcjar = "java_%s_all.srcjar" % name + flatbuffer_java_srcjar( + name = "%s_srcjar" % name, + srcs = srcs, + out = out_srcjar, + custom_package = custom_package, + flatc_args = flatc_args, + include_paths = include_paths, + package_prefix = package_prefix, + ) + + native.filegroup( + name = "%s.srcjar" % name, + srcs = [out_srcjar], + ) + + native.java_library( + name = name, + srcs = [out_srcjar], + javacopts = ["-source 7 -target 7"], + deps = [ + "@flatbuffers//:runtime_java", + ], + visibility = visibility, + ) + +def flatbuffer_java_srcjar( + name, + srcs, + out, + custom_package = "", + package_prefix = "", + include_paths = DEFAULT_INCLUDE_PATHS, + flatc_args = DEFAULT_FLATC_ARGS): + """Generate flatbuffer Java source files. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files including all includes. (required) + out: Output file name. (required) + custom_package: Package name of generated Java files. If not specified + namespace in the schema files will be used. (optional) + package_prefix: like custom_package, but prefixes to the existing + namespace. (optional) + include_paths: List of paths that includes files can be found in. (optional) + flatc_args: List of additional arguments to pass to flatc. (optional) + """ + command_fmt = """set -e + tmpdir=$(@D) + schemas=$$tmpdir/schemas + java_root=$$tmpdir/java + rm -rf $$schemas + rm -rf $$java_root + mkdir -p $$schemas + mkdir -p $$java_root + + for src in $(SRCS); do + dest=$$schemas/$$src + rm -rf $$(dirname $$dest) + mkdir -p $$(dirname $$dest) + if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then + cp -f $$src $$dest + else + if [ -z "{package_prefix}" ]; then + sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest + else + sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest + fi + fi + done + + flatc_arg_I="-I $$tmpdir/schemas" + for include_path in {include_paths}; do + flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path" + done + + flatc_additional_args= + for arg in {flatc_args}; do + flatc_additional_args="$$flatc_additional_args $$arg" + done + + for src in $(SRCS); do + $(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src + done + + $(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root + """ + genrule_cmd = command_fmt.format( + package_name = native.package_name(), + custom_package = custom_package, + package_prefix = package_prefix, + flatc_path = flatc_path, + zip_files = zip_files, + include_paths = " ".join(include_paths), + flatc_args = " ".join(flatc_args), + ) + + native.genrule( + name = name, + srcs = srcs, + outs = [out], + tools = [flatc_path, zip_files], + cmd = genrule_cmd, + ) + +def flatbuffer_android_library( + name, + srcs, + custom_package = "", + package_prefix = "", + include_paths = DEFAULT_INCLUDE_PATHS, + flatc_args = DEFAULT_FLATC_ARGS, + visibility = None): + """An android_library with the generated reader/writers for the given flatbuffer definitions. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files including all includes. (required) + custom_package: Package name of generated Java files. If not specified + namespace in the schema files will be used. (optional) + package_prefix: like custom_package, but prefixes to the existing + namespace. (optional) + include_paths: List of paths that includes files can be found in. (optional) + flatc_args: List of additional arguments to pass to flatc. (optional) + visibility: Visibility setting for the android_library rule. (optional) + """ + out_srcjar = "android_%s_all.srcjar" % name + flatbuffer_java_srcjar( + name = "%s_srcjar" % name, + srcs = srcs, + out = out_srcjar, + custom_package = custom_package, + flatc_args = flatc_args, + include_paths = include_paths, + package_prefix = package_prefix, + ) + + native.filegroup( + name = "%s.srcjar" % name, + srcs = [out_srcjar], + ) + + # To support org.checkerframework.dataflow.qual.Pure. + checkerframework_annotations = [ + "@org_checkerframework_qual", + ] if "--java-checkerframework" in flatc_args else [] + + android_library( + name = name, + srcs = [out_srcjar], + javacopts = ["-source 7 -target 7"], + visibility = visibility, + deps = [ + "@flatbuffers//:runtime_android", + ] + checkerframework_annotations, + ) diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD new file mode 100644 index 000000000..f32f1a5f3 --- /dev/null +++ b/third_party/flatbuffers/flatbuffers.BUILD @@ -0,0 +1,192 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load(":build_defs.bzl", "flatbuffer_py_strip_prefix_srcs") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE.txt"]) + +licenses(["notice"]) + +config_setting( + name = "platform_freebsd", + values = {"cpu": "freebsd"}, +) + +config_setting( + name = "platform_openbsd", + values = {"cpu": "openbsd"}, +) + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, +) + +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") + +# Public flatc library to compile flatbuffer files at runtime. +cc_library( + name = "flatbuffers", + hdrs = ["//:public_headers"], + linkstatic = 1, + strip_include_prefix = "/include", + visibility = ["//visibility:public"], + deps = ["//src:flatbuffers"], +) + +# Public C++ headers for the Flatbuffers library. +filegroup( + name = "public_headers", + srcs = [ + "include/flatbuffers/allocator.h", + "include/flatbuffers/array.h", + "include/flatbuffers/base.h", + "include/flatbuffers/bfbs_generator.h", + "include/flatbuffers/buffer.h", + "include/flatbuffers/buffer_ref.h", + "include/flatbuffers/code_generators.h", + "include/flatbuffers/default_allocator.h", + "include/flatbuffers/detached_buffer.h", + "include/flatbuffers/flatbuffer_builder.h", + "include/flatbuffers/flatbuffers.h", + "include/flatbuffers/flexbuffers.h", + "include/flatbuffers/hash.h", + "include/flatbuffers/idl.h", + "include/flatbuffers/minireflect.h", + "include/flatbuffers/reflection.h", + "include/flatbuffers/reflection_generated.h", + "include/flatbuffers/registry.h", + "include/flatbuffers/stl_emulation.h", + "include/flatbuffers/string.h", + "include/flatbuffers/struct.h", + "include/flatbuffers/table.h", + "include/flatbuffers/util.h", + "include/flatbuffers/vector.h", + "include/flatbuffers/vector_downward.h", + "include/flatbuffers/verifier.h", + ], + visibility = ["//:__subpackages__"], +) + +# Public flatc compiler library. +cc_library( + name = "flatc_library", + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "@flatbuffers//src:flatc_library", + ], +) + +# Public flatc compiler. +cc_binary( + name = "flatc", + linkopts = select({ + ":platform_freebsd": [ + "-lm", + ], + ":windows": [], + "//conditions:default": [ + "-lm", + "-ldl", + ], + }), + visibility = ["//visibility:public"], + deps = [ + "@flatbuffers//src:flatc", + ], +) + +filegroup( + name = "flatc_headers", + srcs = [ + "include/flatbuffers/flatc.h", + ], + visibility = ["//:__subpackages__"], +) + +# Library used by flatbuffer_cc_library rules. +cc_library( + name = "runtime_cc", + hdrs = [ + "include/flatbuffers/allocator.h", + "include/flatbuffers/array.h", + "include/flatbuffers/base.h", + "include/flatbuffers/buffer.h", + "include/flatbuffers/buffer_ref.h", + "include/flatbuffers/default_allocator.h", + "include/flatbuffers/detached_buffer.h", + "include/flatbuffers/flatbuffer_builder.h", + "include/flatbuffers/flatbuffers.h", + "include/flatbuffers/flexbuffers.h", + "include/flatbuffers/stl_emulation.h", + "include/flatbuffers/string.h", + "include/flatbuffers/struct.h", + "include/flatbuffers/table.h", + "include/flatbuffers/util.h", + "include/flatbuffers/vector.h", + "include/flatbuffers/vector_downward.h", + "include/flatbuffers/verifier.h", + ], + linkstatic = 1, + strip_include_prefix = "/include", + visibility = ["//visibility:public"], +) + +flatbuffer_py_strip_prefix_srcs( + name = "flatbuffer_py_strip_prefix", + srcs = [ + "python/flatbuffers/__init__.py", + "python/flatbuffers/_version.py", + "python/flatbuffers/builder.py", + "python/flatbuffers/compat.py", + "python/flatbuffers/encode.py", + "python/flatbuffers/flexbuffers.py", + "python/flatbuffers/number_types.py", + "python/flatbuffers/packer.py", + "python/flatbuffers/table.py", + "python/flatbuffers/util.py", + ], + strip_prefix = "python/flatbuffers/", +) + +filegroup( + name = "runtime_py_srcs", + srcs = [ + "__init__.py", + "_version.py", + "builder.py", + "compat.py", + "encode.py", + "flexbuffers.py", + "number_types.py", + "packer.py", + "table.py", + "util.py", + ], +) + +py_library( + name = "runtime_py", + srcs = [":runtime_py_srcs"], + visibility = ["//visibility:public"], +) + +filegroup( + name = "runtime_java_srcs", + srcs = glob(["java/com/google/flatbuffers/**/*.java"]), +) + +java_library( + name = "runtime_java", + srcs = [":runtime_java_srcs"], + visibility = ["//visibility:public"], +) + +android_library( + name = "runtime_android", + srcs = [":runtime_java_srcs"], + visibility = ["//visibility:public"], +) diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl new file mode 100644 index 000000000..9f28c8323 --- /dev/null +++ b/third_party/flatbuffers/workspace.bzl @@ -0,0 +1,30 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loads the Flatbuffers library.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + tf_http_archive( + name = "flatbuffers", + strip_prefix = "flatbuffers-2.0.6", + sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9", + urls = tf_mirror_urls("https://github.com/google/flatbuffers/archive/v2.0.6.tar.gz"), + build_file = "//third_party/flatbuffers:flatbuffers.BUILD", + system_build_file = "//third_party/flatbuffers:BUILD.system", + link_files = { + "//third_party/flatbuffers:build_defs.bzl": "build_defs.bzl", + }, + ) diff --git a/third_party/repo.bzl b/third_party/repo.bzl new file mode 100644 index 000000000..28b159ad5 --- /dev/null +++ b/third_party/repo.bzl @@ -0,0 +1,160 @@ +# Copyright 2017 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for defining TensorFlow Bazel dependencies.""" + +def tf_mirror_urls(url): + """A helper for generating TF-mirror versions of URLs. + + Given a URL, it returns a list of the TF-mirror cache version of that URL + and the original URL, suitable for use in `urls` field of `tf_http_archive`. + """ + if not url.startswith("https://"): + return [url] + return [ + "https://storage.googleapis.com/mirror.tensorflow.org/%s" % url[8:], + url, + ] + +def _get_env_var(ctx, name): + if name in ctx.os.environ: + return ctx.os.environ[name] + else: + return None + +# Checks if we should use the system lib instead of the bundled one +def _use_system_lib(ctx, name): + syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS") + if not syslibenv: + return False + return name in [n.strip() for n in syslibenv.split(",")] + +def _get_link_dict(ctx, link_files, build_file): + link_dict = {ctx.path(v): ctx.path(Label(k)) for k, v in link_files.items()} + if build_file: + # Use BUILD.bazel because it takes precedence over BUILD. + link_dict[ctx.path("BUILD.bazel")] = ctx.path(Label(build_file)) + return link_dict + +def _tf_http_archive_impl(ctx): + # Construct all paths early on to prevent rule restart. We want the + # attributes to be strings instead of labels because they refer to files + # in the TensorFlow repository, not files in repos depending on TensorFlow. + # See also https://github.com/bazelbuild/bazel/issues/10515. + link_dict = _get_link_dict(ctx, ctx.attr.link_files, ctx.attr.build_file) + + # For some reason, we need to "resolve" labels once before the + # download_and_extract otherwise it'll invalidate and re-download the + # archive each time. + # https://github.com/bazelbuild/bazel/issues/10515 + patch_files = ctx.attr.patch_file + for patch_file in patch_files: + if patch_file: + ctx.path(Label(patch_file)) + + if _use_system_lib(ctx, ctx.attr.name): + link_dict.update(_get_link_dict( + ctx = ctx, + link_files = ctx.attr.system_link_files, + build_file = ctx.attr.system_build_file, + )) + else: + ctx.download_and_extract( + url = ctx.attr.urls, + sha256 = ctx.attr.sha256, + type = ctx.attr.type, + stripPrefix = ctx.attr.strip_prefix, + ) + if patch_files: + for patch_file in patch_files: + patch_file = ctx.path(Label(patch_file)) if patch_file else None + if patch_file: + ctx.patch(patch_file, strip = 1) + + for dst, src in link_dict.items(): + ctx.delete(dst) + ctx.symlink(src, dst) + +_tf_http_archive = repository_rule( + implementation = _tf_http_archive_impl, + attrs = { + "sha256": attr.string(mandatory = True), + "urls": attr.string_list(mandatory = True), + "strip_prefix": attr.string(), + "type": attr.string(), + "patch_file": attr.string_list(), + "build_file": attr.string(), + "system_build_file": attr.string(), + "link_files": attr.string_dict(), + "system_link_files": attr.string_dict(), + }, + environ = ["TF_SYSTEM_LIBS"], +) + +def tf_http_archive(name, sha256, urls, **kwargs): + """Downloads and creates Bazel repos for dependencies. + + This is a swappable replacement for both http_archive() and + new_http_archive() that offers some additional features. It also helps + ensure best practices are followed. + + File arguments are relative to the TensorFlow repository by default. Dependent + repositories that use this rule should refer to files either with absolute + labels (e.g. '@foo//:bar') or from a label created in their repository (e.g. + 'str(Label("//:bar"))'). + """ + if len(urls) < 2: + fail("tf_http_archive(urls) must have redundant URLs.") + + if not any([mirror in urls[0] for mirror in ( + "mirror.tensorflow.org", + "mirror.bazel.build", + "storage.googleapis.com", + )]): + fail("The first entry of tf_http_archive(urls) must be a mirror " + + "URL, preferrably mirror.tensorflow.org. Even if you don't have " + + "permission to mirror the file, please put the correctly " + + "formatted mirror URL there anyway, because someone will come " + + "along shortly thereafter and mirror the file.") + + if native.existing_rule(name): + print("\n\033[1;33mWarning:\033[0m skipping import of repository '" + + name + "' because it already exists.\n") + return + + _tf_http_archive( + name = name, + sha256 = sha256, + urls = urls, + **kwargs + ) + +def _tf_vendored_impl(repository_ctx): + parent_path = repository_ctx.path(repository_ctx.attr.parent).dirname + + # get_child doesn't allow slashes. Yes this is silly. bazel_skylib paths + # doesn't work with path objects. + relpath_parts = repository_ctx.attr.relpath.split("/") + vendored_path = parent_path + for part in relpath_parts: + vendored_path = vendored_path.get_child(part) + repository_ctx.symlink(vendored_path, ".") + +tf_vendored = repository_rule( + implementation = _tf_vendored_impl, + attrs = { + "parent": attr.label(default = "//:WORKSPACE"), + "relpath": attr.string(), + }, +)