2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2022-07-01 15:06:54 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
# JAX is Autograd and XLA
|
|
|
|
|
2025-03-11 08:29:45 -07:00
|
|
|
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
|
2024-03-27 10:26:38 -07:00
|
|
|
load("@rules_python//python:defs.bzl", "py_library")
|
2022-07-01 15:06:54 -07:00
|
|
|
load(
|
|
|
|
"//jaxlib:jax.bzl",
|
2023-12-05 08:04:41 -08:00
|
|
|
"if_building_jaxlib",
|
2025-01-06 10:56:57 -08:00
|
|
|
"jax_export_file_visibility",
|
2023-08-24 14:40:10 -07:00
|
|
|
"jax_extend_internal_users",
|
2022-07-01 15:06:54 -07:00
|
|
|
"jax_extra_deps",
|
2023-11-17 02:04:49 -08:00
|
|
|
"jax_internal_export_back_compat_test_util_visibility",
|
2022-07-01 15:06:54 -07:00
|
|
|
"jax_internal_packages",
|
2023-11-09 13:57:30 -08:00
|
|
|
"jax_internal_test_harnesses_visibility",
|
2022-07-01 15:06:54 -07:00
|
|
|
"jax_test_util_visibility",
|
2023-03-10 08:40:28 -08:00
|
|
|
"jax_visibility",
|
2024-05-17 04:29:43 -07:00
|
|
|
"mosaic_gpu_internal_users",
|
2023-07-20 18:28:18 -07:00
|
|
|
"mosaic_internal_users",
|
2025-03-03 17:25:59 -08:00
|
|
|
"pallas_fuser_users",
|
2023-08-01 16:42:26 -07:00
|
|
|
"pallas_gpu_internal_users",
|
|
|
|
"pallas_tpu_internal_users",
|
2022-08-05 07:48:40 -07:00
|
|
|
"py_deps",
|
2022-07-01 15:06:54 -07:00
|
|
|
"py_library_providing_imports_info",
|
|
|
|
"pytype_library",
|
2023-03-09 05:26:58 -08:00
|
|
|
"pytype_strict_library",
|
2025-03-14 13:53:57 -07:00
|
|
|
"serialize_executable_internal_users",
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-04-19 13:26:24 -07:00
|
|
|
package(
|
|
|
|
default_applicable_licenses = [],
|
|
|
|
default_visibility = [":internal"],
|
|
|
|
)
|
2022-07-01 15:06:54 -07:00
|
|
|
|
2023-02-01 21:25:46 +00:00
|
|
|
licenses(["notice"])
|
|
|
|
|
2025-03-11 08:29:45 -07:00
|
|
|
# The flag controls whether jaxlib should be built by Bazel.
|
|
|
|
# If ":build_jaxlib=true", then jaxlib will be built.
|
|
|
|
# If ":build_jaxlib=false", then jaxlib is not built. It is assumed that the pre-built jaxlib wheel
|
|
|
|
# is available in the "dist" folder.
|
|
|
|
# If ":build_jaxlib=wheel", then jaxlib wheel will be built as a py_import rule attribute.
|
|
|
|
# The py_import rule unpacks the wheel and provides its content as a py_library.
|
|
|
|
string_flag(
|
2022-07-03 15:04:37 -04:00
|
|
|
name = "build_jaxlib",
|
2025-03-11 08:29:45 -07:00
|
|
|
build_setting_default = "true",
|
|
|
|
values = [
|
|
|
|
"true",
|
|
|
|
"false",
|
|
|
|
"wheel",
|
|
|
|
],
|
2022-07-03 15:04:37 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
config_setting(
|
|
|
|
name = "enable_jaxlib_build",
|
|
|
|
flag_values = {
|
2025-03-11 08:29:45 -07:00
|
|
|
":build_jaxlib": "true",
|
2022-07-03 15:04:37 -04:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
exports_files([
|
|
|
|
"LICENSE",
|
|
|
|
"version.py",
|
2025-02-25 09:28:35 -08:00
|
|
|
"py.typed",
|
2022-07-01 15:06:54 -07:00
|
|
|
])
|
|
|
|
|
2025-01-06 10:56:57 -08:00
|
|
|
exports_files(
|
|
|
|
["_src/export/serialization.fbs"],
|
|
|
|
visibility = jax_export_file_visibility,
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
# Packages that have access to JAX-internal implementation details.
|
|
|
|
package_group(
|
|
|
|
name = "internal",
|
|
|
|
packages = [
|
|
|
|
"//...",
|
|
|
|
] + jax_internal_packages,
|
|
|
|
)
|
|
|
|
|
2023-08-24 14:40:10 -07:00
|
|
|
package_group(
|
|
|
|
name = "jax_extend_users",
|
2023-12-04 17:48:17 -08:00
|
|
|
includes = [":internal"],
|
2023-08-24 14:40:10 -07:00
|
|
|
packages = [
|
|
|
|
# Intentionally avoid jax dependencies on jax.extend.
|
|
|
|
# See https://jax.readthedocs.io/en/latest/jep/15856-jex.html
|
2024-08-26 09:10:26 -07:00
|
|
|
"//tests/...",
|
2023-08-24 14:40:10 -07:00
|
|
|
] + jax_extend_internal_users,
|
|
|
|
)
|
|
|
|
|
2023-07-20 18:28:18 -07:00
|
|
|
package_group(
|
|
|
|
name = "mosaic_users",
|
2024-08-26 09:10:26 -07:00
|
|
|
includes = [":internal"],
|
|
|
|
packages = mosaic_internal_users,
|
2023-07-20 18:28:18 -07:00
|
|
|
)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
package_group(
|
|
|
|
name = "pallas_gpu_users",
|
2024-08-26 09:10:26 -07:00
|
|
|
includes = [":internal"],
|
|
|
|
packages = pallas_gpu_internal_users,
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
package_group(
|
|
|
|
name = "pallas_tpu_users",
|
2024-08-26 09:10:26 -07:00
|
|
|
includes = [":internal"],
|
|
|
|
packages = pallas_tpu_internal_users,
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
|
2025-03-03 17:25:59 -08:00
|
|
|
package_group(
|
|
|
|
name = "pallas_fuser_users",
|
|
|
|
includes = [":internal"],
|
|
|
|
packages = pallas_fuser_users,
|
|
|
|
)
|
|
|
|
|
2024-04-18 04:03:03 -07:00
|
|
|
package_group(
|
|
|
|
name = "mosaic_gpu_users",
|
2024-08-26 09:10:26 -07:00
|
|
|
includes = [":internal"],
|
|
|
|
packages = mosaic_gpu_internal_users,
|
2024-04-18 04:03:03 -07:00
|
|
|
)
|
|
|
|
|
2025-03-14 13:53:57 -07:00
|
|
|
package_group(
|
|
|
|
name = "serialize_executable_users",
|
|
|
|
includes = [":internal"],
|
|
|
|
packages = serialize_executable_internal_users,
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
# JAX-private test utilities.
|
|
|
|
py_library(
|
|
|
|
# This build target is required in order to use private test utilities in jax._src.test_util,
|
|
|
|
# and its visibility is intentionally restricted to discourage its use outside JAX itself.
|
|
|
|
# JAX does provide some public test utilities (see jax/test_util.py);
|
|
|
|
# these are available in jax.test_util via the standard :jax target.
|
|
|
|
name = "test_util",
|
|
|
|
srcs = [
|
|
|
|
"_src/test_util.py",
|
2024-12-19 21:10:21 -05:00
|
|
|
"_src/test_warning_util.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
visibility = [
|
|
|
|
":internal",
|
|
|
|
] + jax_test_util_visibility,
|
|
|
|
deps = [
|
2025-02-14 06:11:12 -08:00
|
|
|
":compilation_cache_internal",
|
2022-07-01 15:06:54 -07:00
|
|
|
":jax",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-11-17 02:04:49 -08:00
|
|
|
# TODO(necula): break the internal_test_util into smaller build targets.
|
2023-02-16 15:29:12 -08:00
|
|
|
py_library(
|
|
|
|
name = "internal_test_util",
|
2023-12-18 09:49:11 -08:00
|
|
|
srcs = [
|
2025-02-25 09:28:35 -08:00
|
|
|
"_src/internal_test_util/__init__.py",
|
2023-12-18 09:49:11 -08:00
|
|
|
"_src/internal_test_util/deprecation_module.py",
|
|
|
|
"_src/internal_test_util/lax_test_util.py",
|
|
|
|
] + glob(
|
2023-11-17 02:04:49 -08:00
|
|
|
[
|
2023-12-18 09:49:11 -08:00
|
|
|
"_src/internal_test_util/lazy_loader_module/*.py",
|
|
|
|
],
|
2023-11-09 13:57:30 -08:00
|
|
|
),
|
2023-02-16 15:29:12 -08:00
|
|
|
visibility = [":internal"],
|
2023-03-07 08:49:05 -08:00
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
] + py_deps("numpy"),
|
2023-02-16 15:29:12 -08:00
|
|
|
)
|
|
|
|
|
2023-11-09 13:57:30 -08:00
|
|
|
py_library(
|
|
|
|
name = "internal_test_harnesses",
|
|
|
|
srcs = ["_src/internal_test_util/test_harnesses.py"],
|
|
|
|
visibility = [":internal"] + jax_internal_test_harnesses_visibility,
|
|
|
|
deps = [
|
2024-10-04 13:55:36 -07:00
|
|
|
":ad_util",
|
|
|
|
":config",
|
2023-11-09 13:57:30 -08:00
|
|
|
":jax",
|
2024-10-04 13:55:36 -07:00
|
|
|
":test_util",
|
|
|
|
"//jax/_src/lib",
|
2023-11-09 13:57:30 -08:00
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-11-17 02:04:49 -08:00
|
|
|
py_library(
|
|
|
|
name = "internal_export_back_compat_test_util",
|
|
|
|
srcs = ["_src/internal_test_util/export_back_compat_test_util.py"],
|
|
|
|
visibility = [
|
|
|
|
":internal",
|
|
|
|
] + jax_internal_export_back_compat_test_util_visibility,
|
|
|
|
deps = [
|
|
|
|
":jax",
|
2025-01-08 18:33:23 -08:00
|
|
|
":test_util",
|
2023-11-17 02:04:49 -08:00
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-12-18 09:49:11 -08:00
|
|
|
py_library(
|
|
|
|
name = "internal_export_back_compat_test_data",
|
|
|
|
testonly = 1,
|
2024-05-02 05:37:41 -07:00
|
|
|
srcs = glob([
|
|
|
|
"_src/internal_test_util/export_back_compat_test_data/*.py",
|
|
|
|
"_src/internal_test_util/export_back_compat_test_data/pallas/*.py",
|
|
|
|
]),
|
2023-12-18 09:49:11 -08:00
|
|
|
visibility = [
|
|
|
|
":internal",
|
|
|
|
],
|
|
|
|
deps = py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
py_library_providing_imports_info(
|
|
|
|
name = "jax",
|
2023-03-10 11:11:28 -08:00
|
|
|
srcs = [
|
2023-03-29 08:23:18 -07:00
|
|
|
"_src/__init__.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/ad_checkpoint.py",
|
|
|
|
"_src/api.py",
|
|
|
|
"_src/array.py",
|
2024-06-24 11:19:59 -07:00
|
|
|
"_src/blocked_sampler.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/callback.py",
|
|
|
|
"_src/checkify.py",
|
|
|
|
"_src/custom_batching.py",
|
2025-01-23 11:38:06 -08:00
|
|
|
"_src/custom_dce.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/custom_derivatives.py",
|
2024-07-08 04:30:11 -07:00
|
|
|
"_src/custom_partitioning.py",
|
2024-12-05 11:32:43 -08:00
|
|
|
"_src/custom_partitioning_sharding_rule.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/custom_transpose.py",
|
|
|
|
"_src/debugging.py",
|
|
|
|
"_src/dispatch.py",
|
|
|
|
"_src/dlpack.py",
|
2024-03-14 15:53:33 -07:00
|
|
|
"_src/earray.py",
|
2025-02-14 07:27:38 -08:00
|
|
|
"_src/error_check.py",
|
2024-12-20 11:26:04 +00:00
|
|
|
"_src/ffi.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/flatten_util.py",
|
2023-03-30 06:12:10 -07:00
|
|
|
"_src/interpreters/__init__.py",
|
|
|
|
"_src/interpreters/ad.py",
|
|
|
|
"_src/interpreters/batching.py",
|
|
|
|
"_src/interpreters/pxla.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/pjit.py",
|
|
|
|
"_src/prng.py",
|
|
|
|
"_src/public_test_util.py",
|
|
|
|
"_src/random.py",
|
2023-12-19 16:30:48 -08:00
|
|
|
"_src/shard_alike.py",
|
2024-05-08 05:44:08 -07:00
|
|
|
"_src/sourcemap.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/stages.py",
|
2024-02-12 13:07:59 -08:00
|
|
|
"_src/tree.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
] + glob(
|
2022-07-01 15:06:54 -07:00
|
|
|
[
|
|
|
|
"*.py",
|
2024-07-08 06:15:20 -07:00
|
|
|
"_src/cudnn/**/*.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/debugger/**/*.py",
|
2023-08-25 17:36:15 -07:00
|
|
|
"_src/extend/**/*.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/image/**/*.py",
|
2024-06-04 22:02:36 -07:00
|
|
|
"_src/export/**/*.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/lax/**/*.py",
|
|
|
|
"_src/nn/**/*.py",
|
|
|
|
"_src/numpy/**/*.py",
|
|
|
|
"_src/ops/**/*.py",
|
|
|
|
"_src/scipy/**/*.py",
|
|
|
|
"_src/state/**/*.py",
|
|
|
|
"_src/third_party/**/*.py",
|
2023-12-11 12:03:48 -08:00
|
|
|
"experimental/key_reuse/**/*.py",
|
2024-11-27 13:29:27 -08:00
|
|
|
"experimental/roofline/**/*.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"image/**/*.py",
|
|
|
|
"interpreters/**/*.py",
|
|
|
|
"lax/**/*.py",
|
|
|
|
"lib/**/*.py",
|
|
|
|
"nn/**/*.py",
|
|
|
|
"numpy/**/*.py",
|
|
|
|
"ops/**/*.py",
|
|
|
|
"scipy/**/*.py",
|
|
|
|
"third_party/**/*.py",
|
|
|
|
],
|
|
|
|
exclude = [
|
|
|
|
"*_test.py",
|
|
|
|
"**/*_test.py",
|
2023-02-16 15:29:12 -08:00
|
|
|
"_src/internal_test_util/**",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
) + [
|
2024-01-25 22:20:36 -08:00
|
|
|
"experimental/attrs.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"experimental/pjit.py",
|
|
|
|
"experimental/multihost_utils.py",
|
2022-11-04 15:29:10 -07:00
|
|
|
"experimental/shard_map.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
# until checkify is moved out of experimental
|
|
|
|
"experimental/checkify.py",
|
|
|
|
"experimental/compilation_cache/compilation_cache.py",
|
|
|
|
],
|
|
|
|
lib_rule = pytype_library,
|
2023-03-13 11:12:50 -07:00
|
|
|
pytype_srcs = glob(
|
2023-08-22 11:50:09 -07:00
|
|
|
[
|
|
|
|
"numpy/*.pyi",
|
|
|
|
"_src/**/*.pyi",
|
|
|
|
],
|
2023-03-13 11:12:50 -07:00
|
|
|
exclude = [
|
|
|
|
"_src/basearray.pyi",
|
|
|
|
],
|
|
|
|
),
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
2023-03-09 05:26:58 -08:00
|
|
|
deps = [
|
2023-03-30 06:12:10 -07:00
|
|
|
":abstract_arrays",
|
|
|
|
":ad_util",
|
|
|
|
":api_util",
|
2023-03-13 11:12:50 -07:00
|
|
|
":basearray",
|
2023-03-09 13:09:20 -08:00
|
|
|
":cloud_tpu_init",
|
2023-04-17 08:35:26 -07:00
|
|
|
":compilation_cache_internal",
|
2023-08-15 06:38:56 -07:00
|
|
|
":compiler",
|
2024-05-17 15:58:25 -07:00
|
|
|
":compute_on",
|
2023-03-09 05:26:58 -08:00
|
|
|
":config",
|
2023-03-29 15:06:30 -07:00
|
|
|
":core",
|
2023-03-29 08:23:18 -07:00
|
|
|
":custom_api_util",
|
2023-03-10 12:25:25 -08:00
|
|
|
":deprecations",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-10 12:25:25 -08:00
|
|
|
":effects",
|
|
|
|
":environment_info",
|
2024-09-03 14:30:37 -07:00
|
|
|
":internal_mesh_utils",
|
2023-08-08 10:08:19 -07:00
|
|
|
":jaxpr_util",
|
2023-11-15 08:48:17 -08:00
|
|
|
":layout",
|
2023-03-09 13:09:20 -08:00
|
|
|
":lazy_loader",
|
2023-03-13 08:31:16 -07:00
|
|
|
":mesh",
|
2023-03-30 14:05:24 -07:00
|
|
|
":mlir",
|
2023-03-09 13:09:20 -08:00
|
|
|
":monitoring",
|
2025-02-13 11:21:41 -08:00
|
|
|
":named_sharding",
|
2023-04-06 08:31:47 -07:00
|
|
|
":op_shardings",
|
2023-03-30 06:12:10 -07:00
|
|
|
":partial_eval",
|
2023-04-06 11:42:45 -07:00
|
|
|
":partition_spec",
|
2023-03-09 13:09:20 -08:00
|
|
|
":path",
|
2023-06-09 13:34:09 -07:00
|
|
|
":pickle_util",
|
2023-03-09 08:50:49 -08:00
|
|
|
":pretty_printer",
|
2023-03-10 11:38:08 -08:00
|
|
|
":profiler",
|
2023-03-13 08:49:39 -07:00
|
|
|
":sharding",
|
2023-04-10 10:15:08 -07:00
|
|
|
":sharding_impls",
|
2023-04-06 09:48:14 -07:00
|
|
|
":sharding_specs",
|
2023-03-10 08:40:28 -08:00
|
|
|
":source_info_util",
|
2023-03-09 10:33:11 -08:00
|
|
|
":traceback_util",
|
2023-03-10 14:51:08 -08:00
|
|
|
":tree_util",
|
2023-03-23 10:29:11 -07:00
|
|
|
":typing",
|
2023-03-09 05:26:58 -08:00
|
|
|
":util",
|
|
|
|
":version",
|
2023-03-30 14:05:24 -07:00
|
|
|
":xla",
|
2023-03-10 11:11:28 -08:00
|
|
|
":xla_bridge",
|
2024-09-06 16:44:19 -07:00
|
|
|
":xla_metadata",
|
2023-03-09 05:26:58 -08:00
|
|
|
"//jax/_src/lib",
|
2024-06-09 08:58:54 -07:00
|
|
|
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps,
|
2023-03-30 06:12:10 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "abstract_arrays",
|
|
|
|
srcs = ["_src/abstract_arrays.py"],
|
|
|
|
deps = [
|
|
|
|
":ad_util",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-30 06:12:10 -07:00
|
|
|
":traceback_util",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "ad_util",
|
|
|
|
srcs = ["_src/ad_util.py"],
|
|
|
|
deps = [
|
|
|
|
":core",
|
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":typing",
|
|
|
|
":util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "api_util",
|
|
|
|
srcs = ["_src/api_util.py"],
|
|
|
|
deps = [
|
|
|
|
":abstract_arrays",
|
|
|
|
":config",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2024-12-18 18:23:33 -08:00
|
|
|
":state_types",
|
2023-03-30 06:12:10 -07:00
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":util",
|
|
|
|
] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-13 11:12:50 -07:00
|
|
|
name = "basearray",
|
|
|
|
srcs = ["_src/basearray.py"],
|
|
|
|
pytype_srcs = ["_src/basearray.pyi"],
|
|
|
|
deps = [
|
2025-01-16 10:50:30 -08:00
|
|
|
":partition_spec",
|
2023-03-13 11:12:50 -07:00
|
|
|
":sharding",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "cloud_tpu_init",
|
|
|
|
srcs = ["_src/cloud_tpu_init.py"],
|
2023-12-15 20:35:58 +00:00
|
|
|
deps = [
|
2024-05-28 13:42:18 -07:00
|
|
|
":config",
|
2023-12-15 20:35:58 +00:00
|
|
|
":hardware_utils",
|
2024-05-24 13:55:34 -07:00
|
|
|
":version",
|
2024-01-12 17:12:16 -08:00
|
|
|
],
|
2023-03-09 13:09:20 -08:00
|
|
|
)
|
|
|
|
|
2023-04-17 08:35:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "compilation_cache_internal",
|
|
|
|
srcs = ["_src/compilation_cache.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("compilation_cache"),
|
|
|
|
deps = [
|
2023-07-27 23:00:26 -07:00
|
|
|
":cache_key",
|
2023-04-17 08:35:26 -07:00
|
|
|
":compilation_cache_interface",
|
2023-04-19 13:26:24 -07:00
|
|
|
":config",
|
2024-05-30 17:59:05 +04:00
|
|
|
":lru_cache",
|
2024-01-08 14:02:54 -08:00
|
|
|
":monitoring",
|
2023-04-17 08:35:26 -07:00
|
|
|
":path",
|
|
|
|
"//jax/_src/lib",
|
2023-04-20 06:16:12 -07:00
|
|
|
] + py_deps("numpy") + py_deps("zstandard"),
|
2023-04-17 08:35:26 -07:00
|
|
|
)
|
|
|
|
|
2023-07-27 23:00:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "cache_key",
|
|
|
|
srcs = ["_src/cache_key.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("compilation_cache"),
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-04-17 08:35:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "compilation_cache_interface",
|
|
|
|
srcs = ["_src/compilation_cache_interface.py"],
|
2024-01-26 13:52:47 +00:00
|
|
|
deps = [
|
|
|
|
":path",
|
|
|
|
":util",
|
|
|
|
],
|
2023-04-17 08:35:26 -07:00
|
|
|
)
|
|
|
|
|
2024-05-30 17:59:05 +04:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "lru_cache",
|
|
|
|
srcs = ["_src/lru_cache.py"],
|
|
|
|
deps = [
|
|
|
|
":compilation_cache_interface",
|
2024-07-03 13:20:55 +00:00
|
|
|
":path",
|
2024-05-30 17:59:05 +04:00
|
|
|
] + py_deps("filelock"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 05:26:58 -08:00
|
|
|
name = "config",
|
|
|
|
srcs = ["_src/config.py"],
|
|
|
|
deps = [
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
":logging_config",
|
2023-03-09 05:26:58 -08:00
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "logging_config",
|
|
|
|
srcs = ["_src/logging_config.py"],
|
|
|
|
)
|
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "compiler",
|
|
|
|
srcs = ["_src/compiler.py"],
|
|
|
|
deps = [
|
2024-11-26 04:05:35 -08:00
|
|
|
":cache_key",
|
2023-08-15 06:38:56 -07:00
|
|
|
":compilation_cache_internal",
|
|
|
|
":config",
|
2023-12-18 21:24:59 -08:00
|
|
|
":mlir",
|
2023-08-15 06:38:56 -07:00
|
|
|
":monitoring",
|
2024-07-29 16:13:01 -07:00
|
|
|
":path",
|
2023-08-15 06:38:56 -07:00
|
|
|
":profiler",
|
|
|
|
":traceback_util",
|
2024-01-11 23:37:22 -08:00
|
|
|
":xla_bridge",
|
2023-08-15 06:38:56 -07:00
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-28 18:30:36 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "core",
|
|
|
|
srcs = [
|
|
|
|
"_src/core.py",
|
|
|
|
"_src/errors.py",
|
|
|
|
"_src/linear_util.py",
|
|
|
|
],
|
|
|
|
deps = [
|
2024-05-17 16:31:23 -07:00
|
|
|
":compute_on",
|
2023-03-28 18:30:36 -07:00
|
|
|
":config",
|
2024-06-13 13:14:27 -07:00
|
|
|
":deprecations",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-28 18:30:36 -07:00
|
|
|
":effects",
|
2024-11-20 13:06:39 -08:00
|
|
|
":mesh",
|
2025-02-13 11:21:41 -08:00
|
|
|
":named_sharding",
|
2024-12-10 18:02:42 -08:00
|
|
|
":partition_spec",
|
2023-03-28 18:30:36 -07:00
|
|
|
":pretty_printer",
|
|
|
|
":source_info_util",
|
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":typing",
|
|
|
|
":util",
|
2024-09-06 16:44:19 -07:00
|
|
|
":xla_metadata",
|
2023-03-28 18:30:36 -07:00
|
|
|
"//jax/_src/lib",
|
2023-05-17 07:58:19 -07:00
|
|
|
] + py_deps("numpy"),
|
2023-03-28 18:30:36 -07:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "custom_api_util",
|
|
|
|
srcs = ["_src/custom_api_util.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "deprecations",
|
|
|
|
srcs = ["_src/deprecations.py"],
|
|
|
|
)
|
|
|
|
|
2023-05-17 07:58:19 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "dtypes",
|
|
|
|
srcs = [
|
|
|
|
"_src/dtypes.py",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":traceback_util",
|
|
|
|
":typing",
|
2023-10-17 15:07:28 -07:00
|
|
|
":util",
|
2025-02-06 14:55:57 +00:00
|
|
|
"//jax/_src/lib",
|
2023-05-17 07:58:19 -07:00
|
|
|
] + py_deps("ml_dtypes") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "effects",
|
|
|
|
srcs = ["_src/effects.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 12:25:25 -08:00
|
|
|
name = "environment_info",
|
|
|
|
srcs = ["_src/environment_info.py"],
|
|
|
|
deps = [
|
|
|
|
":version",
|
2023-03-29 08:23:18 -07:00
|
|
|
":xla_bridge",
|
2023-03-10 12:25:25 -08:00
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-12-15 20:35:58 +00:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "hardware_utils",
|
|
|
|
srcs = ["_src/hardware_utils.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-30 06:12:10 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "lax_reference",
|
|
|
|
srcs = ["_src/lax_reference.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("lax_reference"),
|
|
|
|
deps = [
|
|
|
|
":core",
|
|
|
|
":util",
|
|
|
|
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "lazy_loader",
|
|
|
|
srcs = ["_src/lazy_loader.py"],
|
|
|
|
)
|
|
|
|
|
2023-08-08 10:08:19 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "jaxpr_util",
|
|
|
|
srcs = ["_src/jaxpr_util.py"],
|
|
|
|
deps = [
|
|
|
|
":core",
|
|
|
|
":source_info_util",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-13 08:31:16 -07:00
|
|
|
name = "mesh",
|
|
|
|
srcs = ["_src/mesh.py"],
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-30 14:05:24 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "mlir",
|
|
|
|
srcs = ["_src/interpreters/mlir.py"],
|
|
|
|
deps = [
|
|
|
|
":ad_util",
|
2025-01-24 12:53:51 +02:00
|
|
|
":api_util",
|
2023-03-30 14:05:24 -07:00
|
|
|
":config",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-30 14:05:24 -07:00
|
|
|
":effects",
|
2023-11-15 08:48:17 -08:00
|
|
|
":layout",
|
2023-04-06 08:31:47 -07:00
|
|
|
":op_shardings",
|
2023-03-30 14:05:24 -07:00
|
|
|
":partial_eval",
|
2024-12-10 18:02:42 -08:00
|
|
|
":partition_spec",
|
2023-12-18 21:24:59 -08:00
|
|
|
":path",
|
2023-06-09 13:34:09 -07:00
|
|
|
":pickle_util",
|
2024-06-05 09:06:36 -07:00
|
|
|
":sharding",
|
2023-04-10 10:15:08 -07:00
|
|
|
":sharding_impls",
|
2023-03-30 14:05:24 -07:00
|
|
|
":source_info_util",
|
2024-02-29 22:18:05 -08:00
|
|
|
":state_types",
|
2023-03-30 14:05:24 -07:00
|
|
|
":util",
|
|
|
|
":xla",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "monitoring",
|
|
|
|
srcs = ["_src/monitoring.py"],
|
|
|
|
)
|
|
|
|
|
2023-04-06 08:31:47 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "op_shardings",
|
|
|
|
srcs = ["_src/op_shardings.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2025-03-14 13:53:57 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "serialize_executable",
|
|
|
|
srcs = ["experimental/serialize_executable.py"],
|
|
|
|
visibility = [":serialize_executable_users"],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2025-01-29 15:01:07 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "source_mapper",
|
|
|
|
srcs = glob(include = ["experimental/source_mapper/**/*.py"]),
|
|
|
|
visibility = [
|
|
|
|
"//visibility:public",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":core",
|
|
|
|
":jax",
|
|
|
|
":source_info_util",
|
|
|
|
] + py_deps("absl/flags"),
|
|
|
|
)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas",
|
|
|
|
srcs = glob(
|
|
|
|
[
|
|
|
|
"experimental/pallas/**/*.py",
|
|
|
|
],
|
|
|
|
exclude = [
|
|
|
|
"experimental/pallas/gpu.py",
|
2024-10-07 04:04:16 -07:00
|
|
|
"experimental/pallas/mosaic_gpu.py",
|
2024-06-05 01:34:07 -07:00
|
|
|
"experimental/pallas/ops/gpu/**/*.py",
|
2024-01-11 14:41:21 -08:00
|
|
|
"experimental/pallas/ops/tpu/**/*.py",
|
2024-10-07 04:04:16 -07:00
|
|
|
"experimental/pallas/tpu.py",
|
2025-03-03 17:25:59 -08:00
|
|
|
"experimental/pallas/fuser.py",
|
2024-10-07 04:04:16 -07:00
|
|
|
"experimental/pallas/triton.py",
|
2023-08-01 16:42:26 -07:00
|
|
|
],
|
|
|
|
),
|
|
|
|
visibility = [
|
|
|
|
"//visibility:public",
|
|
|
|
],
|
|
|
|
deps = [
|
2024-07-09 10:42:24 -07:00
|
|
|
":deprecations",
|
2023-08-01 16:42:26 -07:00
|
|
|
":jax",
|
2023-10-17 09:05:07 -07:00
|
|
|
":source_info_util",
|
2023-08-01 16:42:26 -07:00
|
|
|
"//jax/_src/pallas",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_tpu",
|
|
|
|
srcs = ["experimental/pallas/tpu.py"],
|
|
|
|
visibility = [
|
|
|
|
":pallas_tpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
2024-06-10 05:58:35 -07:00
|
|
|
":pallas", # build_cleaner: keep
|
2023-09-18 10:53:10 -07:00
|
|
|
":tpu_custom_call",
|
2024-07-24 17:13:49 -07:00
|
|
|
"//jax/_src/pallas",
|
2024-06-10 05:58:35 -07:00
|
|
|
"//jax/_src/pallas/mosaic:core",
|
2025-01-15 18:34:15 -08:00
|
|
|
"//jax/_src/pallas/mosaic:helpers",
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
"//jax/_src/pallas/mosaic:interpret",
|
2024-06-10 05:58:35 -07:00
|
|
|
"//jax/_src/pallas/mosaic:lowering",
|
|
|
|
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
|
|
|
|
"//jax/_src/pallas/mosaic:pipeline",
|
|
|
|
"//jax/_src/pallas/mosaic:primitives",
|
2024-06-10 18:07:33 -07:00
|
|
|
"//jax/_src/pallas/mosaic:random",
|
2024-07-17 05:28:34 -07:00
|
|
|
"//jax/_src/pallas/mosaic:verification",
|
2023-08-01 16:42:26 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2025-03-03 17:25:59 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_fuser",
|
|
|
|
srcs = ["experimental/pallas/fuser.py"],
|
|
|
|
visibility = [
|
|
|
|
":pallas_fuser_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":pallas", # build_cleaner: keep
|
|
|
|
"//jax/_src/pallas/fuser:block_spec",
|
2025-03-11 16:35:06 -07:00
|
|
|
"//jax/_src/pallas/fuser:custom_evaluate",
|
2025-03-03 17:25:59 -08:00
|
|
|
"//jax/_src/pallas/fuser:fusable",
|
|
|
|
"//jax/_src/pallas/fuser:fusion",
|
|
|
|
"//jax/_src/pallas/fuser:jaxpr_fusion",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-02-15 11:43:31 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_gpu_ops",
|
2024-10-29 05:20:02 -07:00
|
|
|
srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"],
|
2024-02-15 11:43:31 -08:00
|
|
|
visibility = [
|
|
|
|
":pallas_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
":pallas",
|
|
|
|
":pallas_gpu",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-10-29 05:20:02 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_experimental_gpu_ops",
|
|
|
|
srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"],
|
|
|
|
visibility = [
|
|
|
|
":mosaic_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
":mosaic_gpu",
|
|
|
|
":pallas",
|
|
|
|
":pallas_mosaic_gpu",
|
|
|
|
":test_util", # This is only to make them runnable as jax_multiplatform_test...
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-08-26 08:03:04 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_tpu_ops",
|
2024-01-11 14:41:21 -08:00
|
|
|
srcs = glob(["experimental/pallas/ops/tpu/**/*.py"]),
|
2023-08-26 08:03:04 -07:00
|
|
|
visibility = [
|
|
|
|
":pallas_tpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
":pallas",
|
|
|
|
":pallas_tpu",
|
2024-01-11 14:41:21 -08:00
|
|
|
] + py_deps("numpy"),
|
2023-08-26 08:03:04 -07:00
|
|
|
)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_gpu",
|
2024-10-07 05:43:14 -07:00
|
|
|
visibility = [
|
|
|
|
":pallas_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":pallas_triton",
|
|
|
|
# TODO(slebedev): Add :pallas_mosaic_gpu once it is ready.
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_triton",
|
2024-10-07 04:04:16 -07:00
|
|
|
srcs = [
|
|
|
|
"experimental/pallas/gpu.py",
|
|
|
|
"experimental/pallas/triton.py",
|
|
|
|
],
|
2023-08-01 16:42:26 -07:00
|
|
|
visibility = [
|
|
|
|
":pallas_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
2024-10-07 04:04:16 -07:00
|
|
|
":deprecations",
|
2024-09-04 13:31:35 -07:00
|
|
|
"//jax/_src/pallas/triton:core",
|
2024-06-10 05:58:35 -07:00
|
|
|
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
|
|
|
|
"//jax/_src/pallas/triton:primitives",
|
2024-05-30 01:45:31 -07:00
|
|
|
],
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
|
2024-10-07 04:04:16 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pallas_mosaic_gpu",
|
|
|
|
srcs = ["experimental/pallas/mosaic_gpu.py"],
|
|
|
|
visibility = [
|
|
|
|
":mosaic_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
2025-02-20 15:05:51 -08:00
|
|
|
":mosaic_gpu",
|
2024-10-07 04:04:16 -07:00
|
|
|
"//jax/_src/pallas/mosaic_gpu:core",
|
|
|
|
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
|
2024-10-28 08:25:08 -07:00
|
|
|
"//jax/_src/pallas/mosaic_gpu:pipeline",
|
2024-10-07 04:04:16 -07:00
|
|
|
"//jax/_src/pallas/mosaic_gpu:primitives",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-04-18 04:03:03 -07:00
|
|
|
# This target only supports sm_90 GPUs.
|
2025-03-17 08:36:36 -07:00
|
|
|
py_library_providing_imports_info(
|
2024-04-18 04:03:03 -07:00
|
|
|
name = "mosaic_gpu",
|
|
|
|
srcs = glob(["experimental/mosaic/gpu/*.py"]),
|
|
|
|
visibility = [
|
|
|
|
":mosaic_gpu_users",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":config",
|
2024-05-22 04:40:54 -07:00
|
|
|
":core",
|
2024-04-18 04:03:03 -07:00
|
|
|
":jax",
|
|
|
|
":mlir",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
"//jaxlib/mlir:arithmetic_dialect",
|
|
|
|
"//jaxlib/mlir:builtin_dialect",
|
|
|
|
"//jaxlib/mlir:func_dialect",
|
|
|
|
"//jaxlib/mlir:gpu_dialect",
|
|
|
|
"//jaxlib/mlir:ir",
|
|
|
|
"//jaxlib/mlir:llvm_dialect",
|
|
|
|
"//jaxlib/mlir:math_dialect",
|
|
|
|
"//jaxlib/mlir:memref_dialect",
|
|
|
|
"//jaxlib/mlir:nvgpu_dialect",
|
|
|
|
"//jaxlib/mlir:nvvm_dialect",
|
|
|
|
"//jaxlib/mlir:pass_manager",
|
|
|
|
"//jaxlib/mlir:scf_dialect",
|
|
|
|
"//jaxlib/mlir:vector_dialect",
|
2025-03-17 08:36:36 -07:00
|
|
|
"//jaxlib/mosaic/python:gpu_dialect",
|
2024-05-30 01:45:31 -07:00
|
|
|
] + py_deps("absl/flags") + py_deps("numpy"),
|
2024-04-18 04:03:03 -07:00
|
|
|
)
|
|
|
|
|
2023-03-30 06:12:10 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "partial_eval",
|
|
|
|
srcs = ["_src/interpreters/partial_eval.py"],
|
|
|
|
deps = [
|
2023-04-12 15:11:22 -07:00
|
|
|
":ad_util",
|
2023-03-30 06:12:10 -07:00
|
|
|
":api_util",
|
2024-05-17 15:58:25 -07:00
|
|
|
":compute_on",
|
2023-03-30 06:12:10 -07:00
|
|
|
":config",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-03-30 06:12:10 -07:00
|
|
|
":effects",
|
|
|
|
":profiler",
|
|
|
|
":source_info_util",
|
2023-04-04 10:56:25 -07:00
|
|
|
":state_types",
|
2023-03-30 06:12:10 -07:00
|
|
|
":tree_util",
|
|
|
|
":util",
|
2024-09-06 16:44:19 -07:00
|
|
|
":xla_metadata",
|
2023-03-30 06:12:10 -07:00
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-04-06 11:42:45 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "partition_spec",
|
|
|
|
srcs = ["_src/partition_spec.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 13:09:20 -08:00
|
|
|
name = "path",
|
|
|
|
srcs = ["_src/path.py"],
|
|
|
|
deps = py_deps("epath"),
|
|
|
|
)
|
|
|
|
|
2023-06-29 21:01:32 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "experimental_profiler",
|
|
|
|
srcs = ["experimental/profiler.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2025-01-23 17:52:13 -08:00
|
|
|
pytype_library(
|
|
|
|
name = "experimental_transfer",
|
|
|
|
srcs = ["experimental/transfer.py"],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-06-09 13:34:09 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "pickle_util",
|
|
|
|
srcs = ["_src/pickle_util.py"],
|
|
|
|
deps = [":profiler"] + py_deps("cloudpickle"),
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 08:50:49 -08:00
|
|
|
name = "pretty_printer",
|
|
|
|
srcs = ["_src/pretty_printer.py"],
|
2024-01-26 13:52:47 +00:00
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":util",
|
|
|
|
] + py_deps("colorama"),
|
2023-03-09 08:50:49 -08:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 11:38:08 -08:00
|
|
|
name = "profiler",
|
|
|
|
srcs = ["_src/profiler.py"],
|
|
|
|
deps = [
|
|
|
|
":traceback_util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-13 08:49:39 -07:00
|
|
|
name = "sharding",
|
|
|
|
srcs = ["_src/sharding.py"],
|
|
|
|
deps = [
|
2024-06-05 08:02:39 -07:00
|
|
|
":op_shardings",
|
2023-03-13 08:49:39 -07:00
|
|
|
":util",
|
2023-04-14 08:46:17 -07:00
|
|
|
":xla_bridge",
|
2023-03-13 08:49:39 -07:00
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-05-17 15:58:25 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "compute_on",
|
|
|
|
srcs = ["_src/compute_on.py"],
|
2025-01-26 09:24:01 -08:00
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
2024-05-17 15:58:25 -07:00
|
|
|
)
|
|
|
|
|
2024-09-06 16:44:19 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "xla_metadata",
|
|
|
|
srcs = ["_src/xla_metadata.py"],
|
|
|
|
deps = [
|
|
|
|
":config",
|
2025-01-26 12:08:12 -08:00
|
|
|
"//jax/_src/lib",
|
2024-09-06 16:44:19 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-11-15 08:48:17 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "layout",
|
|
|
|
srcs = ["_src/layout.py"],
|
2024-04-03 16:12:43 -07:00
|
|
|
deps = [
|
2024-06-27 16:46:44 -07:00
|
|
|
":dtypes",
|
2024-04-03 16:12:43 -07:00
|
|
|
":sharding",
|
|
|
|
":sharding_impls",
|
|
|
|
"//jax/_src/lib",
|
2024-06-27 16:46:44 -07:00
|
|
|
] + py_deps("numpy"),
|
2023-11-15 08:48:17 -08:00
|
|
|
)
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "sharding_impls",
|
|
|
|
srcs = ["_src/sharding_impls.py"],
|
|
|
|
deps = [
|
2023-08-04 16:26:31 -07:00
|
|
|
":config",
|
2024-06-03 14:52:08 -07:00
|
|
|
":core",
|
2024-09-03 14:30:37 -07:00
|
|
|
":internal_mesh_utils",
|
2023-04-10 10:15:08 -07:00
|
|
|
":mesh",
|
2025-02-13 11:21:41 -08:00
|
|
|
":named_sharding",
|
2023-04-10 10:15:08 -07:00
|
|
|
":op_shardings",
|
|
|
|
":partition_spec",
|
|
|
|
":sharding",
|
|
|
|
":sharding_specs",
|
2025-01-08 11:10:37 -08:00
|
|
|
":source_info_util",
|
2023-04-10 10:15:08 -07:00
|
|
|
":tree_util",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2025-02-13 11:21:41 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "named_sharding",
|
|
|
|
srcs = ["_src/named_sharding.py"],
|
|
|
|
deps = [
|
2025-02-26 18:16:45 -08:00
|
|
|
":config",
|
2025-02-13 11:21:41 -08:00
|
|
|
":mesh",
|
|
|
|
":partition_spec",
|
|
|
|
":sharding",
|
|
|
|
":util",
|
2025-02-26 18:16:45 -08:00
|
|
|
":xla_bridge",
|
2025-02-13 11:21:41 -08:00
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-04-06 09:48:14 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "sharding_specs",
|
|
|
|
srcs = ["_src/sharding_specs.py"],
|
|
|
|
deps = [
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
":config",
|
2023-04-06 09:48:14 -07:00
|
|
|
":op_shardings",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-03 14:30:37 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "internal_mesh_utils",
|
|
|
|
srcs = ["_src/mesh_utils.py"],
|
|
|
|
deps = [
|
|
|
|
":xla_bridge",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 08:40:28 -08:00
|
|
|
name = "source_info_util",
|
|
|
|
srcs = ["_src/source_info_util.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("source_info_util"),
|
|
|
|
deps = [
|
|
|
|
":traceback_util",
|
|
|
|
":version",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-04-04 10:56:25 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "state_types",
|
|
|
|
srcs = [
|
|
|
|
"_src/state/__init__.py",
|
2024-01-02 15:52:57 -08:00
|
|
|
"_src/state/indexing.py",
|
2023-04-04 10:56:25 -07:00
|
|
|
"_src/state/types.py",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
":core",
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
":dtypes",
|
2023-04-04 10:56:25 -07:00
|
|
|
":effects",
|
|
|
|
":pretty_printer",
|
2024-10-01 10:25:53 -07:00
|
|
|
":traceback_util",
|
2024-01-02 15:52:57 -08:00
|
|
|
":tree_util",
|
2023-09-25 14:44:27 +01:00
|
|
|
":typing",
|
2023-04-04 10:56:25 -07:00
|
|
|
":util",
|
2024-01-02 15:52:57 -08:00
|
|
|
] + py_deps("numpy"),
|
2023-04-04 10:56:25 -07:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 14:51:08 -08:00
|
|
|
name = "tree_util",
|
|
|
|
srcs = ["_src/tree_util.py"],
|
|
|
|
visibility = [":internal"] + jax_visibility("tree_util"),
|
|
|
|
deps = [
|
|
|
|
":traceback_util",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 10:33:11 -08:00
|
|
|
name = "traceback_util",
|
|
|
|
srcs = ["_src/traceback_util.py"],
|
2023-03-10 08:40:28 -08:00
|
|
|
visibility = [":internal"] + jax_visibility("traceback_util"),
|
2023-03-09 10:33:11 -08:00
|
|
|
deps = [
|
|
|
|
":config",
|
|
|
|
":util",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-23 10:29:11 -07:00
|
|
|
name = "typing",
|
|
|
|
srcs = [
|
|
|
|
"_src/typing.py",
|
|
|
|
],
|
|
|
|
deps = [":basearray"] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-07-20 18:28:18 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "tpu_custom_call",
|
|
|
|
srcs = ["_src/tpu_custom_call.py"],
|
|
|
|
visibility = [":internal"],
|
|
|
|
deps = [
|
2023-07-26 03:58:59 -07:00
|
|
|
":config",
|
2023-07-20 18:28:18 -07:00
|
|
|
":core",
|
|
|
|
":jax",
|
2023-12-06 08:19:14 -08:00
|
|
|
":mlir",
|
2024-02-06 15:46:31 -08:00
|
|
|
":sharding_impls",
|
2023-07-20 18:28:18 -07:00
|
|
|
"//jax/_src/lib",
|
2024-10-25 12:16:18 -07:00
|
|
|
"//jax/_src/pallas",
|
2023-12-05 08:04:41 -08:00
|
|
|
] + if_building_jaxlib([
|
|
|
|
"//jaxlib/mlir:ir",
|
2024-03-19 14:17:34 -07:00
|
|
|
"//jaxlib/mlir:mhlo_dialect",
|
|
|
|
"//jaxlib/mlir:pass_manager",
|
2023-12-05 08:04:41 -08:00
|
|
|
"//jaxlib/mlir:stablehlo_dialect",
|
|
|
|
]) + py_deps("numpy") + py_deps("absl/flags"),
|
2023-07-20 18:28:18 -07:00
|
|
|
)
|
|
|
|
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-09 05:26:58 -08:00
|
|
|
name = "util",
|
|
|
|
srcs = ["_src/util.py"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
2023-03-09 05:26:58 -08:00
|
|
|
":config",
|
|
|
|
"//jax/_src/lib",
|
2023-03-27 10:14:05 -07:00
|
|
|
] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-03-09 05:26:58 -08:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "version",
|
|
|
|
srcs = ["version.py"],
|
|
|
|
)
|
|
|
|
|
2023-03-30 14:05:24 -07:00
|
|
|
pytype_strict_library(
|
|
|
|
name = "xla",
|
|
|
|
srcs = ["_src/interpreters/xla.py"],
|
|
|
|
deps = [
|
|
|
|
":abstract_arrays",
|
|
|
|
":config",
|
|
|
|
":core",
|
2023-05-17 07:58:19 -07:00
|
|
|
":dtypes",
|
2023-04-10 10:15:08 -07:00
|
|
|
":sharding_impls",
|
2023-03-30 14:05:24 -07:00
|
|
|
":source_info_util",
|
|
|
|
":typing",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2023-03-10 11:11:28 -08:00
|
|
|
# TODO(phawkins): break up this SCC.
|
2023-03-27 10:14:05 -07:00
|
|
|
pytype_strict_library(
|
2023-03-10 11:11:28 -08:00
|
|
|
name = "xla_bridge",
|
|
|
|
srcs = [
|
|
|
|
"_src/clusters/__init__.py",
|
|
|
|
"_src/clusters/cloud_tpu_cluster.py",
|
|
|
|
"_src/clusters/cluster.py",
|
2024-07-08 05:08:25 +00:00
|
|
|
"_src/clusters/k8s_cluster.py",
|
2024-07-08 12:19:18 -07:00
|
|
|
"_src/clusters/mpi4py_cluster.py",
|
2023-03-10 11:11:28 -08:00
|
|
|
"_src/clusters/ompi_cluster.py",
|
|
|
|
"_src/clusters/slurm_cluster.py",
|
|
|
|
"_src/distributed.py",
|
|
|
|
"_src/xla_bridge.py",
|
|
|
|
],
|
2023-03-10 14:22:55 -08:00
|
|
|
visibility = [":internal"] + jax_visibility("xla_bridge"),
|
2023-03-10 11:11:28 -08:00
|
|
|
deps = [
|
|
|
|
":cloud_tpu_init",
|
|
|
|
":config",
|
2024-01-12 17:12:16 -08:00
|
|
|
":hardware_utils",
|
2023-03-10 11:11:28 -08:00
|
|
|
":traceback_util",
|
2023-03-29 08:23:18 -07:00
|
|
|
":util",
|
2023-03-10 11:11:28 -08:00
|
|
|
"//jax/_src/lib",
|
2024-06-26 13:43:15 -04:00
|
|
|
],
|
2023-03-10 11:11:28 -08:00
|
|
|
)
|
|
|
|
|
2023-03-09 08:50:49 -08:00
|
|
|
# Public JAX libraries below this point.
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
py_library_providing_imports_info(
|
|
|
|
name = "experimental",
|
2023-02-16 13:29:50 -08:00
|
|
|
srcs = glob(
|
|
|
|
[
|
|
|
|
"experimental/*.py",
|
|
|
|
"example_libraries/*.py",
|
|
|
|
],
|
|
|
|
),
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("absl/logging") + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "stax",
|
|
|
|
srcs = [
|
|
|
|
"example_libraries/stax.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "experimental_sparse",
|
2022-11-16 09:58:06 -08:00
|
|
|
srcs = glob(
|
|
|
|
[
|
|
|
|
"experimental/sparse/*.py",
|
|
|
|
],
|
|
|
|
exclude = ["experimental/sparse/test_util.py"],
|
|
|
|
),
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
2022-11-16 09:58:06 -08:00
|
|
|
pytype_library(
|
|
|
|
name = "sparse_test_util",
|
|
|
|
srcs = [
|
|
|
|
"experimental/sparse/test_util.py",
|
|
|
|
],
|
|
|
|
visibility = [":internal"],
|
|
|
|
deps = [
|
|
|
|
":experimental_sparse",
|
|
|
|
":jax",
|
|
|
|
":test_util",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "optimizers",
|
|
|
|
srcs = [
|
|
|
|
"example_libraries/optimizers.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
2023-11-16 14:21:04 -08:00
|
|
|
deps = [":jax"] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "ode",
|
|
|
|
srcs = ["experimental/ode.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
# TODO(apaszke): Remove this target
|
|
|
|
pytype_library(
|
|
|
|
name = "pjit",
|
|
|
|
srcs = ["experimental/pjit.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":experimental",
|
|
|
|
":jax",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "jet",
|
|
|
|
srcs = ["experimental/jet.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "experimental_host_callback",
|
2024-03-29 13:36:20 +02:00
|
|
|
srcs = [
|
|
|
|
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
|
|
|
|
"experimental/host_callback.py",
|
|
|
|
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":jax",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "compilation_cache",
|
|
|
|
srcs = [
|
2025-02-25 09:28:35 -08:00
|
|
|
"experimental/compilation_cache/__init__.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"experimental/compilation_cache/compilation_cache.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "mesh_utils",
|
|
|
|
srcs = ["experimental/mesh_utils.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
2024-09-03 14:30:37 -07:00
|
|
|
":internal_mesh_utils",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
)
|
2022-11-28 14:31:10 -08:00
|
|
|
|
2023-12-04 17:48:17 -08:00
|
|
|
# TODO(phawkins): remove this target in favor of the finer-grained targets in jax/extend/...
|
|
|
|
pytype_strict_library(
|
2023-08-24 14:40:10 -07:00
|
|
|
name = "extend",
|
|
|
|
visibility = [":jax_extend_users"],
|
2023-12-04 17:48:17 -08:00
|
|
|
deps = [
|
|
|
|
"//jax/extend",
|
2024-03-14 16:09:02 -07:00
|
|
|
"//jax/extend:backend",
|
2023-12-04 17:48:17 -08:00
|
|
|
"//jax/extend:core",
|
|
|
|
"//jax/extend:linear_util",
|
|
|
|
"//jax/extend:random",
|
|
|
|
"//jax/extend:source_info_util",
|
|
|
|
],
|
2023-08-24 14:40:10 -07:00
|
|
|
)
|
|
|
|
|
2023-07-20 18:28:18 -07:00
|
|
|
pytype_library(
|
|
|
|
name = "mosaic",
|
|
|
|
srcs = [
|
|
|
|
"experimental/mosaic/__init__.py",
|
|
|
|
"experimental/mosaic/dialects.py",
|
|
|
|
],
|
|
|
|
visibility = [":mosaic_users"],
|
|
|
|
deps = [
|
|
|
|
":tpu_custom_call",
|
|
|
|
"//jax/_src/lib",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-11-28 14:31:10 -08:00
|
|
|
pytype_library(
|
|
|
|
name = "rnn",
|
|
|
|
srcs = ["experimental/rnn.py"],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [":jax"],
|
|
|
|
)
|
2024-10-28 12:17:34 -07:00
|
|
|
|
|
|
|
pytype_library(
|
|
|
|
name = "experimental_colocated_python",
|
|
|
|
srcs = [
|
|
|
|
"experimental/colocated_python/__init__.py",
|
|
|
|
"experimental/colocated_python/api.py",
|
|
|
|
"experimental/colocated_python/func.py",
|
|
|
|
"experimental/colocated_python/func_backend.py",
|
2025-02-21 12:57:49 -08:00
|
|
|
"experimental/colocated_python/obj.py",
|
|
|
|
"experimental/colocated_python/obj_backend.py",
|
2024-10-28 12:17:34 -07:00
|
|
|
"experimental/colocated_python/serialization.py",
|
|
|
|
],
|
|
|
|
visibility = ["//visibility:public"],
|
|
|
|
deps = [
|
|
|
|
":api_util",
|
|
|
|
":jax",
|
|
|
|
":traceback_util",
|
|
|
|
":tree_util",
|
|
|
|
":util",
|
|
|
|
":xla_bridge",
|
|
|
|
"//jax/_src/lib",
|
2024-11-26 13:30:31 -08:00
|
|
|
"//jax/extend:ifrt_programs",
|
2024-10-28 12:17:34 -07:00
|
|
|
] + py_deps("numpy") + py_deps("cloudpickle"),
|
|
|
|
)
|