mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

C++ static initialization acquires an internal mutex. It is unsafe to call into Python code while holding that mutex, e.g., see the deadlock in https://gist.github.com/vfdev-5/826ef16c6cbc9f4d85466e8a348c3b5a However, in this case, there's a simpler thing we can do: eagerly initialize the ::type() values during module initialization, rather than on-demand. PiperOrigin-RevId: 744508279
1092 lines
33 KiB
Python
1092 lines
33 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
load(
|
|
"//jaxlib:jax.bzl",
|
|
"cc_proto_library",
|
|
"if_oss",
|
|
"jax_visibility",
|
|
"nanobind_extension",
|
|
"proto_library",
|
|
"py_deps",
|
|
"py_strict_library",
|
|
"py_strict_test",
|
|
"pytype_strict_library",
|
|
)
|
|
|
|
licenses(["notice"])
|
|
|
|
package(
|
|
default_applicable_licenses = [],
|
|
default_visibility = ["//jax:internal"],
|
|
)
|
|
|
|
package_group(
|
|
name = "xla_python",
|
|
includes = [
|
|
"//jax:internal",
|
|
],
|
|
)
|
|
|
|
nanobind_extension(
|
|
name = "xla_extension",
|
|
srcs = ["xla.cc"],
|
|
pytype_deps = py_deps(["numpy"]),
|
|
pytype_srcs = glob(["xla_extension/*.pyi"]),
|
|
visibility = ["//visibility:public"],
|
|
deps = [
|
|
":config",
|
|
":custom_call_sharding",
|
|
":dlpack",
|
|
":guard_lib",
|
|
":ifrt_proxy",
|
|
":jax_jit",
|
|
":mlir",
|
|
":nb_class_ptr",
|
|
":pjit",
|
|
":pmap_lib",
|
|
":py_client",
|
|
":python_ref_manager",
|
|
":pytree",
|
|
":sdy",
|
|
":traceback",
|
|
":util",
|
|
":weakref_lru_cache",
|
|
":xla_compiler",
|
|
"@com_google_absl//absl/base",
|
|
"@com_google_absl//absl/container:flat_hash_map",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/hash",
|
|
"@com_google_absl//absl/log:initialize",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@com_google_absl//absl/synchronization",
|
|
"@com_google_absl//absl/time",
|
|
"@com_google_absl//absl/types:span",
|
|
"@llvm-project//llvm:Support",
|
|
"@llvm-project//mlir:IR",
|
|
"@nanobind",
|
|
"@tsl//tsl/platform",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla:literal",
|
|
"@xla//xla:shape_util",
|
|
"@xla//xla:types",
|
|
"@xla//xla:util",
|
|
"@xla//xla/backends/cpu/collectives:cpu_collectives",
|
|
"@xla//xla/ffi:ffi_api",
|
|
"@xla//xla/pjrt:exceptions",
|
|
"@xla//xla/pjrt:mlir_to_hlo",
|
|
"@xla//xla/pjrt:pjrt_api",
|
|
"@xla//xla/pjrt:pjrt_c_api_client",
|
|
"@xla//xla/pjrt:pjrt_client",
|
|
"@xla//xla/pjrt:pjrt_common",
|
|
"@xla//xla/pjrt:pjrt_compiler",
|
|
"@xla//xla/pjrt:pjrt_executable",
|
|
"@xla//xla/pjrt:pjrt_layout",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
|
"@xla//xla/pjrt/distributed",
|
|
"@xla//xla/pjrt/distributed:client",
|
|
"@xla//xla/pjrt/distributed:key_value_store_interface",
|
|
"@xla//xla/pjrt/distributed:protocol_proto_cc",
|
|
"@xla//xla/pjrt/distributed:service",
|
|
"@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
|
|
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
|
|
"@xla//xla/python:logging",
|
|
"@xla//xla/python:nb_absl_flat_hash_map",
|
|
"@xla//xla/python:nb_absl_span",
|
|
"@xla//xla/python:ops",
|
|
"@xla//xla/python:pprof_profile_builder",
|
|
"@xla//xla/python:profiler",
|
|
"@xla//xla/python:refine_polymorphic_shapes",
|
|
"@xla//xla/python:types",
|
|
"@xla//xla/python:version",
|
|
"@xla//xla/python/ifrt",
|
|
"@xla//xla/python/ifrt:plugin_program",
|
|
"@xla//xla/python/ifrt:plugin_program_serdes",
|
|
"@xla//xla/python/pjrt_ifrt",
|
|
"@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
|
|
"@xla//xla/python/pjrt_ifrt:xla_ifrt",
|
|
"@xla//xla/tsl/concurrency:ref_count",
|
|
"@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager",
|
|
"@xla//xla/tsl/platform:logging",
|
|
"@xla//xla/tsl/platform:status",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
"@xla//xla/tsl/platform/cloud:gcs_file_system",
|
|
"@xla//xla/tsl/python/lib/core:numpy",
|
|
] + select({
|
|
# gloo tcp transport only builds on linux
|
|
"@xla//xla/tsl:macos": [
|
|
"@gloo//:transport_uv",
|
|
"@xla//xla/backends/cpu/collectives:gloo_collectives",
|
|
"@xla//xla/backends/cpu/collectives:gloo_kv_store",
|
|
],
|
|
"@xla//xla/tsl:windows": [],
|
|
"//conditions:default": [
|
|
":py_socket_transfer",
|
|
"@gloo//:transport_tcp",
|
|
"@xla//xla/backends/cpu/collectives:gloo_collectives",
|
|
"@xla//xla/backends/cpu/collectives:gloo_kv_store",
|
|
],
|
|
}) + select({
|
|
# mpitrampoline does not build on windows
|
|
"@xla//xla/tsl:windows": [],
|
|
# we support MPI collectives only in OSS builds
|
|
"//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]),
|
|
}),
|
|
)
|
|
|
|
cc_library(
|
|
name = "callback",
|
|
srcs = [
|
|
"callback.cc",
|
|
],
|
|
hdrs = [
|
|
"callback.h",
|
|
],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":python_ref_manager",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@com_google_absl//absl/types:span",
|
|
"@nanobind",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla:comparison_util",
|
|
"@xla//xla:xla_data_proto_cc",
|
|
"@xla//xla/pjrt:host_callback",
|
|
"@xla//xla/pjrt:transpose",
|
|
"@xla//xla/python:nb_numpy",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
"@xla//xla/tsl/python/lib/core:numpy",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "config",
|
|
srcs = ["config.cc"],
|
|
hdrs = ["config.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":python_ref_manager",
|
|
"@com_google_absl//absl/base:core_headers",
|
|
"@com_google_absl//absl/container:flat_hash_set",
|
|
"@com_google_absl//absl/synchronization",
|
|
"@com_google_absl//absl/types:span",
|
|
"@nanobind",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla/tsl/platform:logging",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "custom_call_sharding",
|
|
srcs = ["custom_call_sharding.cc"],
|
|
hdrs = ["custom_call_sharding.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings",
|
|
"@nanobind",
|
|
"@xla//third_party/python_runtime:headers",
|
|
"@xla//xla:shape_util",
|
|
"@xla//xla:util",
|
|
"@xla//xla/hlo/ir:hlo",
|
|
"@xla//xla/hlo/utils:hlo_sharding_util",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs",
|
|
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
|
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
|
"@xla//xla/python:custom_call_batch_partitioner",
|
|
"@xla//xla/python:custom_partition_callback",
|
|
"@xla//xla/python:inspect_sharding",
|
|
"@xla//xla/tsl/platform:logging",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "dlpack",
|
|
srcs = ["dlpack.cc"],
|
|
hdrs = ["dlpack.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":nb_class_ptr",
|
|
":py_client",
|
|
":python_ref_manager",
|
|
":traceback",
|
|
":util",
|
|
"@com_google_absl//absl/algorithm:container",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/types:span",
|
|
"@dlpack",
|
|
"@llvm-project//llvm:Support",
|
|
"@nanobind",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla:shape_util",
|
|
"@xla//xla:status_macros",
|
|
"@xla//xla:util",
|
|
"@xla//xla:xla_data_proto_cc",
|
|
"@xla//xla/pjrt:exceptions",
|
|
"@xla//xla/pjrt:pjrt_client",
|
|
"@xla//xla/pjrt:pjrt_common",
|
|
"@xla//xla/pjrt:pjrt_compiler",
|
|
"@xla//xla/python:types",
|
|
"@xla//xla/python/ifrt",
|
|
"@xla//xla/python/pjrt_ifrt",
|
|
"@xla//xla/tsl/platform:errors",
|
|
"@xla//xla/tsl/platform:logging",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "guard_lib",
|
|
srcs = ["guard_lib.cc"],
|
|
hdrs = ["guard_lib.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
"@com_google_absl//absl/base:core_headers",
|
|
"@com_google_absl//absl/functional:function_ref",
|
|
"@com_google_absl//absl/log",
|
|
"@com_google_absl//absl/status",
|
|
"@nanobind",
|
|
"@xla//xla:util",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "ifrt_proxy",
|
|
srcs = ["ifrt_proxy.cc"],
|
|
hdrs = ["ifrt_proxy.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":nb_class_ptr",
|
|
":py_client",
|
|
"@com_google_absl//absl/log",
|
|
"@com_google_absl//absl/log:check",
|
|
"@com_google_absl//absl/log:log_entry",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings:string_view",
|
|
"@com_google_absl//absl/time",
|
|
"@nanobind",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/python/ifrt",
|
|
"@xla//xla/python/ifrt:attribute_map",
|
|
"@xla//xla/python/ifrt_proxy/client:grpc_client",
|
|
"@xla//xla/python/ifrt_proxy/client:registry",
|
|
"@xla//xla/tsl/platform:env",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "jax_jit",
|
|
srcs = ["jax_jit.cc"],
|
|
hdrs = ["jax_jit.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":py_client",
|
|
":python_ref_manager",
|
|
":pytree",
|
|
"@com_google_absl//absl/algorithm:container",
|
|
"@com_google_absl//absl/base:core_headers",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/hash",
|
|
"@com_google_absl//absl/log:check",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@com_google_absl//absl/types:span",
|
|
"@nanobind",
|
|
"@tsl//tsl/profiler/lib:traceme",
|
|
"@xla//third_party/python_runtime:headers", # build_cleaner: keep
|
|
"@xla//xla/pjrt:pjrt_client",
|
|
"@xla//xla/pjrt:pjrt_layout",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/python:nb_absl_inlined_vector",
|
|
"@xla//xla/python:nb_absl_span",
|
|
"@xla//xla/python:types",
|
|
"@xla//xla/tsl/platform:logging",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "mlir",
|
|
srcs = ["mlir.cc"],
|
|
hdrs = ["mlir.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
"@com_google_absl//absl/log",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings:string_view",
|
|
"@llvm-project//llvm:Support",
|
|
"@llvm-project//mlir:BytecodeWriter",
|
|
"@llvm-project//mlir:IR",
|
|
"@llvm-project//mlir:Parser",
|
|
"@llvm-project//mlir:Pass",
|
|
"@llvm-project//mlir:ReconcileUnrealizedCasts",
|
|
"@llvm-project//mlir:Support",
|
|
"@nanobind",
|
|
"@stablehlo//:stablehlo_serialization",
|
|
"@xla//xla/hlo/builder:xla_computation",
|
|
"@xla//xla/hlo/translate:stablehlo",
|
|
"@xla//xla/mlir_hlo:mhlo_passes",
|
|
"@xla//xla/pjrt:mlir_to_hlo",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/python:refine_polymorphic_shapes",
|
|
"@xla//xla/tsl/platform:errors",
|
|
"@xla//xla/tsl/platform:logging",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "nb_class_ptr",
|
|
hdrs = ["nb_class_ptr.h"],
|
|
copts = ["-fexceptions"],
|
|
features = ["-use_header_modules"],
|
|
visibility = jax_visibility("jaxlib/xla/nb_class_ptr"),
|
|
deps = ["@nanobind"],
|
|
)
|
|
|
|
cc_library(
|
|
name = "pjit",
|
|
srcs = ["pjit.cc"],
|
|
hdrs = ["pjit.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":config",
|
|
":guard_lib",
|
|
":jax_jit",
|
|
":nb_class_ptr",
|
|
":py_client",
|
|
":python_ref_manager",
|
|
":pytree",
|
|
":traceback",
|
|
"@com_google_absl//absl/base:core_headers",
|
|
"@com_google_absl//absl/cleanup",
|
|
"@com_google_absl//absl/container:flat_hash_map",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/hash",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@com_google_absl//absl/synchronization",
|
|
"@com_google_absl//absl/types:span",
|
|
"@nanobind",
|
|
"@tsl//tsl/profiler/lib:traceme",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla:shape_util",
|
|
"@xla//xla:util",
|
|
"@xla//xla/pjrt:exceptions",
|
|
"@xla//xla/pjrt:lru_cache",
|
|
"@xla//xla/python:nb_helpers",
|
|
"@xla//xla/python:nb_numpy",
|
|
"@xla//xla/python/ifrt",
|
|
"@xla//xla/tsl/concurrency:ref_count",
|
|
"@xla//xla/tsl/platform:env",
|
|
"@xla//xla/tsl/platform:errors",
|
|
"@xla//xla/tsl/platform:logging",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "pmap_lib",
|
|
srcs = ["pmap_lib.cc"],
|
|
hdrs = ["pmap_lib.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":config",
|
|
":jax_jit",
|
|
":nb_class_ptr",
|
|
":py_client",
|
|
":python_ref_manager",
|
|
":pytree",
|
|
":traceback",
|
|
"@com_google_absl//absl/container:flat_hash_map",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/hash",
|
|
"@com_google_absl//absl/log:check",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@com_google_absl//absl/synchronization",
|
|
"@com_google_absl//absl/types:span",
|
|
"@nanobind",
|
|
"@tsl//tsl/profiler/lib:traceme",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla:status_macros",
|
|
"@xla//xla:util",
|
|
"@xla//xla:xla_data_proto_cc",
|
|
"@xla//xla/pjrt:exceptions",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/python:nb_helpers",
|
|
"@xla//xla/python:nb_numpy",
|
|
"@xla//xla/python:types",
|
|
"@xla//xla/python/ifrt",
|
|
"@xla//xla/tsl/concurrency:ref_count",
|
|
"@xla//xla/tsl/platform:env",
|
|
"@xla//xla/tsl/platform:logging",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
"@xla//xla/tsl/python/lib/core:numpy",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "py_client",
|
|
srcs = [
|
|
"py_array.cc",
|
|
"py_client.cc",
|
|
"py_compile_only_client.cc",
|
|
"py_device.cc",
|
|
"py_device_list.cc",
|
|
"py_executable.cc",
|
|
"py_memory_space.cc",
|
|
"py_program.cc",
|
|
"py_values.cc",
|
|
"sharding.cc",
|
|
"to_ifrt_sharding.cc",
|
|
],
|
|
hdrs = [
|
|
"py_array.h",
|
|
"py_client.h",
|
|
"py_compile_only_client.h",
|
|
"py_device.h",
|
|
"py_device_list.h",
|
|
"py_executable.h",
|
|
"py_memory_space.h",
|
|
"py_program.h",
|
|
"py_values.h",
|
|
"sharded_device_array.h",
|
|
"sharding.h",
|
|
"to_ifrt_sharding.h",
|
|
],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
visibility = jax_visibility("jaxlib/xla/py_client"),
|
|
deps = [
|
|
":guard_lib",
|
|
":nb_class_ptr",
|
|
":py_client_cpu",
|
|
":py_host_callback",
|
|
":python_ref_manager",
|
|
":traceback",
|
|
":util",
|
|
"@com_google_absl//absl/algorithm:container",
|
|
"@com_google_absl//absl/base",
|
|
"@com_google_absl//absl/container:flat_hash_map",
|
|
"@com_google_absl//absl/container:flat_hash_set",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/functional:any_invocable",
|
|
"@com_google_absl//absl/hash",
|
|
"@com_google_absl//absl/log",
|
|
"@com_google_absl//absl/log:check",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/strings:cord",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@com_google_absl//absl/synchronization",
|
|
"@com_google_absl//absl/types:span",
|
|
"@llvm-project//llvm:Support",
|
|
"@llvm-project//mlir:IR",
|
|
"@llvm-project//mlir:Pass",
|
|
"@nanobind",
|
|
"@tsl//tsl/platform:fingerprint",
|
|
"@tsl//tsl/platform:ml_dtypes",
|
|
"@tsl//tsl/profiler/lib:traceme",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla:literal",
|
|
"@xla//xla:shape_util",
|
|
"@xla//xla:status_macros",
|
|
"@xla//xla:types",
|
|
"@xla//xla:util",
|
|
"@xla//xla:xla_data_proto_cc",
|
|
"@xla//xla/hlo/ir:hlo",
|
|
"@xla//xla/pjrt:exceptions",
|
|
"@xla//xla/pjrt:lru_cache",
|
|
"@xla//xla/pjrt:mlir_to_hlo",
|
|
"@xla//xla/pjrt:pjrt_client",
|
|
"@xla//xla/pjrt:pjrt_compiler",
|
|
"@xla//xla/pjrt:pjrt_executable",
|
|
"@xla//xla/pjrt:pjrt_future",
|
|
"@xla//xla/pjrt:pjrt_layout",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/python:nb_absl_span",
|
|
"@xla//xla/python:nb_helpers",
|
|
"@xla//xla/python:nb_numpy",
|
|
"@xla//xla/python:pprof_profile_builder",
|
|
"@xla//xla/python:types",
|
|
"@xla//xla/python/compile_only_ifrt:client",
|
|
"@xla//xla/python/ifrt",
|
|
"@xla//xla/python/ifrt:attribute_map",
|
|
"@xla//xla/python/ifrt:custom_call_program",
|
|
"@xla//xla/python/ifrt:plugin_program",
|
|
"@xla//xla/python/ifrt:plugin_program_serdes",
|
|
"@xla//xla/python/ifrt:user_context",
|
|
"@xla//xla/python/ifrt/hlo:hlo_program",
|
|
"@xla//xla/python/pjrt_ifrt",
|
|
"@xla//xla/python/pjrt_ifrt:pjrt_dtype",
|
|
"@xla//xla/python/pjrt_ifrt:xla_ifrt",
|
|
"@xla//xla/service:platform_util",
|
|
"@xla//xla/tsl/concurrency:ref_count",
|
|
"@xla//xla/tsl/framework:allocator",
|
|
"@xla//xla/tsl/platform:env",
|
|
"@xla//xla/tsl/platform:errors",
|
|
"@xla//xla/tsl/platform:logging",
|
|
"@xla//xla/tsl/platform:status",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
"@xla//xla/tsl/python/lib/core:numpy",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "py_client_cpu",
|
|
srcs = ["py_client_cpu.cc"],
|
|
hdrs = ["py_client_cpu.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
"@com_google_absl//absl/algorithm:container",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@com_google_absl//absl/strings:string_view",
|
|
"@com_google_absl//absl/types:span",
|
|
"@nanobind",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla:shape_util",
|
|
"@xla//xla:util",
|
|
"@xla//xla:xla_data_proto_cc",
|
|
"@xla//xla/ffi:ffi_api",
|
|
"@xla//xla/ffi/api:ffi",
|
|
"@xla//xla/pjrt:host_callback",
|
|
"@xla//xla/pjrt:transpose",
|
|
"@xla//xla/python:nb_numpy",
|
|
"@xla//xla/python:types",
|
|
],
|
|
alwayslink = 1,
|
|
)
|
|
|
|
cc_library(
|
|
name = "py_host_callback",
|
|
srcs = ["py_host_callback.cc"],
|
|
hdrs = ["py_host_callback.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":callback",
|
|
":py_host_callback_cc_proto",
|
|
":python_ref_manager",
|
|
"@com_google_absl//absl/algorithm:container",
|
|
"@com_google_absl//absl/log:check",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/types:span",
|
|
"@llvm-project//llvm:Support",
|
|
"@nanobind",
|
|
"@xla//xla:shape_util",
|
|
"@xla//xla:status_macros",
|
|
"@xla//xla:util",
|
|
"@xla//xla:xla_data_proto_cc",
|
|
"@xla//xla/pjrt:host_callback",
|
|
"@xla//xla/python:types",
|
|
"@xla//xla/python/ifrt",
|
|
"@xla//xla/python/pjrt_ifrt",
|
|
"@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc",
|
|
"@xla//xla/tsl/concurrency:ref_count",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
],
|
|
)
|
|
|
|
proto_library(
|
|
name = "py_host_callback_proto",
|
|
srcs = ["py_host_callback.proto"],
|
|
)
|
|
|
|
cc_proto_library(
|
|
name = "py_host_callback_cc_proto",
|
|
visibility = jax_visibility("jaxlib/xla/py_host_callback_cc_proto"),
|
|
deps = [":py_host_callback_proto"],
|
|
)
|
|
|
|
cc_library(
|
|
name = "py_socket_transfer",
|
|
srcs = ["py_socket_transfer.cc"],
|
|
hdrs = ["py_socket_transfer.h"],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":nb_class_ptr",
|
|
":py_client",
|
|
":traceback",
|
|
"@com_google_absl//absl/container:flat_hash_map",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/synchronization",
|
|
"@llvm-project//llvm:Support",
|
|
"@nanobind",
|
|
"@tsl//tsl/platform:casts",
|
|
"@xla//xla:util",
|
|
"@xla//xla/pjrt:pjrt_client",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/python:nb_numpy",
|
|
"@xla//xla/python:types",
|
|
"@xla//xla/python/ifrt",
|
|
"@xla//xla/python/pjrt_ifrt",
|
|
"@xla//xla/python/pjrt_ifrt:pjrt_dtype",
|
|
"@xla//xla/python/transfer:event_loop",
|
|
"@xla//xla/python/transfer:socket-server",
|
|
"@xla//xla/python/transfer:socket_bulk_transport",
|
|
"@xla//xla/python/transfer:streaming",
|
|
"@xla//xla/python/transfer:streaming_ifrt",
|
|
"@xla//xla/python/transfer:transfer_socket_proto_cc",
|
|
"@xla//xla/tsl/concurrency:ref_count",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "python_ref_manager",
|
|
srcs = ["python_ref_manager.cc"],
|
|
hdrs = ["python_ref_manager.h"],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
visibility = jax_visibility("jaxlib/xla/python_ref_manager"),
|
|
deps = [
|
|
"@com_google_absl//absl/base:core_headers",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/synchronization",
|
|
"@com_google_absl//absl/types:span",
|
|
"@nanobind",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
],
|
|
)
|
|
|
|
proto_library(
|
|
name = "pytree_proto",
|
|
srcs = ["pytree.proto"],
|
|
)
|
|
|
|
cc_proto_library(
|
|
name = "pytree_cc_proto",
|
|
deps = [":pytree_proto"],
|
|
)
|
|
|
|
cc_library(
|
|
name = "pytree",
|
|
srcs = ["pytree.cc"],
|
|
hdrs = ["pytree.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
visibility = jax_visibility("jaxlib/xla/pytree"),
|
|
deps = [
|
|
":nb_class_ptr",
|
|
":pytree_cc_proto",
|
|
"@com_google_absl//absl/algorithm:container",
|
|
"@com_google_absl//absl/container:flat_hash_map",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/hash",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@com_google_absl//absl/types:span",
|
|
"@nanobind",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla/pjrt:exceptions",
|
|
"@xla//xla/tsl/platform:logging",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "sdy",
|
|
srcs = ["sdy.cc"],
|
|
hdrs = ["sdy.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/status:statusor",
|
|
"@com_google_absl//absl/strings:string_view",
|
|
"@llvm-project//llvm:Support",
|
|
"@llvm-project//mlir:BytecodeWriter",
|
|
"@llvm-project//mlir:IR",
|
|
"@llvm-project//mlir:Pass",
|
|
"@llvm-project//mlir:Support",
|
|
"@nanobind",
|
|
"@shardy//shardy/dialect/sdy/ir:dialect",
|
|
"@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
|
|
"@xla//xla/mlir_hlo:all_passes",
|
|
"@xla//xla/pjrt:mlir_to_hlo",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/service/spmd/shardy:constants",
|
|
"@xla//xla/service/spmd/shardy:utils",
|
|
"@xla//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs",
|
|
"@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines",
|
|
"@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "traceback",
|
|
srcs = ["traceback.cc"],
|
|
hdrs = ["traceback.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
visibility = jax_visibility("jaxlib/xla/traceback"),
|
|
deps = [
|
|
":nb_class_ptr",
|
|
"@com_google_absl//absl/base",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/hash",
|
|
"@com_google_absl//absl/log:check",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@nanobind",
|
|
"@tsl//tsl/platform",
|
|
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
|
"@xla//xla/pjrt:exceptions",
|
|
"@xla//xla/tsl/platform:logging",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "util",
|
|
srcs = ["util.cc"],
|
|
hdrs = ["util.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/types:span",
|
|
"@xla//xla:util",
|
|
"@xla//xla/python/ifrt",
|
|
"@xla//xla/tsl/concurrency:ref_count",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "weakref_lru_cache",
|
|
srcs = ["weakref_lru_cache.cc"],
|
|
hdrs = ["weakref_lru_cache.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
"@com_google_absl//absl/base:core_headers",
|
|
"@com_google_absl//absl/cleanup",
|
|
"@com_google_absl//absl/hash",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/synchronization",
|
|
"@nanobind",
|
|
"@xla//third_party/python_runtime:headers",
|
|
"@xla//xla/pjrt:lru_cache",
|
|
"@xla//xla/tsl/platform:logging",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "xla_compiler",
|
|
srcs = ["xla_compiler.cc"],
|
|
hdrs = ["xla_compiler.h"],
|
|
compatible_with = [],
|
|
copts = [
|
|
"-fexceptions",
|
|
"-fno-strict-aliasing",
|
|
],
|
|
features = ["-use_header_modules"],
|
|
deps = [
|
|
":dlpack",
|
|
":py_client",
|
|
"@com_google_absl//absl/base:core_headers",
|
|
"@com_google_absl//absl/container:inlined_vector",
|
|
"@com_google_absl//absl/hash",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/strings",
|
|
"@com_google_absl//absl/strings:str_format",
|
|
"@com_google_absl//absl/synchronization",
|
|
"@com_google_absl//absl/types:span",
|
|
"@nanobind",
|
|
"@xla//xla:array",
|
|
"@xla//xla:debug_options_flags",
|
|
"@xla//xla:literal",
|
|
"@xla//xla:shape_util",
|
|
"@xla//xla:util",
|
|
"@xla//xla:xla_data_proto_cc",
|
|
"@xla//xla:xla_proto_cc",
|
|
"@xla//xla/client:executable_build_options",
|
|
"@xla//xla/ffi",
|
|
"@xla//xla/ffi:ffi_api",
|
|
"@xla//xla/ffi/api:c_api",
|
|
"@xla//xla/hlo/builder:xla_builder",
|
|
"@xla//xla/hlo/builder:xla_computation",
|
|
"@xla//xla/hlo/ir:hlo",
|
|
"@xla//xla/hlo/ir:hlo_module_group",
|
|
"@xla//xla/hlo/parser:hlo_parser",
|
|
"@xla//xla/hlo/pass:hlo_pass",
|
|
"@xla//xla/hlo/transforms/simplifiers:flatten_call_graph",
|
|
"@xla//xla/hlo/transforms/simplifiers:hlo_dce",
|
|
"@xla//xla/hlo/transforms/simplifiers:tuple_simplifier",
|
|
"@xla//xla/pjrt:compile_options_proto_cc",
|
|
"@xla//xla/pjrt:exceptions",
|
|
"@xla//xla/pjrt:pjrt_executable",
|
|
"@xla//xla/pjrt:status_casters",
|
|
"@xla//xla/python:nb_absl_span",
|
|
"@xla//xla/python:nb_helpers",
|
|
"@xla//xla/python:nb_numpy",
|
|
"@xla//xla/python:types",
|
|
"@xla//xla/service:call_inliner",
|
|
"@xla//xla/service:computation_placer",
|
|
"@xla//xla/service:custom_call_target_registry",
|
|
"@xla//xla/service:hlo_graph_dumper",
|
|
"@xla//xla/service:hlo_module_config",
|
|
"@xla//xla/service:hlo_proto_cc",
|
|
"@xla//xla/service:name_uniquer",
|
|
"@xla//xla/tsl/lib/strings:proto_serialization",
|
|
"@xla//xla/tsl/platform:env",
|
|
"@xla//xla/tsl/platform:errors",
|
|
"@xla//xla/tsl/platform:logging",
|
|
"@xla//xla/tsl/platform:statusor",
|
|
],
|
|
)
|
|
|
|
pytype_strict_library(
|
|
name = "xla_client",
|
|
srcs = ["xla_client.py"],
|
|
pytype_srcs = ["xla_client.pyi"],
|
|
visibility = [":xla_python"],
|
|
deps = py_deps([
|
|
"numpy",
|
|
"ml_dtypes",
|
|
]) + [":xla_extension"],
|
|
)
|
|
|
|
py_strict_test(
|
|
name = "xla_client_backend_independent_test",
|
|
srcs = ["xla_client_backend_independent_test.py"],
|
|
deps = [
|
|
":xla_client",
|
|
] + py_deps([
|
|
"absl/testing",
|
|
"numpy",
|
|
"portpicker",
|
|
]),
|
|
)
|
|
|
|
py_strict_library(
|
|
name = "xla_client_test",
|
|
testonly = 1,
|
|
srcs = ["xla_client_test.py"],
|
|
visibility = [":xla_python"],
|
|
deps = [
|
|
":xla_client",
|
|
"//jax",
|
|
"//jax:test_util",
|
|
"//jaxlib",
|
|
] + py_deps([
|
|
"absl/flags",
|
|
"absl/logging",
|
|
"absl/testing",
|
|
"ml_dtypes",
|
|
"numpy",
|
|
]),
|
|
)
|
|
|
|
nanobind_extension(
|
|
name = "custom_calls_testlib",
|
|
testonly = 1,
|
|
srcs = ["custom_calls_testlib.cc"],
|
|
deps = [
|
|
"@com_google_absl//absl/status",
|
|
"@nanobind",
|
|
"@xla//xla/ffi/api:c_api",
|
|
"@xla//xla/ffi/api:ffi",
|
|
],
|
|
)
|
|
|
|
py_strict_test(
|
|
name = "xla_client_test_cpu",
|
|
srcs = ["xla_client_test.py"],
|
|
args = ["--backend=cpu"],
|
|
env = {
|
|
"XLA_FLAGS": "--xla_force_host_platform_device_count=4",
|
|
},
|
|
main = "xla_client_test.py",
|
|
deps = [
|
|
":custom_calls_testlib",
|
|
":xla_client",
|
|
"//jax",
|
|
"//jax:test_util",
|
|
"//jaxlib",
|
|
] + py_deps([
|
|
"absl/flags",
|
|
"absl/logging",
|
|
"absl/testing",
|
|
"ml_dtypes",
|
|
"numpy",
|
|
]),
|
|
)
|
|
|
|
py_strict_test(
|
|
name = "weakref_lru_cache_test",
|
|
srcs = ["weakref_lru_cache_test.py"],
|
|
deps = [
|
|
":xla_client",
|
|
] + py_deps([
|
|
"absl/flags",
|
|
"absl/logging",
|
|
"absl/testing",
|
|
]),
|
|
)
|
|
|
|
py_strict_test(
|
|
name = "pytree_test",
|
|
srcs = ["pytree_test.py"],
|
|
deps = [
|
|
":xla_client",
|
|
] + py_deps([
|
|
"absl/flags",
|
|
"absl/logging",
|
|
"absl/testing",
|
|
]),
|
|
)
|
|
|
|
py_strict_test(
|
|
name = "config_test",
|
|
srcs = ["config_test.py"],
|
|
deps = [
|
|
":xla_client",
|
|
] + py_deps([
|
|
"absl/flags",
|
|
"absl/logging",
|
|
"absl/testing",
|
|
]),
|
|
)
|
|
|
|
py_strict_test(
|
|
name = "jax_jit_test",
|
|
srcs = ["jax_jit_test.py"],
|
|
deps = [
|
|
":xla_client",
|
|
] + py_deps([
|
|
"absl/flags",
|
|
"absl/logging",
|
|
"absl/testing",
|
|
"numpy",
|
|
]),
|
|
)
|