rocm_jax/jax/BUILD
Adam Paszke 3649da56fc [Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes to vector length
We can now perform the conversion in groups of 2, 4 or even 8 elements at a time.

PiperOrigin-RevId: 737626600
2025-03-17 08:37:17 -07:00

1308 lines
30 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# JAX is Autograd and XLA
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"if_building_jaxlib",
"jax_export_file_visibility",
"jax_extend_internal_users",
"jax_extra_deps",
"jax_internal_export_back_compat_test_util_visibility",
"jax_internal_packages",
"jax_internal_test_harnesses_visibility",
"jax_test_util_visibility",
"jax_visibility",
"mosaic_gpu_internal_users",
"mosaic_internal_users",
"pallas_fuser_users",
"pallas_gpu_internal_users",
"pallas_tpu_internal_users",
"py_deps",
"py_library_providing_imports_info",
"pytype_library",
"pytype_strict_library",
"serialize_executable_internal_users",
)
package(
default_applicable_licenses = [],
default_visibility = [":internal"],
)
licenses(["notice"])
# The flag controls whether jaxlib should be built by Bazel.
# If ":build_jaxlib=true", then jaxlib will be built.
# If ":build_jaxlib=false", then jaxlib is not built. It is assumed that the pre-built jaxlib wheel
# is available in the "dist" folder.
# If ":build_jaxlib=wheel", then jaxlib wheel will be built as a py_import rule attribute.
# The py_import rule unpacks the wheel and provides its content as a py_library.
string_flag(
name = "build_jaxlib",
build_setting_default = "true",
values = [
"true",
"false",
"wheel",
],
)
config_setting(
name = "enable_jaxlib_build",
flag_values = {
":build_jaxlib": "true",
},
)
exports_files([
"LICENSE",
"version.py",
"py.typed",
])
exports_files(
["_src/export/serialization.fbs"],
visibility = jax_export_file_visibility,
)
# Packages that have access to JAX-internal implementation details.
package_group(
name = "internal",
packages = [
"//...",
] + jax_internal_packages,
)
package_group(
name = "jax_extend_users",
includes = [":internal"],
packages = [
# Intentionally avoid jax dependencies on jax.extend.
# See https://jax.readthedocs.io/en/latest/jep/15856-jex.html
"//tests/...",
] + jax_extend_internal_users,
)
package_group(
name = "mosaic_users",
includes = [":internal"],
packages = mosaic_internal_users,
)
package_group(
name = "pallas_gpu_users",
includes = [":internal"],
packages = pallas_gpu_internal_users,
)
package_group(
name = "pallas_tpu_users",
includes = [":internal"],
packages = pallas_tpu_internal_users,
)
package_group(
name = "pallas_fuser_users",
includes = [":internal"],
packages = pallas_fuser_users,
)
package_group(
name = "mosaic_gpu_users",
includes = [":internal"],
packages = mosaic_gpu_internal_users,
)
package_group(
name = "serialize_executable_users",
includes = [":internal"],
packages = serialize_executable_internal_users,
)
# JAX-private test utilities.
py_library(
# This build target is required in order to use private test utilities in jax._src.test_util,
# and its visibility is intentionally restricted to discourage its use outside JAX itself.
# JAX does provide some public test utilities (see jax/test_util.py);
# these are available in jax.test_util via the standard :jax target.
name = "test_util",
srcs = [
"_src/test_util.py",
"_src/test_warning_util.py",
],
visibility = [
":internal",
] + jax_test_util_visibility,
deps = [
":compilation_cache_internal",
":jax",
] + py_deps("absl/testing") + py_deps("numpy"),
)
# TODO(necula): break the internal_test_util into smaller build targets.
py_library(
name = "internal_test_util",
srcs = [
"_src/internal_test_util/__init__.py",
"_src/internal_test_util/deprecation_module.py",
"_src/internal_test_util/lax_test_util.py",
] + glob(
[
"_src/internal_test_util/lazy_loader_module/*.py",
],
),
visibility = [":internal"],
deps = [
":jax",
] + py_deps("numpy"),
)
py_library(
name = "internal_test_harnesses",
srcs = ["_src/internal_test_util/test_harnesses.py"],
visibility = [":internal"] + jax_internal_test_harnesses_visibility,
deps = [
":ad_util",
":config",
":jax",
":test_util",
"//jax/_src/lib",
] + py_deps("numpy"),
)
py_library(
name = "internal_export_back_compat_test_util",
srcs = ["_src/internal_test_util/export_back_compat_test_util.py"],
visibility = [
":internal",
] + jax_internal_export_back_compat_test_util_visibility,
deps = [
":jax",
":test_util",
] + py_deps("numpy"),
)
py_library(
name = "internal_export_back_compat_test_data",
testonly = 1,
srcs = glob([
"_src/internal_test_util/export_back_compat_test_data/*.py",
"_src/internal_test_util/export_back_compat_test_data/pallas/*.py",
]),
visibility = [
":internal",
],
deps = py_deps("numpy"),
)
py_library_providing_imports_info(
name = "jax",
srcs = [
"_src/__init__.py",
"_src/ad_checkpoint.py",
"_src/api.py",
"_src/array.py",
"_src/blocked_sampler.py",
"_src/callback.py",
"_src/checkify.py",
"_src/custom_batching.py",
"_src/custom_dce.py",
"_src/custom_derivatives.py",
"_src/custom_partitioning.py",
"_src/custom_partitioning_sharding_rule.py",
"_src/custom_transpose.py",
"_src/debugging.py",
"_src/dispatch.py",
"_src/dlpack.py",
"_src/earray.py",
"_src/error_check.py",
"_src/ffi.py",
"_src/flatten_util.py",
"_src/interpreters/__init__.py",
"_src/interpreters/ad.py",
"_src/interpreters/batching.py",
"_src/interpreters/pxla.py",
"_src/pjit.py",
"_src/prng.py",
"_src/public_test_util.py",
"_src/random.py",
"_src/shard_alike.py",
"_src/sourcemap.py",
"_src/stages.py",
"_src/tree.py",
] + glob(
[
"*.py",
"_src/cudnn/**/*.py",
"_src/debugger/**/*.py",
"_src/extend/**/*.py",
"_src/image/**/*.py",
"_src/export/**/*.py",
"_src/lax/**/*.py",
"_src/nn/**/*.py",
"_src/numpy/**/*.py",
"_src/ops/**/*.py",
"_src/scipy/**/*.py",
"_src/state/**/*.py",
"_src/third_party/**/*.py",
"experimental/key_reuse/**/*.py",
"experimental/roofline/**/*.py",
"image/**/*.py",
"interpreters/**/*.py",
"lax/**/*.py",
"lib/**/*.py",
"nn/**/*.py",
"numpy/**/*.py",
"ops/**/*.py",
"scipy/**/*.py",
"third_party/**/*.py",
],
exclude = [
"*_test.py",
"**/*_test.py",
"_src/internal_test_util/**",
],
) + [
"experimental/attrs.py",
"experimental/pjit.py",
"experimental/multihost_utils.py",
"experimental/shard_map.py",
# until checkify is moved out of experimental
"experimental/checkify.py",
"experimental/compilation_cache/compilation_cache.py",
],
lib_rule = pytype_library,
pytype_srcs = glob(
[
"numpy/*.pyi",
"_src/**/*.pyi",
],
exclude = [
"_src/basearray.pyi",
],
),
visibility = ["//visibility:public"],
deps = [
":abstract_arrays",
":ad_util",
":api_util",
":basearray",
":cloud_tpu_init",
":compilation_cache_internal",
":compiler",
":compute_on",
":config",
":core",
":custom_api_util",
":deprecations",
":dtypes",
":effects",
":environment_info",
":internal_mesh_utils",
":jaxpr_util",
":layout",
":lazy_loader",
":mesh",
":mlir",
":monitoring",
":named_sharding",
":op_shardings",
":partial_eval",
":partition_spec",
":path",
":pickle_util",
":pretty_printer",
":profiler",
":sharding",
":sharding_impls",
":sharding_specs",
":source_info_util",
":traceback_util",
":tree_util",
":typing",
":util",
":version",
":xla",
":xla_bridge",
":xla_metadata",
"//jax/_src/lib",
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps,
)
pytype_strict_library(
name = "abstract_arrays",
srcs = ["_src/abstract_arrays.py"],
deps = [
":ad_util",
":core",
":dtypes",
":traceback_util",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "ad_util",
srcs = ["_src/ad_util.py"],
deps = [
":core",
":traceback_util",
":tree_util",
":typing",
":util",
],
)
pytype_strict_library(
name = "api_util",
srcs = ["_src/api_util.py"],
deps = [
":abstract_arrays",
":config",
":core",
":dtypes",
":state_types",
":traceback_util",
":tree_util",
":util",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "basearray",
srcs = ["_src/basearray.py"],
pytype_srcs = ["_src/basearray.pyi"],
deps = [
":partition_spec",
":sharding",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "cloud_tpu_init",
srcs = ["_src/cloud_tpu_init.py"],
deps = [
":config",
":hardware_utils",
":version",
],
)
pytype_strict_library(
name = "compilation_cache_internal",
srcs = ["_src/compilation_cache.py"],
visibility = [":internal"] + jax_visibility("compilation_cache"),
deps = [
":cache_key",
":compilation_cache_interface",
":config",
":lru_cache",
":monitoring",
":path",
"//jax/_src/lib",
] + py_deps("numpy") + py_deps("zstandard"),
)
pytype_strict_library(
name = "cache_key",
srcs = ["_src/cache_key.py"],
visibility = [":internal"] + jax_visibility("compilation_cache"),
deps = [
":config",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "compilation_cache_interface",
srcs = ["_src/compilation_cache_interface.py"],
deps = [
":path",
":util",
],
)
pytype_strict_library(
name = "lru_cache",
srcs = ["_src/lru_cache.py"],
deps = [
":compilation_cache_interface",
":path",
] + py_deps("filelock"),
)
pytype_strict_library(
name = "config",
srcs = ["_src/config.py"],
deps = [
":logging_config",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "logging_config",
srcs = ["_src/logging_config.py"],
)
pytype_strict_library(
name = "compiler",
srcs = ["_src/compiler.py"],
deps = [
":cache_key",
":compilation_cache_internal",
":config",
":mlir",
":monitoring",
":path",
":profiler",
":traceback_util",
":xla_bridge",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "core",
srcs = [
"_src/core.py",
"_src/errors.py",
"_src/linear_util.py",
],
deps = [
":compute_on",
":config",
":deprecations",
":dtypes",
":effects",
":mesh",
":named_sharding",
":partition_spec",
":pretty_printer",
":source_info_util",
":traceback_util",
":tree_util",
":typing",
":util",
":xla_metadata",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "custom_api_util",
srcs = ["_src/custom_api_util.py"],
)
pytype_strict_library(
name = "deprecations",
srcs = ["_src/deprecations.py"],
)
pytype_strict_library(
name = "dtypes",
srcs = [
"_src/dtypes.py",
],
deps = [
":config",
":traceback_util",
":typing",
":util",
"//jax/_src/lib",
] + py_deps("ml_dtypes") + py_deps("numpy"),
)
pytype_strict_library(
name = "effects",
srcs = ["_src/effects.py"],
)
pytype_strict_library(
name = "environment_info",
srcs = ["_src/environment_info.py"],
deps = [
":version",
":xla_bridge",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "hardware_utils",
srcs = ["_src/hardware_utils.py"],
)
pytype_library(
name = "lax_reference",
srcs = ["_src/lax_reference.py"],
visibility = [":internal"] + jax_visibility("lax_reference"),
deps = [
":core",
":util",
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum"),
)
pytype_strict_library(
name = "lazy_loader",
srcs = ["_src/lazy_loader.py"],
)
pytype_strict_library(
name = "jaxpr_util",
srcs = ["_src/jaxpr_util.py"],
deps = [
":core",
":source_info_util",
":util",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "mesh",
srcs = ["_src/mesh.py"],
deps = [
":config",
":util",
":xla_bridge",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "mlir",
srcs = ["_src/interpreters/mlir.py"],
deps = [
":ad_util",
":api_util",
":config",
":core",
":dtypes",
":effects",
":layout",
":op_shardings",
":partial_eval",
":partition_spec",
":path",
":pickle_util",
":sharding",
":sharding_impls",
":source_info_util",
":state_types",
":util",
":xla",
":xla_bridge",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "monitoring",
srcs = ["_src/monitoring.py"],
)
pytype_strict_library(
name = "op_shardings",
srcs = ["_src/op_shardings.py"],
deps = [
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "serialize_executable",
srcs = ["experimental/serialize_executable.py"],
visibility = [":serialize_executable_users"],
deps = [
":jax",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "source_mapper",
srcs = glob(include = ["experimental/source_mapper/**/*.py"]),
visibility = [
"//visibility:public",
],
deps = [
":config",
":core",
":jax",
":source_info_util",
] + py_deps("absl/flags"),
)
pytype_strict_library(
name = "pallas",
srcs = glob(
[
"experimental/pallas/**/*.py",
],
exclude = [
"experimental/pallas/gpu.py",
"experimental/pallas/mosaic_gpu.py",
"experimental/pallas/ops/gpu/**/*.py",
"experimental/pallas/ops/tpu/**/*.py",
"experimental/pallas/tpu.py",
"experimental/pallas/fuser.py",
"experimental/pallas/triton.py",
],
),
visibility = [
"//visibility:public",
],
deps = [
":deprecations",
":jax",
":source_info_util",
"//jax/_src/pallas",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "pallas_tpu",
srcs = ["experimental/pallas/tpu.py"],
visibility = [
":pallas_tpu_users",
],
deps = [
":pallas", # build_cleaner: keep
":tpu_custom_call",
"//jax/_src/pallas",
"//jax/_src/pallas/mosaic:core",
"//jax/_src/pallas/mosaic:helpers",
"//jax/_src/pallas/mosaic:interpret",
"//jax/_src/pallas/mosaic:lowering",
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/mosaic:pipeline",
"//jax/_src/pallas/mosaic:primitives",
"//jax/_src/pallas/mosaic:random",
"//jax/_src/pallas/mosaic:verification",
],
)
pytype_strict_library(
name = "pallas_fuser",
srcs = ["experimental/pallas/fuser.py"],
visibility = [
":pallas_fuser_users",
],
deps = [
":pallas", # build_cleaner: keep
"//jax/_src/pallas/fuser:block_spec",
"//jax/_src/pallas/fuser:custom_evaluate",
"//jax/_src/pallas/fuser:fusable",
"//jax/_src/pallas/fuser:fusion",
"//jax/_src/pallas/fuser:jaxpr_fusion",
],
)
pytype_strict_library(
name = "pallas_gpu_ops",
srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"],
visibility = [
":pallas_gpu_users",
],
deps = [
":jax",
":pallas",
":pallas_gpu",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "pallas_experimental_gpu_ops",
srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"],
visibility = [
":mosaic_gpu_users",
],
deps = [
":jax",
":mosaic_gpu",
":pallas",
":pallas_mosaic_gpu",
":test_util", # This is only to make them runnable as jax_multiplatform_test...
] + py_deps("numpy"),
)
pytype_strict_library(
name = "pallas_tpu_ops",
srcs = glob(["experimental/pallas/ops/tpu/**/*.py"]),
visibility = [
":pallas_tpu_users",
],
deps = [
":jax",
":pallas",
":pallas_tpu",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "pallas_gpu",
visibility = [
":pallas_gpu_users",
],
deps = [
":pallas_triton",
# TODO(slebedev): Add :pallas_mosaic_gpu once it is ready.
],
)
pytype_strict_library(
name = "pallas_triton",
srcs = [
"experimental/pallas/gpu.py",
"experimental/pallas/triton.py",
],
visibility = [
":pallas_gpu_users",
],
deps = [
":deprecations",
"//jax/_src/pallas/triton:core",
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/triton:primitives",
],
)
pytype_strict_library(
name = "pallas_mosaic_gpu",
srcs = ["experimental/pallas/mosaic_gpu.py"],
visibility = [
":mosaic_gpu_users",
],
deps = [
":mosaic_gpu",
"//jax/_src/pallas/mosaic_gpu:core",
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/mosaic_gpu:pipeline",
"//jax/_src/pallas/mosaic_gpu:primitives",
],
)
# This target only supports sm_90 GPUs.
py_library_providing_imports_info(
name = "mosaic_gpu",
srcs = glob(["experimental/mosaic/gpu/*.py"]),
visibility = [
":mosaic_gpu_users",
],
deps = [
":config",
":core",
":jax",
":mlir",
"//jax/_src/lib",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:builtin_dialect",
"//jaxlib/mlir:func_dialect",
"//jaxlib/mlir:gpu_dialect",
"//jaxlib/mlir:ir",
"//jaxlib/mlir:llvm_dialect",
"//jaxlib/mlir:math_dialect",
"//jaxlib/mlir:memref_dialect",
"//jaxlib/mlir:nvgpu_dialect",
"//jaxlib/mlir:nvvm_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:scf_dialect",
"//jaxlib/mlir:vector_dialect",
"//jaxlib/mosaic/python:gpu_dialect",
] + py_deps("absl/flags") + py_deps("numpy"),
)
pytype_strict_library(
name = "partial_eval",
srcs = ["_src/interpreters/partial_eval.py"],
deps = [
":ad_util",
":api_util",
":compute_on",
":config",
":core",
":dtypes",
":effects",
":profiler",
":source_info_util",
":state_types",
":tree_util",
":util",
":xla_metadata",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "partition_spec",
srcs = ["_src/partition_spec.py"],
)
pytype_strict_library(
name = "path",
srcs = ["_src/path.py"],
deps = py_deps("epath"),
)
pytype_library(
name = "experimental_profiler",
srcs = ["experimental/profiler.py"],
visibility = ["//visibility:public"],
deps = [
"//jax/_src/lib",
],
)
pytype_library(
name = "experimental_transfer",
srcs = ["experimental/transfer.py"],
deps = [
":jax",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "pickle_util",
srcs = ["_src/pickle_util.py"],
deps = [":profiler"] + py_deps("cloudpickle"),
)
pytype_strict_library(
name = "pretty_printer",
srcs = ["_src/pretty_printer.py"],
deps = [
":config",
":util",
] + py_deps("colorama"),
)
pytype_strict_library(
name = "profiler",
srcs = ["_src/profiler.py"],
deps = [
":traceback_util",
":xla_bridge",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "sharding",
srcs = ["_src/sharding.py"],
deps = [
":op_shardings",
":util",
":xla_bridge",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "compute_on",
srcs = ["_src/compute_on.py"],
deps = [
":config",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "xla_metadata",
srcs = ["_src/xla_metadata.py"],
deps = [
":config",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "layout",
srcs = ["_src/layout.py"],
deps = [
":dtypes",
":sharding",
":sharding_impls",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "sharding_impls",
srcs = ["_src/sharding_impls.py"],
deps = [
":config",
":core",
":internal_mesh_utils",
":mesh",
":named_sharding",
":op_shardings",
":partition_spec",
":sharding",
":sharding_specs",
":source_info_util",
":tree_util",
":util",
":xla_bridge",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "named_sharding",
srcs = ["_src/named_sharding.py"],
deps = [
":config",
":mesh",
":partition_spec",
":sharding",
":util",
":xla_bridge",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "sharding_specs",
srcs = ["_src/sharding_specs.py"],
deps = [
":config",
":op_shardings",
":util",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_library(
name = "internal_mesh_utils",
srcs = ["_src/mesh_utils.py"],
deps = [
":xla_bridge",
],
)
pytype_strict_library(
name = "source_info_util",
srcs = ["_src/source_info_util.py"],
visibility = [":internal"] + jax_visibility("source_info_util"),
deps = [
":traceback_util",
":version",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "state_types",
srcs = [
"_src/state/__init__.py",
"_src/state/indexing.py",
"_src/state/types.py",
],
deps = [
":core",
":dtypes",
":effects",
":pretty_printer",
":traceback_util",
":tree_util",
":typing",
":util",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "tree_util",
srcs = ["_src/tree_util.py"],
visibility = [":internal"] + jax_visibility("tree_util"),
deps = [
":traceback_util",
":util",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "traceback_util",
srcs = ["_src/traceback_util.py"],
visibility = [":internal"] + jax_visibility("traceback_util"),
deps = [
":config",
":util",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "typing",
srcs = [
"_src/typing.py",
],
deps = [":basearray"] + py_deps("numpy"),
)
pytype_strict_library(
name = "tpu_custom_call",
srcs = ["_src/tpu_custom_call.py"],
visibility = [":internal"],
deps = [
":config",
":core",
":jax",
":mlir",
":sharding_impls",
"//jax/_src/lib",
"//jax/_src/pallas",
] + if_building_jaxlib([
"//jaxlib/mlir:ir",
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:stablehlo_dialect",
]) + py_deps("numpy") + py_deps("absl/flags"),
)
pytype_strict_library(
name = "util",
srcs = ["_src/util.py"],
deps = [
":config",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "version",
srcs = ["version.py"],
)
pytype_strict_library(
name = "xla",
srcs = ["_src/interpreters/xla.py"],
deps = [
":abstract_arrays",
":config",
":core",
":dtypes",
":sharding_impls",
":source_info_util",
":typing",
":util",
":xla_bridge",
"//jax/_src/lib",
] + py_deps("numpy"),
)
# TODO(phawkins): break up this SCC.
pytype_strict_library(
name = "xla_bridge",
srcs = [
"_src/clusters/__init__.py",
"_src/clusters/cloud_tpu_cluster.py",
"_src/clusters/cluster.py",
"_src/clusters/k8s_cluster.py",
"_src/clusters/mpi4py_cluster.py",
"_src/clusters/ompi_cluster.py",
"_src/clusters/slurm_cluster.py",
"_src/distributed.py",
"_src/xla_bridge.py",
],
visibility = [":internal"] + jax_visibility("xla_bridge"),
deps = [
":cloud_tpu_init",
":config",
":hardware_utils",
":traceback_util",
":util",
"//jax/_src/lib",
],
)
# Public JAX libraries below this point.
py_library_providing_imports_info(
name = "experimental",
srcs = glob(
[
"experimental/*.py",
"example_libraries/*.py",
],
),
visibility = ["//visibility:public"],
deps = [
":jax",
] + py_deps("absl/logging") + py_deps("numpy"),
)
pytype_library(
name = "stax",
srcs = [
"example_libraries/stax.py",
],
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "experimental_sparse",
srcs = glob(
[
"experimental/sparse/*.py",
],
exclude = ["experimental/sparse/test_util.py"],
),
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "sparse_test_util",
srcs = [
"experimental/sparse/test_util.py",
],
visibility = [":internal"],
deps = [
":experimental_sparse",
":jax",
":test_util",
] + py_deps("numpy"),
)
pytype_library(
name = "optimizers",
srcs = [
"example_libraries/optimizers.py",
],
visibility = ["//visibility:public"],
deps = [":jax"] + py_deps("numpy"),
)
pytype_library(
name = "ode",
srcs = ["experimental/ode.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)
# TODO(apaszke): Remove this target
pytype_library(
name = "pjit",
srcs = ["experimental/pjit.py"],
visibility = ["//visibility:public"],
deps = [
":experimental",
":jax",
],
)
pytype_library(
name = "jet",
srcs = ["experimental/jet.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "experimental_host_callback",
srcs = [
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
"experimental/host_callback.py",
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
],
visibility = ["//visibility:public"],
deps = [
":jax",
],
)
pytype_library(
name = "compilation_cache",
srcs = [
"experimental/compilation_cache/__init__.py",
"experimental/compilation_cache/compilation_cache.py",
],
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "mesh_utils",
srcs = ["experimental/mesh_utils.py"],
visibility = ["//visibility:public"],
deps = [
":internal_mesh_utils",
],
)
# TODO(phawkins): remove this target in favor of the finer-grained targets in jax/extend/...
pytype_strict_library(
name = "extend",
visibility = [":jax_extend_users"],
deps = [
"//jax/extend",
"//jax/extend:backend",
"//jax/extend:core",
"//jax/extend:linear_util",
"//jax/extend:random",
"//jax/extend:source_info_util",
],
)
pytype_library(
name = "mosaic",
srcs = [
"experimental/mosaic/__init__.py",
"experimental/mosaic/dialects.py",
],
visibility = [":mosaic_users"],
deps = [
":tpu_custom_call",
"//jax/_src/lib",
],
)
pytype_library(
name = "rnn",
srcs = ["experimental/rnn.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "experimental_colocated_python",
srcs = [
"experimental/colocated_python/__init__.py",
"experimental/colocated_python/api.py",
"experimental/colocated_python/func.py",
"experimental/colocated_python/func_backend.py",
"experimental/colocated_python/obj.py",
"experimental/colocated_python/obj_backend.py",
"experimental/colocated_python/serialization.py",
],
visibility = ["//visibility:public"],
deps = [
":api_util",
":jax",
":traceback_util",
":tree_util",
":util",
":xla_bridge",
"//jax/_src/lib",
"//jax/extend:ifrt_programs",
] + py_deps("numpy") + py_deps("cloudpickle"),
)