2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2022-07-01 15:06:54 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
# JAX is Autograd and XLA
|
|
|
|
|
2022-07-06 06:13:20 -07:00
|
|
|
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
2024-03-27 10:26:38 -07:00
|
|
|
load("@rules_python//python:defs.bzl", "py_library")
|
2022-07-01 15:06:54 -07:00
|
|
|
load(
|
|
|
|
"//jaxlib:jax.bzl",
|
2023-12-05 08:04:41 -08:00
|
|
|
"if_building_jaxlib",
|
2025-01-06 10:56:57 -08:00
|
|
|
"jax_export_file_visibility",
|
2023-08-24 14:40:10 -07:00
|
|
|
"jax_extend_internal_users",
|
2022-07-01 15:06:54 -07:00
|
|
|
"jax_extra_deps",
|
2023-11-17 02:04:49 -08:00
|
|
|
"jax_internal_export_back_compat_test_util_visibility",
|
2022-07-01 15:06:54 -07:00
|
|
|
"jax_internal_packages",
|
2023-11-09 13:57:30 -08:00
|
|
|
"jax_internal_test_harnesses_visibility",
|
2022-07-01 15:06:54 -07:00
|
|
|
"jax_test_util_visibility",
|
2023-03-10 08:40:28 -08:00
|
|
|
"jax_visibility",
|
2024-05-17 04:29:43 -07:00
|
|
|
"mosaic_gpu_internal_users",
|
2023-07-20 18:28:18 -07:00
|
|
|
"mosaic_internal_users",
|
2023-08-01 16:42:26 -07:00
|
|
|
"pallas_gpu_internal_users",
|
|
|
|
"pallas_tpu_internal_users",
|
2022-08-05 07:48:40 -07:00
|
|
|
"py_deps",
|
2022-07-01 15:06:54 -07:00
|
|
|
"py_library_providing_imports_info",
|
|
|
|
"pytype_library",
|
2023-03-09 05:26:58 -08:00
|
|
|
"pytype_strict_library",
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-04-19 13:26:24 -07:00
|
|
|
package(
|
|
|
|
default_applicable_licenses = [],
|
|
|
|
default_visibility = [":internal"],
|
|
|
|
)
|
2022-07-01 15:06:54 -07:00
|
|
|
|
2023-02-01 21:25:46 +00:00
|
|
|
licenses(["notice"])
|
|
|
|
|
2023-12-05 08:04:41 -08:00
|
|
|
# If this flag is true, jaxlib should be built by bazel. If false, then we do not build jaxlib and
|
|
|
|
# assume it has been installed, e.g., by `pip`.
|
2022-07-03 15:04:37 -04:00
|
|
|
bool_flag(
|
|
|
|
name = "build_jaxlib",
|
|
|
|
build_setting_default = True,
|
|
|
|
)
|
|
|
|
|
|
|
|
config_setting(
|
|
|
|
name = "enable_jaxlib_build",
|
|
|
|
flag_values = {
|
|
|
|
":build_jaxlib": "True",
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
exports_files([
|
|
|
|
"LICENSE",
|
|
|
|
"version.py",
|
|
|
|
])
|
|
|
|
|
2025-01-06 10:56:57 -08:00
|
|
|
exports_files(
|
|
|
|
["_src/export/serialization.fbs"],
|
|
|
|
visibility = jax_export_file_visibility,
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
# Packages that have access to JAX-internal implementation details.
|
|
|
|
package_group(
|
|
|
|
name = "internal",
|
|
|
|
packages = [
|
|
|
|
"//...",
|
|
|
|
] + jax_internal_packages,
|
|
|
|
)
|
|
|
|
|
2023-08-24 14:40:10 -07:00
|
|
|
package_group(
|
|
|
|
name = "jax_extend_users",
|
2023-12-04 17:48:17 -08:00
|
|
|
includes = [":internal"],
|
2023-08-24 14:40:10 -07:00
|
|
|
packages = [
|
|
|
|
# Intentionally avoid jax dependencies on jax.extend.
|
|
|
|
# See https://jax.readthedocs.io/en/latest/jep/15856-jex.html
|
2024-08-26 09:10:26 -07:00
|
|
|
"//tests/...",
|
2023-08-24 14:40:10 -07:00
|
|
|
] + jax_extend_internal_users,
|
|
|
|
)
|
|
|
|
|
2023-07-20 18:28:18 -07:00
|
|
|
package_group(
|
|
|
|
name = "mosaic_users",
|
2024-08-26 09:10:26 -07:00
|
|
|
includes = [":internal"],
|
|
|
|
packages = mosaic_internal_users,
|
2023-07-20 18:28:18 -07:00
|
|
|
)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
package_group(
|
|
|
|
name = "pallas_gpu_users",
|
2024-08-26 09:10:26 -07:00
|
|
|
includes = [":internal"],
|
|
|
|
packages = pallas_gpu_internal_users,
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
package_group(
|
|
|
|
name = "pallas_tpu_users",
|
2024-08-26 09:10:26 -07:00
|
|
|
includes = [":internal"],
|
|
|
|
packages = pallas_tpu_internal_users,
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
|
2024-04-18 04:03:03 -07:00
|
|
|
package_group(
|
|
|
|
name = "mosaic_gpu_users",
|
2024-08-26 09:10:26 -07:00
|
|
|
includes = [":internal"],
|
|
|
|
packages = mosaic_gpu_internal_users,
|
2024-04-18 04:03:03 -07:00
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
# JAX-private test utilities.
|
|
|
|
py_library(
|
|
|
|
# This build target is required in order to use private test utilities in jax._src.test_util,
|
|
|
|
# and its visibility is intentionally restricted to discourage its use outside JAX itself.
|
|
|
|
# JAX does provide some public test utilities (see jax/test_util.py);
|
|
|
|
# these are available in jax.test_util via the standard :jax target.
|
|
|
|
name = "test_util",
|
|
|
|
testonly = 1,
|
|
|
|
srcs = [
|
|
|
|
"_src/test_util.py",
|
2024-12-19 21:10:21 -05:00
|
|
|
"_src/test_warning_util.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
visibility = [
|
|
|
|
":internal",
|
|
|
|
] + jax_test_util_visibility,
|
|
|
|
deps = [
|
|
|
|
":jax",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-11-17 02:04:49 -08:00
|
|
|
# TODO(necula): break the internal_test_util into smaller build targets.
|
2023-02-16 15:29:12 -08:00
|
|
|
py_library(
|
|
|
|
name = "internal_test_util",
|
|
|
|
testonly = 1,
|
2023-12-18 09:49:11 -08:00
|
|
|
srcs = [
|
|
|
|
"_src/internal_test_util/deprecation_module.py",
|
|
|
|
"_src/internal_test_util/lax_test_util.py",
|
|
|
|
] + glob(
|
2023-11-17 02:04:49 -08:00
|
|
|
[
|
2023-12-18 09:49:11 -08:00
|
|
|
"_src/internal_test_util/lazy_loader_module/*.py",
|
|
|
|
],
|
2023-11-09 13:57:30 -08:00
|
|
|
),
|
2023-02-16 15:29:12 -08:00
|
|
|
visibility = [":internal"],
|
2023-03-07 08:49:05 -08:00
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
] + py_deps("numpy"),
|
2023-02-16 15:29:12 -08:00
|
|
|
)
|
|
|
|
|
2023-11-09 13:57:30 -08:00
|
|
|
py_library(
|
|
|
|
name = "internal_test_harnesses",
|
|
|
|
testonly = 1,
|
|
|
|
srcs = ["_src/internal_test_util/test_harnesses.py"],
|
|
|
|
visibility = [":internal"] + jax_internal_test_harnesses_visibility,
|
|
|
|
deps = [
|
2024-10-04 13:55:36 -07:00
|
|
|
":ad_util",
|
|
|
|
":config",
|
2023-11-09 13:57:30 -08:00
|
|
|
":jax",
|
2024-10-04 13:55:36 -07:00
|
|
|
":test_util",
|
|
|
|
"//jax/_src/lib",
|
2023-11-09 13:57:30 -08:00
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-11-17 02:04:49 -08:00
|
|
|
py_library(
|
|
|
|
name = "internal_export_back_compat_test_util",
|
|
|
|
testonly = 1,
|
|
|
|
srcs = ["_src/internal_test_util/export_back_compat_test_util.py"],
|
|
|
|
visibility = [
|
|
|
|
":internal",
|
|
|
|
] + jax_internal_export_back_compat_test_util_visibility,
|
|
|
|
deps = [
|
|
|
|
":jax",
|
2025-01-08 18:33:23 -08:00
|
|
|
":test_util",
|
2023-11-17 02:04:49 -08:00
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-12-18 09:49:11 -08:00
|
|
|
py_library(
|
|
|
|
name = "internal_export_back_compat_test_data",
|
|
|
|
testonly = 1,
|
2024-05-02 05:37:41 -07:00
|
|
|
srcs = glob([
|
|
|
|
"_src/internal_test_util/export_back_compat_test_data/*.py",
|
|
|
|
"_src/internal_test_util/export_back_compat_test_data/pallas/*.py",
|
|
|
|
]),
|
2023-12-18 09:49:11 -08:00
|
|
|
visibility = [
|
|
|
|
":internal",
|
|
|
|
],
|
|
|
|
deps = py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
py_library_providing_imports_info(
|
|
|
|
name = "jax",
|
2023-03-10 11:11:28 -08:00
|
|
|
srcs = [
|
2023-03-29 08:23:18 -07:00
|
|
|
"_src/__init__.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/ad_checkpoint.py",
|
|
|
|
"_src/api.py",
|
|
|
|
"_src/array.py",
|
2024-06-24 11:19:59 -07:00
|
|
|
"_src/blocked_sampler.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/callback.py",
|
|
|
|
"_src/checkify.py",
|
|
|
|
"_src/custom_batching.py",
|
2025-01-23 11:38:06 -08:00
|
|
|
"_src/custom_dce.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/custom_derivatives.py",
|
2024-07-08 04:30:11 -07:00
|
|
|
"_src/custom_partitioning.py",
|
2024-12-05 11:32:43 -08:00
|
|
|
"_src/custom_partitioning_sharding_rule.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/custom_transpose.py",
|
|
|
|
"_src/debugging.py",
|
|
|
|
"_src/dispatch.py",
|
|
|
|
"_src/dlpack.py",
|
2024-03-14 15:53:33 -07:00
|
|
|
"_src/earray.py",
|
2024-12-20 11:26:04 +00:00
|
|
|
"_src/ffi.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/flatten_util.py",
|
2023-03-30 06:12:10 -07:00
|
|
|
"_src/interpreters/__init__.py",
|
|
|
|
"_src/interpreters/ad.py",
|
|
|
|
"_src/interpreters/batching.py",
|
|
|
|
"_src/interpreters/pxla.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/pjit.py",
|
|
|
|
"_src/prng.py",
|
|
|
|
"_src/public_test_util.py",
|
|
|
|
"_src/random.py",
|
2023-12-19 16:30:48 -08:00
|
|
|
"_src/shard_alike.py",
|
2024-05-08 05:44:08 -07:00
|
|
|
"_src/sourcemap.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/stages.py",
|
2024-02-12 13:07:59 -08:00
|
|
|
"_src/tree.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
] + glob(
|
2022-07-01 15:06:54 -07:00
|
|
|
[
|
|
|
|
"*.py",
|
2024-07-08 06:15:20 -07:00
|
|
|
"_src/cudnn/**/*.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/debugger/**/*.py",
|
2023-08-25 17:36:15 -07:00
|
|
|
"_src/extend/**/*.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/image/**/*.py",
|
2024-06-04 22:02:36 -07:00
|
|
|
"_src/export/**/*.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/lax/**/*.py",
|
|
|
|
"_src/nn/**/*.py",
|
|
|
|
"_src/numpy/**/*.py",
|
|
|
|
"_src/ops/**/*.py",
|
|
|
|
"_src/scipy/**/*.py",
|
|
|
|
"_src/state/**/*.py",
|
|
|
|
"_src/third_party/**/*.py",
|
2023-12-11 12:03:48 -08:00
|
|
|
"experimental/key_reuse/**/*.py",
|
2024-11-27 13:29:27 -08:00
|
|
|
"experimental/roofline/**/*.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"image/**/*.py",
|
|
|
|
"interpreters/**/*.py",
|
|
|
|
"lax/**/*.py",
|
|
|
|
"lib/**/*.py",
|
|
|
|
"nn/**/*.py",
|
|
|
|
"numpy/**/*.py",
|
|
|
|
"ops/**/*.py",
|
|
|
|
"scipy/**/*.py",
|
|
|
|
"third_party/**/*.py",
|
|
|
|
],
|
|
|
|
exclude = [
|
|
|
|
"*_test.py",
|
|
|
|
"**/*_test.py",
|
2023-02-16 15:29:12 -08:00
|
|
|
"_src/internal_test_util/**",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
) + [
|
2024-01-25 22:20:36 -08:00
|
|
|
"experimental/attrs.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"experimental/pjit.py",
|
|
|
|
"experimental/multihost_utils.py",
|
2022-11-04 15:29:10 -07:00
|
|
|
"experimental/shard_map.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
# until checkify is moved out of experimental
|
|
|
|
"experimental/checkify.py",
|
|
|
|
"experimental/compilation_cache/compilation_cache.py",
|
|
|
|
],
|
|
|
|
lib_rule = pytype_library,
|
2023-03-13 11:12:50 -07:00
|
|
|
pytype_srcs = glob(
|
2023-08-22 11:50:09 -07:00
|
|
|
[
|
|
|
|
"numpy/*.pyi",
|
|
|
|
"_src/**/*.pyi",
|
|
|
|
],
|
2023-03-13 11:12:50 -07:00
|
|
|
exclude = [
|
|
|
|
"_src/basearray.pyi",
|
|
|
|
],
|
|
|
|
),
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
2023-03-09 05:26:58 -08:00
|
|
|
deps = [
|
2023-03-30 06:12:10 -07:00
|
|
|
":abstract_arrays",
|
|
|
|
":ad_util",
|
|
|
|
":api_util",
|
2023-03-13 11:12:50 -07:00
|
|
|
":basearray",
|
2023-03-09 13:09:20 -08:00
|
|
|
":cloud_tpu_init",
|
2023-04-17 08:35:26 -07:00
|
|
|
":compilation_cache_internal",
|
2023-08-15 06:38:56 -07:00
|
|
|
":compiler",
|
2024-05-17 15:58:25 -07:00
|
|
|
":compute_on",
|
2023-03-09 05:26:58 -08:00
|
|
|
":config",
|
2023-03-29 15:06:30 -07:00
|
|
|
":core",
|
2023-03-29 08:23:18 -07:00
|
|
|
":custom_api_util",
|
2023-03-10 12:25:25 -08:00
|
|
|
":deprecations",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-10 12:25:25 -08:00
|
|
|
":effects",
|
|
|
|
":environment_info",
|
2024-09-03 14:30:37 -07:00
|
|
|
":internal_mesh_utils",
|
2023-08-08 10:08:19 -07:00
|
|
|
":jaxpr_util",
|
2023-11-15 08:48:17 -08:00
|
|
|
":layout",
|
2023-03-09 13:09:20 -08:00
|
|
|
":lazy_loader",
|
2023-03-13 08:31:16 -07:00
|
|
|
":mesh",
|
2023-03-30 14:05:24 -07:00
|
|
|
":mlir",
|
2023-03-09 13:09:20 -08:00
|
|
|
":monitoring",
|
2023-04-06 08:31:47 -07:00
|
|
|
":op_shardings",
|
2023-03-30 06:12:10 -07:00
|
|
|
":partial_eval",
|
2023-04-06 11:42:45 -07:00
|
|
|
":partition_spec",
|
2023-03-09 13:09:20 -08:00
|
|
|
":path",
|
2023-06-09 13:34:09 -07:00
|
|
|
":pickle_util",
|
2023-03-09 08:50:49 -08:00
|
|
|
":pretty_printer",
|
2023-03-10 11:38:08 -08:00
|
|
|
":profiler",
|
2023-03-13 08:49:39 -07:00
|
|
|
":sharding",
|
2023-04-10 10:15:08 -07:00
|
|
|
":sharding_impls",
|
2023-04-06 09:48:14 -07:00
|
|
|
":sharding_specs",
|
2023-03-10 08:40:28 -08:00
|
|
|
":source_info_util",
|
2023-03-09 10:33:11 -08:00
|
|
|
":traceback_util",
|
2023-03-10 14:51:08 -08:00
|
|
|
":tree_util",
|
2023-03-23 10:29:11 -07:00
|
|
|
":typing",
|
2023-03-09 05:26:58 -08:00
|
|
|
":util",
|
|
|
|
":version",
|
2023-03-30 14:05:24 -07:00
|
|
|
":xla",
|
2023-03-10 11:11:28 -08:00
|
|
|
":xla_bridge",
|
2024-09-06 16:44:19 -07:00
|
|
|
":xla_metadata",
|
2023-03-09 05:26:58 -08:00
|
|
|
"//jax/_src/lib",
|
2024-06-09 08:58:54 -07:00
|
|
|
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps,
|
2023-03-30 06:12:10 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "abstract_arrays",
|
|
|
|
srcs = ["_src/abstract_arrays.py"],
|
|
|
|
deps = [
|
|
|
|
":ad_util",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-30 06:12:10 -07:00
|
|
|
":traceback_util",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "ad_util",
|
|
|
|
srcs = ["_src/ad_util.py"],
|
|
|
|
deps = [
|
|
|
|
":core",
|
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":typing",
|
|
|
|
":util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "api_util",
|
|
|
|
srcs = ["_src/api_util.py"],
|
|
|
|
deps = [
|
|
|
|
":abstract_arrays",
|
|
|
|
":config",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2024-12-18 18:23:33 -08:00
|
|
|
":state_types",
|
2023-03-30 06:12:10 -07:00
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":util",
|
|
|
|
] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-13 11:12:50 -07:00
|
|
|
name = "basearray",
|
|
|
|
srcs = ["_src/basearray.py"],
|
|
|
|
pytype_srcs = ["_src/basearray.pyi"],
|
|
|
|
deps = [
|
2025-01-16 10:50:30 -08:00
|
|
|
":partition_spec",
|
2023-03-13 11:12:50 -07:00
|
|
|
":sharding",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "cloud_tpu_init",
|
|
|
|
srcs = ["_src/cloud_tpu_init.py"],
|
2023-12-15 20:35:58 +00:00
|
|
|
deps = [
|
2024-05-28 13:42:18 -07:00
|
|
|
":config",
|
2023-12-15 20:35:58 +00:00
|
|
|
":hardware_utils",
|
2024-05-24 13:55:34 -07:00
|
|
|
":version",
|
2024-01-12 17:12:16 -08:00
|
|
|
],
|
2023-03-09 13:09:20 -08:00
|
|
|
)
|
|
|
|
|
2023-04-17 08:35:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "compilation_cache_internal",
|
|
|
|
srcs = ["_src/compilation_cache.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("compilation_cache"),
|
|
|
|
deps = [
|
2023-07-27 23:00:26 -07:00
|
|
|
":cache_key",
|
2023-04-17 08:35:26 -07:00
|
|
|
":compilation_cache_interface",
|
2023-04-19 13:26:24 -07:00
|
|
|
":config",
|
2024-05-30 17:59:05 +04:00
|
|
|
":lru_cache",
|
2024-01-08 14:02:54 -08:00
|
|
|
":monitoring",
|
2023-04-17 08:35:26 -07:00
|
|
|
":path",
|
|
|
|
"//jax/_src/lib",
|
2023-04-20 06:16:12 -07:00
|
|
|
] + py_deps("numpy") + py_deps("zstandard"),
|
2023-04-17 08:35:26 -07:00
|
|
|
)
|
|
|
|
|
2023-07-27 23:00:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "cache_key",
|
|
|
|
srcs = ["_src/cache_key.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("compilation_cache"),
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-04-17 08:35:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "compilation_cache_interface",
|
|
|
|
srcs = ["_src/compilation_cache_interface.py"],
|
2024-01-26 13:52:47 +00:00
|
|
|
deps = [
|
|
|
|
":path",
|
|
|
|
":util",
|
|
|
|
],
|
2023-04-17 08:35:26 -07:00
|
|
|
)
|
|
|
|
|
2024-05-30 17:59:05 +04:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "lru_cache",
|
|
|
|
srcs = ["_src/lru_cache.py"],
|
|
|
|
deps = [
|
|
|
|
":compilation_cache_interface",
|
2024-07-03 13:20:55 +00:00
|
|
|
":path",
|
2024-05-30 17:59:05 +04:00
|
|
|
] + py_deps("filelock"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 05:26:58 -08:00
|
|
|
name = "config",
|
|
|
|
srcs = ["_src/config.py"],
|
|
|
|
deps = [
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
":logging_config",
|
2023-03-09 05:26:58 -08:00
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "logging_config",
|
|
|
|
srcs = ["_src/logging_config.py"],
|
|
|
|
)
|
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "compiler",
|
|
|
|
srcs = ["_src/compiler.py"],
|
|
|
|
deps = [
|
2024-11-26 04:05:35 -08:00
|
|
|
":cache_key",
|
2023-08-15 06:38:56 -07:00
|
|
|
":compilation_cache_internal",
|
|
|
|
":config",
|
2023-12-18 21:24:59 -08:00
|
|
|
":mlir",
|
2023-08-15 06:38:56 -07:00
|
|
|
":monitoring",
|
2024-07-29 16:13:01 -07:00
|
|
|
":path",
|
2023-08-15 06:38:56 -07:00
|
|
|
":profiler",
|
|
|
|
":traceback_util",
|
2024-01-11 23:37:22 -08:00
|
|
|
":xla_bridge",
|
2023-08-15 06:38:56 -07:00
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-28 18:30:36 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "core",
|
|
|
|
srcs = [
|
|
|
|
"_src/core.py",
|
|
|
|
"_src/errors.py",
|
|
|
|
"_src/linear_util.py",
|
|
|
|
],
|
|
|
|
deps = [
|
2024-05-17 16:31:23 -07:00
|
|
|
":compute_on",
|
2023-03-28 18:30:36 -07:00
|
|
|
":config",
|
2024-06-13 13:14:27 -07:00
|
|
|
":deprecations",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-28 18:30:36 -07:00
|
|
|
":effects",
|
2024-11-20 13:06:39 -08:00
|
|
|
":mesh",
|
2024-12-10 18:02:42 -08:00
|
|
|
":partition_spec",
|
2023-03-28 18:30:36 -07:00
|
|
|
":pretty_printer",
|
|
|
|
":source_info_util",
|
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":typing",
|
|
|
|
":util",
|
2024-09-06 16:44:19 -07:00
|
|
|
":xla_metadata",
|
2023-03-28 18:30:36 -07:00
|
|
|
"//jax/_src/lib",
|
2023-05-17 07:58:19 -07:00
|
|
|
] + py_deps("numpy"),
|
2023-03-28 18:30:36 -07:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "custom_api_util",
|
|
|
|
srcs = ["_src/custom_api_util.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "deprecations",
|
|
|
|
srcs = ["_src/deprecations.py"],
|
|
|
|
)
|
|
|
|
|
2023-05-17 07:58:19 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "dtypes",
|
|
|
|
srcs = [
|
|
|
|
"_src/dtypes.py",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":traceback_util",
|
|
|
|
":typing",
|
2023-10-17 15:07:28 -07:00
|
|
|
":util",
|
2025-02-06 14:55:57 +00:00
|
|
|
"//jax/_src/lib",
|
2023-05-17 07:58:19 -07:00
|
|
|
] + py_deps("ml_dtypes") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "effects",
|
|
|
|
srcs = ["_src/effects.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "environment_info",
|
|
|
|
srcs = ["_src/environment_info.py"],
|
|
|
|
deps = [
|
|
|
|
":version",
|
2023-03-29 08:23:18 -07:00
|
|
|
":xla_bridge",
|
2023-03-10 12:25:25 -08:00
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-12-15 20:35:58 +00:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "hardware_utils",
|
|
|
|
srcs = ["_src/hardware_utils.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-30 06:12:10 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "lax_reference",
|
|
|
|
srcs = ["_src/lax_reference.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("lax_reference"),
|
|
|
|
deps = [
|
|
|
|
":core",
|
|
|
|
":util",
|
|
|
|
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "lazy_loader",
|
|
|
|
srcs = ["_src/lazy_loader.py"],
|
|
|
|
)
|
|
|
|
|
2023-08-08 10:08:19 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "jaxpr_util",
|
|
|
|
srcs = ["_src/jaxpr_util.py"],
|
|
|
|
deps = [
|
|
|
|
":core",
|
|
|
|
":source_info_util",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-13 08:31:16 -07:00
|
|
|
name = "mesh",
|
|
|
|
srcs = ["_src/mesh.py"],
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-30 14:05:24 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "mlir",
|
|
|
|
srcs = ["_src/interpreters/mlir.py"],
|
|
|
|
deps = [
|
|
|
|
":ad_util",
|
2025-01-24 12:53:51 +02:00
|
|
|
":api_util",
|
2023-03-30 14:05:24 -07:00
|
|
|
":config",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-30 14:05:24 -07:00
|
|
|
":effects",
|
2023-11-15 08:48:17 -08:00
|
|
|
":layout",
|
2023-04-06 08:31:47 -07:00
|
|
|
":op_shardings",
|
2023-03-30 14:05:24 -07:00
|
|
|
":partial_eval",
|
2024-12-10 18:02:42 -08:00
|
|
|
":partition_spec",
|
2023-12-18 21:24:59 -08:00
|
|
|
":path",
|
2023-06-09 13:34:09 -07:00
|
|
|
":pickle_util",
|
2024-06-05 09:06:36 -07:00
|
|
|
":sharding",
|
2023-04-10 10:15:08 -07:00
|
|
|
":sharding_impls",
|
2023-03-30 14:05:24 -07:00
|
|
|
":source_info_util",
|
2024-02-29 22:18:05 -08:00
|
|
|
":state_types",
|
2023-03-30 14:05:24 -07:00
|
|
|
":util",
|
|
|
|
":xla",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "monitoring",
|
|
|
|
srcs = ["_src/monitoring.py"],
|
|
|
|
)
|
|
|
|
|
2023-04-06 08:31:47 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "op_shardings",
|
|
|
|
srcs = ["_src/op_shardings.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2025-01-29 15:01:07 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "source_mapper",
|
|
|
|
srcs = glob(include = ["experimental/source_mapper/**/*.py"]),
|
|
|
|
visibility = [
|
|
|
|
"//visibility:public",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":core",
|
|
|
|
":jax",
|
|
|
|
":source_info_util",
|
|
|
|
] + py_deps("absl/flags"),
|
|
|
|
)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas",
|
|
|
|
srcs = glob(
|
|
|
|
[
|
|
|
|
"experimental/pallas/**/*.py",
|
|
|
|
],
|
|
|
|
exclude = [
|
|
|
|
"experimental/pallas/gpu.py",
|
2024-10-07 04:04:16 -07:00
|
|
|
"experimental/pallas/mosaic_gpu.py",
|
2024-06-05 01:34:07 -07:00
|
|
|
"experimental/pallas/ops/gpu/**/*.py",
|
2024-01-11 14:41:21 -08:00
|
|
|
"experimental/pallas/ops/tpu/**/*.py",
|
2024-10-07 04:04:16 -07:00
|
|
|
"experimental/pallas/tpu.py",
|
|
|
|
"experimental/pallas/triton.py",
|
2023-08-01 16:42:26 -07:00
|
|
|
],
|
|
|
|
),
|
|
|
|
visibility = [
|
|
|
|
"//visibility:public",
|
|
|
|
],
|
|
|
|
deps = [
|
2024-07-09 10:42:24 -07:00
|
|
|
":deprecations",
|
2023-08-01 16:42:26 -07:00
|
|
|
":jax",
|
2023-10-17 09:05:07 -07:00
|
|
|
":source_info_util",
|
2023-08-01 16:42:26 -07:00
|
|
|
"//jax/_src/pallas",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_tpu",
|
|
|
|
srcs = ["experimental/pallas/tpu.py"],
|
|
|
|
visibility = [
|
|
|
|
":pallas_tpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
2024-06-10 05:58:35 -07:00
|
|
|
":pallas", # build_cleaner: keep
|
2023-09-18 10:53:10 -07:00
|
|
|
":tpu_custom_call",
|
2024-07-24 17:13:49 -07:00
|
|
|
"//jax/_src/pallas",
|
2024-06-10 05:58:35 -07:00
|
|
|
"//jax/_src/pallas/mosaic:core",
|
2025-01-15 18:34:15 -08:00
|
|
|
"//jax/_src/pallas/mosaic:helpers",
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
"//jax/_src/pallas/mosaic:interpret",
|
2024-06-10 05:58:35 -07:00
|
|
|
"//jax/_src/pallas/mosaic:lowering",
|
|
|
|
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
|
|
|
|
"//jax/_src/pallas/mosaic:pipeline",
|
|
|
|
"//jax/_src/pallas/mosaic:primitives",
|
2024-06-10 18:07:33 -07:00
|
|
|
"//jax/_src/pallas/mosaic:random",
|
2024-07-17 05:28:34 -07:00
|
|
|
"//jax/_src/pallas/mosaic:verification",
|
2023-08-01 16:42:26 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-02-15 11:43:31 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_gpu_ops",
|
2024-10-29 05:20:02 -07:00
|
|
|
srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"],
|
2024-02-15 11:43:31 -08:00
|
|
|
visibility = [
|
|
|
|
":pallas_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
":pallas",
|
|
|
|
":pallas_gpu",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-10-29 05:20:02 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_experimental_gpu_ops",
|
|
|
|
testonly = True,
|
|
|
|
srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"],
|
|
|
|
visibility = [
|
|
|
|
":mosaic_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
":mosaic_gpu",
|
|
|
|
":pallas",
|
|
|
|
":pallas_mosaic_gpu",
|
|
|
|
":test_util", # This is only to make them runnable as jax_multiplatform_test...
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-08-26 08:03:04 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_tpu_ops",
|
2024-01-11 14:41:21 -08:00
|
|
|
srcs = glob(["experimental/pallas/ops/tpu/**/*.py"]),
|
2023-08-26 08:03:04 -07:00
|
|
|
visibility = [
|
|
|
|
":pallas_tpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
":pallas",
|
|
|
|
":pallas_tpu",
|
2024-01-11 14:41:21 -08:00
|
|
|
] + py_deps("numpy"),
|
2023-08-26 08:03:04 -07:00
|
|
|
)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_gpu",
|
2024-10-07 05:43:14 -07:00
|
|
|
visibility = [
|
|
|
|
":pallas_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":pallas_triton",
|
|
|
|
# TODO(slebedev): Add :pallas_mosaic_gpu once it is ready.
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_triton",
|
2024-10-07 04:04:16 -07:00
|
|
|
srcs = [
|
|
|
|
"experimental/pallas/gpu.py",
|
|
|
|
"experimental/pallas/triton.py",
|
|
|
|
],
|
2023-08-01 16:42:26 -07:00
|
|
|
visibility = [
|
|
|
|
":pallas_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
2024-10-07 04:04:16 -07:00
|
|
|
":deprecations",
|
2024-09-04 13:31:35 -07:00
|
|
|
"//jax/_src/pallas/triton:core",
|
2024-06-10 05:58:35 -07:00
|
|
|
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
|
|
|
|
"//jax/_src/pallas/triton:primitives",
|
2024-05-30 01:45:31 -07:00
|
|
|
],
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
|
2024-10-07 04:04:16 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_mosaic_gpu",
|
|
|
|
srcs = ["experimental/pallas/mosaic_gpu.py"],
|
|
|
|
visibility = [
|
|
|
|
":mosaic_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax/_src/pallas/mosaic_gpu:core",
|
|
|
|
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
|
2024-10-28 08:25:08 -07:00
|
|
|
"//jax/_src/pallas/mosaic_gpu:pipeline",
|
2024-10-07 04:04:16 -07:00
|
|
|
"//jax/_src/pallas/mosaic_gpu:primitives",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-04-18 04:03:03 -07:00
|
|
|
# This target only supports sm_90 GPUs.
|
|
|
|
py_library(
|
|
|
|
name = "mosaic_gpu",
|
|
|
|
srcs = glob(["experimental/mosaic/gpu/*.py"]),
|
|
|
|
visibility = [
|
|
|
|
":mosaic_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":config",
|
2024-05-22 04:40:54 -07:00
|
|
|
":core",
|
2024-04-18 04:03:03 -07:00
|
|
|
":jax",
|
|
|
|
":mlir",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
"//jaxlib/mlir:arithmetic_dialect",
|
|
|
|
"//jaxlib/mlir:builtin_dialect",
|
|
|
|
"//jaxlib/mlir:func_dialect",
|
|
|
|
"//jaxlib/mlir:gpu_dialect",
|
|
|
|
"//jaxlib/mlir:ir",
|
|
|
|
"//jaxlib/mlir:llvm_dialect",
|
|
|
|
"//jaxlib/mlir:math_dialect",
|
|
|
|
"//jaxlib/mlir:memref_dialect",
|
|
|
|
"//jaxlib/mlir:nvgpu_dialect",
|
|
|
|
"//jaxlib/mlir:nvvm_dialect",
|
|
|
|
"//jaxlib/mlir:pass_manager",
|
|
|
|
"//jaxlib/mlir:scf_dialect",
|
|
|
|
"//jaxlib/mlir:vector_dialect",
|
2024-05-30 01:45:31 -07:00
|
|
|
] + py_deps("absl/flags") + py_deps("numpy"),
|
2024-04-18 04:03:03 -07:00
|
|
|
)
|
|
|
|
|
2023-03-30 06:12:10 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "partial_eval",
|
|
|
|
srcs = ["_src/interpreters/partial_eval.py"],
|
|
|
|
deps = [
|
2023-04-12 15:11:22 -07:00
|
|
|
":ad_util",
|
2023-03-30 06:12:10 -07:00
|
|
|
":api_util",
|
2024-05-17 15:58:25 -07:00
|
|
|
":compute_on",
|
2023-03-30 06:12:10 -07:00
|
|
|
":config",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-30 06:12:10 -07:00
|
|
|
":effects",
|
|
|
|
":profiler",
|
|
|
|
":source_info_util",
|
2023-04-04 10:56:25 -07:00
|
|
|
":state_types",
|
2023-03-30 06:12:10 -07:00
|
|
|
":tree_util",
|
|
|
|
":util",
|
2024-09-06 16:44:19 -07:00
|
|
|
":xla_metadata",
|
2023-03-30 06:12:10 -07:00
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-04-06 11:42:45 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "partition_spec",
|
|
|
|
srcs = ["_src/partition_spec.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "path",
|
|
|
|
srcs = ["_src/path.py"],
|
|
|
|
deps = py_deps("epath"),
|
|
|
|
)
|
|
|
|
|
2023-06-29 21:01:32 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "experimental_profiler",
|
|
|
|
srcs = ["experimental/profiler.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2025-01-23 17:52:13 -08:00
|
|
|
pytype_library(
|
|
|
|
name = "experimental_transfer",
|
|
|
|
srcs = ["experimental/transfer.py"],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-06-09 13:34:09 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pickle_util",
|
|
|
|
srcs = ["_src/pickle_util.py"],
|
|
|
|
deps = [":profiler"] + py_deps("cloudpickle"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 08:50:49 -08:00
|
|
|
name = "pretty_printer",
|
|
|
|
srcs = ["_src/pretty_printer.py"],
|
2024-01-26 13:52:47 +00:00
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":util",
|
|
|
|
] + py_deps("colorama"),
|
2023-03-09 08:50:49 -08:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 11:38:08 -08:00
|
|
|
name = "profiler",
|
|
|
|
srcs = ["_src/profiler.py"],
|
|
|
|
deps = [
|
|
|
|
":traceback_util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-13 08:49:39 -07:00
|
|
|
name = "sharding",
|
|
|
|
srcs = ["_src/sharding.py"],
|
|
|
|
deps = [
|
2024-06-05 08:02:39 -07:00
|
|
|
":op_shardings",
|
2023-03-13 08:49:39 -07:00
|
|
|
":util",
|
2023-04-14 08:46:17 -07:00
|
|
|
":xla_bridge",
|
2023-03-13 08:49:39 -07:00
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-05-17 15:58:25 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "compute_on",
|
|
|
|
srcs = ["_src/compute_on.py"],
|
2025-01-26 09:24:01 -08:00
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
2024-05-17 15:58:25 -07:00
|
|
|
)
|
|
|
|
|
2024-09-06 16:44:19 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "xla_metadata",
|
|
|
|
srcs = ["_src/xla_metadata.py"],
|
|
|
|
deps = [
|
|
|
|
":config",
|
2025-01-26 12:08:12 -08:00
|
|
|
"//jax/_src/lib",
|
2024-09-06 16:44:19 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-11-15 08:48:17 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "layout",
|
|
|
|
srcs = ["_src/layout.py"],
|
2024-04-03 16:12:43 -07:00
|
|
|
deps = [
|
2024-06-27 16:46:44 -07:00
|
|
|
":dtypes",
|
2024-04-03 16:12:43 -07:00
|
|
|
":sharding",
|
|
|
|
":sharding_impls",
|
|
|
|
"//jax/_src/lib",
|
2024-06-27 16:46:44 -07:00
|
|
|
] + py_deps("numpy"),
|
2023-11-15 08:48:17 -08:00
|
|
|
)
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "sharding_impls",
|
|
|
|
srcs = ["_src/sharding_impls.py"],
|
|
|
|
deps = [
|
2023-08-04 16:26:31 -07:00
|
|
|
":config",
|
2024-06-03 14:52:08 -07:00
|
|
|
":core",
|
2024-09-03 14:30:37 -07:00
|
|
|
":internal_mesh_utils",
|
2023-04-10 10:15:08 -07:00
|
|
|
":mesh",
|
|
|
|
":op_shardings",
|
|
|
|
":partition_spec",
|
|
|
|
":sharding",
|
|
|
|
":sharding_specs",
|
2025-01-08 11:10:37 -08:00
|
|
|
":source_info_util",
|
2023-04-10 10:15:08 -07:00
|
|
|
":tree_util",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-04-06 09:48:14 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "sharding_specs",
|
|
|
|
srcs = ["_src/sharding_specs.py"],
|
|
|
|
deps = [
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
":config",
|
2023-04-06 09:48:14 -07:00
|
|
|
":op_shardings",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-03 14:30:37 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "internal_mesh_utils",
|
|
|
|
srcs = ["_src/mesh_utils.py"],
|
|
|
|
deps = [
|
|
|
|
":xla_bridge",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 08:40:28 -08:00
|
|
|
name = "source_info_util",
|
|
|
|
srcs = ["_src/source_info_util.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("source_info_util"),
|
|
|
|
deps = [
|
|
|
|
":traceback_util",
|
|
|
|
":version",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-04-04 10:56:25 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "state_types",
|
|
|
|
srcs = [
|
|
|
|
"_src/state/__init__.py",
|
2024-01-02 15:52:57 -08:00
|
|
|
"_src/state/indexing.py",
|
2023-04-04 10:56:25 -07:00
|
|
|
"_src/state/types.py",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":core",
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
":dtypes",
|
2023-04-04 10:56:25 -07:00
|
|
|
":effects",
|
|
|
|
":pretty_printer",
|
2024-10-01 10:25:53 -07:00
|
|
|
":traceback_util",
|
2024-01-02 15:52:57 -08:00
|
|
|
":tree_util",
|
2023-09-25 14:44:27 +01:00
|
|
|
":typing",
|
2023-04-04 10:56:25 -07:00
|
|
|
":util",
|
2024-01-02 15:52:57 -08:00
|
|
|
] + py_deps("numpy"),
|
2023-04-04 10:56:25 -07:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 14:51:08 -08:00
|
|
|
name = "tree_util",
|
|
|
|
srcs = ["_src/tree_util.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("tree_util"),
|
|
|
|
deps = [
|
|
|
|
":traceback_util",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 10:33:11 -08:00
|
|
|
name = "traceback_util",
|
|
|
|
srcs = ["_src/traceback_util.py"],
|
2023-03-10 08:40:28 -08:00
|
|
|
visibility = [":internal"] + jax_visibility("traceback_util"),
|
2023-03-09 10:33:11 -08:00
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-23 10:29:11 -07:00
|
|
|
name = "typing",
|
|
|
|
srcs = [
|
|
|
|
"_src/typing.py",
|
|
|
|
],
|
|
|
|
deps = [":basearray"] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-07-20 18:28:18 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "tpu_custom_call",
|
|
|
|
srcs = ["_src/tpu_custom_call.py"],
|
|
|
|
visibility = [":internal"],
|
|
|
|
deps = [
|
2023-07-26 03:58:59 -07:00
|
|
|
":config",
|
2023-07-20 18:28:18 -07:00
|
|
|
":core",
|
|
|
|
":jax",
|
2023-12-06 08:19:14 -08:00
|
|
|
":mlir",
|
2024-02-06 15:46:31 -08:00
|
|
|
":sharding_impls",
|
2023-07-20 18:28:18 -07:00
|
|
|
"//jax/_src/lib",
|
2024-10-25 12:16:18 -07:00
|
|
|
"//jax/_src/pallas",
|
2023-12-05 08:04:41 -08:00
|
|
|
] + if_building_jaxlib([
|
|
|
|
"//jaxlib/mlir:ir",
|
2024-03-19 14:17:34 -07:00
|
|
|
"//jaxlib/mlir:mhlo_dialect",
|
|
|
|
"//jaxlib/mlir:pass_manager",
|
2023-12-05 08:04:41 -08:00
|
|
|
"//jaxlib/mlir:stablehlo_dialect",
|
|
|
|
]) + py_deps("numpy") + py_deps("absl/flags"),
|
2023-07-20 18:28:18 -07:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 05:26:58 -08:00
|
|
|
name = "util",
|
|
|
|
srcs = ["_src/util.py"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
2023-03-09 05:26:58 -08:00
|
|
|
":config",
|
|
|
|
"//jax/_src/lib",
|
2023-03-27 10:14:05 -07:00
|
|
|
] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-03-09 05:26:58 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "version",
|
|
|
|
srcs = ["version.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-30 14:05:24 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "xla",
|
|
|
|
srcs = ["_src/interpreters/xla.py"],
|
|
|
|
deps = [
|
|
|
|
":abstract_arrays",
|
|
|
|
":config",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-04-10 10:15:08 -07:00
|
|
|
":sharding_impls",
|
2023-03-30 14:05:24 -07:00
|
|
|
":source_info_util",
|
|
|
|
":typing",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-10 11:11:28 -08:00
|
|
|
# TODO(phawkins): break up this SCC.
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 11:11:28 -08:00
|
|
|
name = "xla_bridge",
|
|
|
|
srcs = [
|
|
|
|
"_src/clusters/__init__.py",
|
|
|
|
"_src/clusters/cloud_tpu_cluster.py",
|
|
|
|
"_src/clusters/cluster.py",
|
2024-07-08 05:08:25 +00:00
|
|
|
"_src/clusters/k8s_cluster.py",
|
2024-07-08 12:19:18 -07:00
|
|
|
"_src/clusters/mpi4py_cluster.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/clusters/ompi_cluster.py",
|
|
|
|
"_src/clusters/slurm_cluster.py",
|
|
|
|
"_src/distributed.py",
|
|
|
|
"_src/xla_bridge.py",
|
|
|
|
],
|
2023-03-10 14:22:55 -08:00
|
|
|
visibility = [":internal"] + jax_visibility("xla_bridge"),
|
2023-03-10 11:11:28 -08:00
|
|
|
deps = [
|
|
|
|
":cloud_tpu_init",
|
|
|
|
":config",
|
2024-01-12 17:12:16 -08:00
|
|
|
":hardware_utils",
|
2023-03-10 11:11:28 -08:00
|
|
|
":traceback_util",
|
2023-03-29 08:23:18 -07:00
|
|
|
":util",
|
2023-03-10 11:11:28 -08:00
|
|
|
"//jax/_src/lib",
|
2024-06-26 13:43:15 -04:00
|
|
|
],
|
2023-03-10 11:11:28 -08:00
|
|
|
)
|
|
|
|
|
2023-03-09 08:50:49 -08:00
|
|
|
# Public JAX libraries below this point.
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
py_library_providing_imports_info(
|
|
|
|
name = "experimental",
|
2023-02-16 13:29:50 -08:00
|
|
|
srcs = glob(
|
|
|
|
[
|
|
|
|
"experimental/*.py",
|
|
|
|
"example_libraries/*.py",
|
|
|
|
],
|
|
|
|
),
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("absl/logging") + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "stax",
|
|
|
|
srcs = [
|
|
|
|
"example_libraries/stax.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "experimental_sparse",
|
2022-11-16 09:58:06 -08:00
|
|
|
srcs = glob(
|
|
|
|
[
|
|
|
|
"experimental/sparse/*.py",
|
|
|
|
],
|
|
|
|
exclude = ["experimental/sparse/test_util.py"],
|
|
|
|
),
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
2022-11-16 09:58:06 -08:00
|
|
|
pytype_library(
|
|
|
|
name = "sparse_test_util",
|
|
|
|
testonly = 1,
|
|
|
|
srcs = [
|
|
|
|
"experimental/sparse/test_util.py",
|
|
|
|
],
|
|
|
|
visibility = [":internal"],
|
|
|
|
deps = [
|
|
|
|
":experimental_sparse",
|
|
|
|
":jax",
|
|
|
|
":test_util",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "optimizers",
|
|
|
|
srcs = [
|
|
|
|
"example_libraries/optimizers.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
2023-11-16 14:21:04 -08:00
|
|
|
deps = [":jax"] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "ode",
|
|
|
|
srcs = ["experimental/ode.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
# TODO(apaszke): Remove this target
|
|
|
|
pytype_library(
|
|
|
|
name = "pjit",
|
|
|
|
srcs = ["experimental/pjit.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":experimental",
|
|
|
|
":jax",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "jet",
|
|
|
|
srcs = ["experimental/jet.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "experimental_host_callback",
|
2024-03-29 13:36:20 +02:00
|
|
|
srcs = [
|
|
|
|
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
|
|
|
|
"experimental/host_callback.py",
|
|
|
|
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "compilation_cache",
|
|
|
|
srcs = [
|
|
|
|
"experimental/compilation_cache/compilation_cache.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "mesh_utils",
|
|
|
|
srcs = ["experimental/mesh_utils.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
2024-09-03 14:30:37 -07:00
|
|
|
":internal_mesh_utils",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
)
|
2022-11-28 14:31:10 -08:00
|
|
|
|
2023-12-04 17:48:17 -08:00
|
|
|
# TODO(phawkins): remove this target in favor of the finer-grained targets in jax/extend/...
|
|
|
|
pytype_strict_library(
|
2023-08-24 14:40:10 -07:00
|
|
|
name = "extend",
|
|
|
|
visibility = [":jax_extend_users"],
|
2023-12-04 17:48:17 -08:00
|
|
|
deps = [
|
|
|
|
"//jax/extend",
|
2024-03-14 16:09:02 -07:00
|
|
|
"//jax/extend:backend",
|
2023-12-04 17:48:17 -08:00
|
|
|
"//jax/extend:core",
|
|
|
|
"//jax/extend:linear_util",
|
|
|
|
"//jax/extend:random",
|
|
|
|
"//jax/extend:source_info_util",
|
|
|
|
],
|
2023-08-24 14:40:10 -07:00
|
|
|
)
|
|
|
|
|
2023-07-20 18:28:18 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "mosaic",
|
|
|
|
srcs = [
|
|
|
|
"experimental/mosaic/__init__.py",
|
|
|
|
"experimental/mosaic/dialects.py",
|
|
|
|
],
|
|
|
|
visibility = [":mosaic_users"],
|
|
|
|
deps = [
|
|
|
|
":tpu_custom_call",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-11-28 14:31:10 -08:00
|
|
|
pytype_library(
|
|
|
|
name = "rnn",
|
|
|
|
srcs = ["experimental/rnn.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
2024-10-28 12:17:34 -07:00
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "experimental_colocated_python",
|
|
|
|
srcs = [
|
|
|
|
"experimental/colocated_python/__init__.py",
|
|
|
|
"experimental/colocated_python/api.py",
|
|
|
|
"experimental/colocated_python/func.py",
|
|
|
|
"experimental/colocated_python/func_backend.py",
|
|
|
|
"experimental/colocated_python/serialization.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":api_util",
|
|
|
|
":jax",
|
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
2024-11-26 13:30:31 -08:00
|
|
|
"//jax/extend:ifrt_programs",
|
2024-10-28 12:17:34 -07:00
|
|
|
] + py_deps("numpy") + py_deps("cloudpickle"),
|
|
|
|
)
|