2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2022-07-01 15:06:54 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
# JAX is Autograd and XLA
|
|
|
|
|
2022-07-06 06:13:20 -07:00
|
|
|
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
2022-07-01 15:06:54 -07:00
|
|
|
load(
|
|
|
|
"//jaxlib:jax.bzl",
|
|
|
|
"jax_extra_deps",
|
|
|
|
"jax_internal_packages",
|
|
|
|
"jax_test_util_visibility",
|
2023-03-10 08:40:28 -08:00
|
|
|
"jax_visibility",
|
2022-08-05 07:48:40 -07:00
|
|
|
"py_deps",
|
2022-07-01 15:06:54 -07:00
|
|
|
"py_library_providing_imports_info",
|
|
|
|
"pytype_library",
|
2023-03-09 05:26:58 -08:00
|
|
|
"pytype_strict_library",
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
package(default_visibility = [":internal"])
|
|
|
|
|
2023-02-01 21:25:46 +00:00
|
|
|
licenses(["notice"])
|
|
|
|
|
2022-07-03 15:04:37 -04:00
|
|
|
bool_flag(
|
|
|
|
name = "build_jaxlib",
|
|
|
|
build_setting_default = True,
|
|
|
|
)
|
|
|
|
|
|
|
|
config_setting(
|
|
|
|
name = "enable_jaxlib_build",
|
|
|
|
flag_values = {
|
|
|
|
":build_jaxlib": "True",
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
exports_files([
|
|
|
|
"LICENSE",
|
|
|
|
"version.py",
|
|
|
|
])
|
|
|
|
|
|
|
|
# Packages that have access to JAX-internal implementation details.
|
|
|
|
package_group(
|
|
|
|
name = "internal",
|
|
|
|
packages = [
|
|
|
|
"//...",
|
|
|
|
] + jax_internal_packages,
|
|
|
|
)
|
|
|
|
|
|
|
|
# JAX-private test utilities.
|
|
|
|
py_library(
|
|
|
|
# This build target is required in order to use private test utilities in jax._src.test_util,
|
|
|
|
# and its visibility is intentionally restricted to discourage its use outside JAX itself.
|
|
|
|
# JAX does provide some public test utilities (see jax/test_util.py);
|
|
|
|
# these are available in jax.test_util via the standard :jax target.
|
|
|
|
name = "test_util",
|
|
|
|
testonly = 1,
|
|
|
|
srcs = [
|
|
|
|
"_src/test_util.py",
|
|
|
|
],
|
|
|
|
visibility = [
|
|
|
|
":internal",
|
|
|
|
] + jax_test_util_visibility,
|
|
|
|
deps = [
|
|
|
|
":jax",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-02-16 15:29:12 -08:00
|
|
|
py_library(
|
|
|
|
name = "internal_test_util",
|
|
|
|
testonly = 1,
|
2023-03-10 10:28:55 -08:00
|
|
|
srcs = glob(["_src/internal_test_util/**/*.py"]),
|
2023-02-16 15:29:12 -08:00
|
|
|
visibility = [":internal"],
|
2023-03-07 08:49:05 -08:00
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
] + py_deps("numpy"),
|
2023-02-16 15:29:12 -08:00
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
py_library_providing_imports_info(
|
|
|
|
name = "jax",
|
2023-03-10 11:11:28 -08:00
|
|
|
srcs = [
|
2023-03-29 08:23:18 -07:00
|
|
|
"_src/__init__.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/ad_checkpoint.py",
|
|
|
|
"_src/api.py",
|
|
|
|
"_src/array.py",
|
|
|
|
"_src/callback.py",
|
|
|
|
"_src/checkify.py",
|
|
|
|
"_src/custom_batching.py",
|
|
|
|
"_src/custom_derivatives.py",
|
|
|
|
"_src/custom_transpose.py",
|
|
|
|
"_src/debugging.py",
|
|
|
|
"_src/dispatch.py",
|
|
|
|
"_src/dlpack.py",
|
|
|
|
"_src/flatten_util.py",
|
2023-03-30 06:12:10 -07:00
|
|
|
"_src/interpreters/__init__.py",
|
|
|
|
"_src/interpreters/ad.py",
|
|
|
|
"_src/interpreters/batching.py",
|
|
|
|
"_src/interpreters/pxla.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/maps.py",
|
|
|
|
"_src/pjit.py",
|
|
|
|
"_src/prng.py",
|
|
|
|
"_src/public_test_util.py",
|
|
|
|
"_src/random.py",
|
|
|
|
"_src/stages.py",
|
|
|
|
] + glob(
|
2022-07-01 15:06:54 -07:00
|
|
|
[
|
|
|
|
"*.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/debugger/**/*.py",
|
|
|
|
"_src/image/**/*.py",
|
|
|
|
"_src/lax/**/*.py",
|
|
|
|
"_src/nn/**/*.py",
|
|
|
|
"_src/numpy/**/*.py",
|
|
|
|
"_src/ops/**/*.py",
|
|
|
|
"_src/scipy/**/*.py",
|
|
|
|
"_src/state/**/*.py",
|
|
|
|
"_src/third_party/**/*.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"image/**/*.py",
|
|
|
|
"interpreters/**/*.py",
|
|
|
|
"lax/**/*.py",
|
|
|
|
"lib/**/*.py",
|
|
|
|
"nn/**/*.py",
|
|
|
|
"numpy/**/*.py",
|
|
|
|
"ops/**/*.py",
|
|
|
|
"scipy/**/*.py",
|
|
|
|
"third_party/**/*.py",
|
|
|
|
],
|
|
|
|
exclude = [
|
|
|
|
"*_test.py",
|
|
|
|
"**/*_test.py",
|
2023-02-16 15:29:12 -08:00
|
|
|
"_src/internal_test_util/**",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
) + [
|
|
|
|
# until new parallelism APIs are moved out of experimental
|
|
|
|
"experimental/maps.py",
|
|
|
|
"experimental/pjit.py",
|
|
|
|
"experimental/multihost_utils.py",
|
2022-11-04 15:29:10 -07:00
|
|
|
"experimental/shard_map.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
# until checkify is moved out of experimental
|
|
|
|
"experimental/checkify.py",
|
|
|
|
# to avoid circular dependencies
|
|
|
|
"experimental/compilation_cache/compilation_cache.py",
|
|
|
|
"experimental/compilation_cache/gfile_cache.py",
|
|
|
|
"experimental/compilation_cache/cache_interface.py",
|
|
|
|
],
|
|
|
|
lib_rule = pytype_library,
|
2023-03-13 11:12:50 -07:00
|
|
|
pytype_srcs = glob(
|
|
|
|
["_src/**/*.pyi"],
|
|
|
|
exclude = [
|
|
|
|
"_src/basearray.pyi",
|
|
|
|
],
|
|
|
|
),
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
2023-03-09 05:26:58 -08:00
|
|
|
deps = [
|
2023-03-30 06:12:10 -07:00
|
|
|
":abstract_arrays",
|
|
|
|
":ad_util",
|
|
|
|
":api_util",
|
2023-03-13 11:12:50 -07:00
|
|
|
":basearray",
|
2023-03-09 13:09:20 -08:00
|
|
|
":cloud_tpu_init",
|
2023-03-09 05:26:58 -08:00
|
|
|
":config",
|
2023-03-29 15:06:30 -07:00
|
|
|
":core",
|
2023-03-29 08:23:18 -07:00
|
|
|
":custom_api_util",
|
2023-03-10 12:25:25 -08:00
|
|
|
":deprecations",
|
|
|
|
":effects",
|
|
|
|
":environment_info",
|
2023-03-09 13:09:20 -08:00
|
|
|
":lazy_loader",
|
2023-03-13 08:31:16 -07:00
|
|
|
":mesh",
|
2023-03-30 14:05:24 -07:00
|
|
|
":mlir",
|
2023-03-09 13:09:20 -08:00
|
|
|
":monitoring",
|
2023-04-06 08:31:47 -07:00
|
|
|
":op_shardings",
|
2023-03-30 06:12:10 -07:00
|
|
|
":partial_eval",
|
2023-04-06 11:42:45 -07:00
|
|
|
":partition_spec",
|
2023-03-09 13:09:20 -08:00
|
|
|
":path",
|
2023-03-09 08:50:49 -08:00
|
|
|
":pretty_printer",
|
2023-03-10 11:38:08 -08:00
|
|
|
":profiler",
|
2023-03-13 08:49:39 -07:00
|
|
|
":sharding",
|
2023-04-10 10:15:08 -07:00
|
|
|
":sharding_impls",
|
2023-04-06 09:48:14 -07:00
|
|
|
":sharding_specs",
|
2023-03-10 08:40:28 -08:00
|
|
|
":source_info_util",
|
2023-03-09 10:33:11 -08:00
|
|
|
":traceback_util",
|
2023-03-10 14:51:08 -08:00
|
|
|
":tree_util",
|
2023-03-23 10:29:11 -07:00
|
|
|
":typing",
|
2023-03-09 05:26:58 -08:00
|
|
|
":util",
|
|
|
|
":version",
|
2023-03-30 14:05:24 -07:00
|
|
|
":xla",
|
2023-03-10 11:11:28 -08:00
|
|
|
":xla_bridge",
|
2023-03-09 05:26:58 -08:00
|
|
|
"//jax/_src/lib",
|
2023-03-30 06:12:10 -07:00
|
|
|
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + jax_extra_deps,
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "abstract_arrays",
|
|
|
|
srcs = ["_src/abstract_arrays.py"],
|
|
|
|
deps = [
|
|
|
|
":ad_util",
|
|
|
|
":core",
|
|
|
|
":traceback_util",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "ad_util",
|
|
|
|
srcs = ["_src/ad_util.py"],
|
|
|
|
deps = [
|
|
|
|
":core",
|
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":typing",
|
|
|
|
":util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "api_util",
|
|
|
|
srcs = ["_src/api_util.py"],
|
|
|
|
deps = [
|
|
|
|
":abstract_arrays",
|
|
|
|
":config",
|
|
|
|
":core",
|
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":util",
|
|
|
|
] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-13 11:12:50 -07:00
|
|
|
name = "basearray",
|
|
|
|
srcs = ["_src/basearray.py"],
|
|
|
|
pytype_srcs = ["_src/basearray.pyi"],
|
|
|
|
deps = [
|
|
|
|
":sharding",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "cloud_tpu_init",
|
|
|
|
srcs = ["_src/cloud_tpu_init.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 05:26:58 -08:00
|
|
|
name = "config",
|
|
|
|
srcs = ["_src/config.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-28 18:30:36 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "core",
|
|
|
|
srcs = [
|
|
|
|
"_src/core.py",
|
|
|
|
"_src/dtypes.py",
|
|
|
|
"_src/errors.py",
|
|
|
|
"_src/linear_util.py",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":effects",
|
|
|
|
":pretty_printer",
|
|
|
|
":source_info_util",
|
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":typing",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("ml_dtypes") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "custom_api_util",
|
|
|
|
srcs = ["_src/custom_api_util.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "deprecations",
|
|
|
|
srcs = ["_src/deprecations.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "effects",
|
|
|
|
srcs = ["_src/effects.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "environment_info",
|
|
|
|
srcs = ["_src/environment_info.py"],
|
|
|
|
deps = [
|
|
|
|
":version",
|
2023-03-29 08:23:18 -07:00
|
|
|
":xla_bridge",
|
2023-03-10 12:25:25 -08:00
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-10 11:11:28 -08:00
|
|
|
pytype_library(
|
|
|
|
name = "iree",
|
|
|
|
srcs = ["_src/iree.py"],
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-30 06:12:10 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "lax_reference",
|
|
|
|
srcs = ["_src/lax_reference.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("lax_reference"),
|
|
|
|
deps = [
|
|
|
|
":core",
|
|
|
|
":util",
|
|
|
|
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "lazy_loader",
|
|
|
|
srcs = ["_src/lazy_loader.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-13 08:31:16 -07:00
|
|
|
name = "mesh",
|
|
|
|
srcs = ["_src/mesh.py"],
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-30 14:05:24 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "mlir",
|
|
|
|
srcs = ["_src/interpreters/mlir.py"],
|
|
|
|
deps = [
|
|
|
|
":ad_util",
|
|
|
|
":config",
|
|
|
|
":core",
|
|
|
|
":effects",
|
2023-04-06 08:31:47 -07:00
|
|
|
":op_shardings",
|
2023-03-30 14:05:24 -07:00
|
|
|
":partial_eval",
|
2023-04-10 10:15:08 -07:00
|
|
|
":sharding_impls",
|
2023-03-30 14:05:24 -07:00
|
|
|
":source_info_util",
|
|
|
|
":util",
|
|
|
|
":xla",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "monitoring",
|
|
|
|
srcs = ["_src/monitoring.py"],
|
|
|
|
)
|
|
|
|
|
2023-04-06 08:31:47 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "op_shardings",
|
|
|
|
srcs = ["_src/op_shardings.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-30 06:12:10 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "partial_eval",
|
|
|
|
srcs = ["_src/interpreters/partial_eval.py"],
|
|
|
|
deps = [
|
2023-04-12 15:11:22 -07:00
|
|
|
":ad_util",
|
2023-03-30 06:12:10 -07:00
|
|
|
":api_util",
|
|
|
|
":config",
|
|
|
|
":core",
|
|
|
|
":effects",
|
|
|
|
":profiler",
|
|
|
|
":source_info_util",
|
2023-04-04 10:56:25 -07:00
|
|
|
":state_types",
|
2023-03-30 06:12:10 -07:00
|
|
|
":tree_util",
|
|
|
|
":util",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-04-06 11:42:45 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "partition_spec",
|
|
|
|
srcs = ["_src/partition_spec.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "path",
|
|
|
|
srcs = ["_src/path.py"],
|
|
|
|
deps = py_deps("epath"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 08:50:49 -08:00
|
|
|
name = "pretty_printer",
|
|
|
|
srcs = ["_src/pretty_printer.py"],
|
|
|
|
deps = [":config"] + py_deps("colorama"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 11:38:08 -08:00
|
|
|
name = "profiler",
|
|
|
|
srcs = ["_src/profiler.py"],
|
|
|
|
deps = [
|
|
|
|
":traceback_util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-13 08:49:39 -07:00
|
|
|
name = "sharding",
|
|
|
|
srcs = ["_src/sharding.py"],
|
|
|
|
deps = [
|
|
|
|
":util",
|
2023-04-14 08:46:17 -07:00
|
|
|
":xla_bridge",
|
2023-03-13 08:49:39 -07:00
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "sharding_impls",
|
|
|
|
srcs = ["_src/sharding_impls.py"],
|
|
|
|
deps = [
|
|
|
|
":mesh",
|
|
|
|
":op_shardings",
|
|
|
|
":partition_spec",
|
|
|
|
":sharding",
|
|
|
|
":sharding_specs",
|
|
|
|
":tree_util",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-04-06 09:48:14 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "sharding_specs",
|
|
|
|
srcs = ["_src/sharding_specs.py"],
|
|
|
|
deps = [
|
|
|
|
":op_shardings",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 08:40:28 -08:00
|
|
|
name = "source_info_util",
|
|
|
|
srcs = ["_src/source_info_util.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("source_info_util"),
|
|
|
|
deps = [
|
|
|
|
":traceback_util",
|
|
|
|
":version",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-04-04 10:56:25 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "state_types",
|
|
|
|
srcs = [
|
|
|
|
"_src/state/__init__.py",
|
|
|
|
"_src/state/types.py",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":core",
|
|
|
|
":effects",
|
|
|
|
":pretty_printer",
|
|
|
|
":util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 14:51:08 -08:00
|
|
|
name = "tree_util",
|
|
|
|
srcs = ["_src/tree_util.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("tree_util"),
|
|
|
|
deps = [
|
|
|
|
":traceback_util",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 10:33:11 -08:00
|
|
|
name = "traceback_util",
|
|
|
|
srcs = ["_src/traceback_util.py"],
|
2023-03-10 08:40:28 -08:00
|
|
|
visibility = [":internal"] + jax_visibility("traceback_util"),
|
2023-03-09 10:33:11 -08:00
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-23 10:29:11 -07:00
|
|
|
name = "typing",
|
|
|
|
srcs = [
|
|
|
|
"_src/typing.py",
|
|
|
|
],
|
|
|
|
deps = [":basearray"] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 05:26:58 -08:00
|
|
|
name = "util",
|
|
|
|
srcs = ["_src/util.py"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
2023-03-09 05:26:58 -08:00
|
|
|
":config",
|
|
|
|
"//jax/_src/lib",
|
2023-03-27 10:14:05 -07:00
|
|
|
] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-03-09 05:26:58 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "version",
|
|
|
|
srcs = ["version.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-30 14:05:24 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "xla",
|
|
|
|
srcs = ["_src/interpreters/xla.py"],
|
|
|
|
deps = [
|
|
|
|
":abstract_arrays",
|
|
|
|
":config",
|
|
|
|
":core",
|
2023-04-10 10:15:08 -07:00
|
|
|
":sharding_impls",
|
2023-03-30 14:05:24 -07:00
|
|
|
":source_info_util",
|
|
|
|
":typing",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-10 11:11:28 -08:00
|
|
|
# TODO(phawkins): break up this SCC.
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 11:11:28 -08:00
|
|
|
name = "xla_bridge",
|
|
|
|
srcs = [
|
|
|
|
"_src/clusters/__init__.py",
|
|
|
|
"_src/clusters/cloud_tpu_cluster.py",
|
|
|
|
"_src/clusters/cluster.py",
|
|
|
|
"_src/clusters/ompi_cluster.py",
|
|
|
|
"_src/clusters/slurm_cluster.py",
|
|
|
|
"_src/distributed.py",
|
|
|
|
"_src/xla_bridge.py",
|
|
|
|
],
|
2023-03-10 14:22:55 -08:00
|
|
|
visibility = [":internal"] + jax_visibility("xla_bridge"),
|
2023-03-10 11:11:28 -08:00
|
|
|
deps = [
|
|
|
|
":cloud_tpu_init",
|
|
|
|
":config",
|
|
|
|
":iree",
|
|
|
|
":traceback_util",
|
2023-03-29 08:23:18 -07:00
|
|
|
":util",
|
2023-03-10 11:11:28 -08:00
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-09 08:50:49 -08:00
|
|
|
# Public JAX libraries below this point.
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
py_library_providing_imports_info(
|
|
|
|
name = "experimental",
|
2023-02-16 13:29:50 -08:00
|
|
|
srcs = glob(
|
|
|
|
[
|
|
|
|
"experimental/*.py",
|
|
|
|
"example_libraries/*.py",
|
|
|
|
],
|
|
|
|
),
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("absl/logging") + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "stax",
|
|
|
|
srcs = [
|
|
|
|
"example_libraries/stax.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "experimental_sparse",
|
2022-11-16 09:58:06 -08:00
|
|
|
srcs = glob(
|
|
|
|
[
|
|
|
|
"experimental/sparse/*.py",
|
|
|
|
],
|
|
|
|
exclude = ["experimental/sparse/test_util.py"],
|
|
|
|
),
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
2022-11-16 09:58:06 -08:00
|
|
|
pytype_library(
|
|
|
|
name = "sparse_test_util",
|
|
|
|
testonly = 1,
|
|
|
|
srcs = [
|
|
|
|
"experimental/sparse/test_util.py",
|
|
|
|
],
|
|
|
|
visibility = [":internal"],
|
|
|
|
deps = [
|
|
|
|
":experimental_sparse",
|
|
|
|
":jax",
|
|
|
|
":test_util",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "optimizers",
|
|
|
|
srcs = [
|
|
|
|
"example_libraries/optimizers.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "ode",
|
|
|
|
srcs = ["experimental/ode.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
# TODO(apaszke): Remove this target
|
|
|
|
pytype_library(
|
|
|
|
name = "maps",
|
|
|
|
srcs = ["experimental/maps.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
2023-03-13 08:31:16 -07:00
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
":mesh",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
# TODO(apaszke): Remove this target
|
|
|
|
pytype_library(
|
|
|
|
name = "pjit",
|
|
|
|
srcs = ["experimental/pjit.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":experimental",
|
|
|
|
":jax",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "jet",
|
|
|
|
srcs = ["experimental/jet.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "experimental_host_callback",
|
|
|
|
srcs = ["experimental/host_callback.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "compilation_cache",
|
|
|
|
srcs = [
|
|
|
|
"experimental/compilation_cache/compilation_cache.py",
|
|
|
|
"experimental/compilation_cache/gfile_cache.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "mesh_utils",
|
|
|
|
srcs = ["experimental/mesh_utils.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":experimental",
|
|
|
|
":jax",
|
|
|
|
],
|
|
|
|
)
|
2022-11-28 14:31:10 -08:00
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "rnn",
|
|
|
|
srcs = ["experimental/rnn.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|