rocm_jax/WORKSPACE
Jean-Baptiste Lespiau bdd65453b4
Add more features to the C++ jax.jit. (#4169)
This mainly follows https://github.com/google/jax/pull/4089 by adding:

- support for disable_jit from C++
- support for jax._cpp_jit on methods.
- supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend.
- concurrency support.

I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.)

See:

- https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see
 cr/328899906 + benchmarks for how numbers were generated)
- The results of the Jax tests when enabling this:
http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure).
2020-09-01 10:34:47 +03:00

76 lines
2.4 KiB
Python

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "io_bazel_rules_closure",
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
],
)
# https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "bazel_skylib",
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
urls = [
"http://mirror.tensorflow.org/github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz",
"https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz",
],
)
# To update TensorFlow to a new revision,
# a) update URL and strip_prefix to the new git commit hash
# b) get the sha256 hash of the commit by running:
# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "3fb86bfd01986be94fa94c74c29ddade8e8981cb56ed8d449cf63a21664a0d8c",
strip_prefix = "tensorflow-3c75664e72c40fc202fd986903cea39bd526f63d",
urls = [
"https://github.com/tensorflow/tensorflow/archive/3c75664e72c40fc202fd986903cea39bd526f63d.tar.gz",
],
)
# For development, one can use a local TF repository instead.
# local_repository(
# name = "org_tensorflow",
# path = "tensorflow",
# )
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace", "tf_bind")
tf_workspace(
path_prefix = "",
tf_repo_name = "org_tensorflow",
)
tf_bind()
# Required for TensorFlow dependency on @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")