[JAX] [XLA:Python] Migrate xla_extension and its type stubs into jaxlib.

Future changes will migrate many of its dependent modules.

PiperOrigin-RevId: 739361786
This commit is contained in:
Peter Hawkins 2025-03-21 18:52:12 -07:00 committed by jax authors
parent 2692c5ff98
commit 55e408471c
22 changed files with 3268 additions and 16 deletions

View File

@ -45,6 +45,7 @@ py_library_providing_imports_info(
"//jaxlib:cpu_feature_guard",
"//jaxlib:utils",
"//jaxlib/xla:xla_client",
"//jaxlib/xla:xla_extension",
"//jaxlib/triton",
"//jaxlib/mlir/_mlir_libs:register_jax_dialects",
"//jaxlib/mlir:arithmetic_dialect",
@ -61,6 +62,5 @@ py_library_providing_imports_info(
"//jaxlib/mlir:sparse_tensor_dialect",
"//jaxlib/mlir:stablehlo_dialect",
"//jaxlib/mlir:vector_dialect",
# xla_extension
]),
)

View File

@ -29,13 +29,6 @@ package(
default_visibility = ["//jax:internal"],
)
# This makes xla_extension module accessible from jax._src.lib.
genrule(
name = "xla_extension_py",
outs = ["xla_extension.py"],
cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@",
)
py_library_providing_imports_info(
name = "jaxlib",
srcs = [
@ -51,8 +44,8 @@ py_library_providing_imports_info(
"lapack.py",
"plugin_support.py",
"xla_client.py",
"xla_extension.py",
":version",
":xla_extension_py",
],
data = [":ffi_headers"],
lib_rule = pytype_library,
@ -82,6 +75,7 @@ py_library_providing_imports_info(
"//jaxlib/mosaic",
"//jaxlib/triton",
"//jaxlib/xla:xla_client",
"//jaxlib/xla:xla_extension",
],
)

View File

@ -610,3 +610,12 @@ def jax_py_test(
if "PYTHONWARNINGS" not in env:
env["PYTHONWARNINGS"] = "error"
py_test(name = name, env = env, **kwargs)
def if_oss(oss_value, google_value = []):
"""Returns one of the arguments based on the non-configurable build env.
Specifically, it does not return a `select`, and can be used to e.g.
compute elements of list attributes.
"""
_ = (google_value, oss_value) # buildifier: disable=unused-variable
return oss_value

View File

@ -62,11 +62,11 @@ py_binary(
"//jaxlib",
"//jaxlib:README.md",
"//jaxlib:setup.py",
"//jaxlib/xla:xla_client.py",
"//jaxlib/xla:xla_extension",
"@xla//xla/ffi/api:api.h",
"@xla//xla/ffi/api:c_api.h",
"@xla//xla/ffi/api:ffi.h",
"@xla//xla/python:xla_client.py",
"@xla//xla/python:xla_extension",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]),

View File

@ -110,7 +110,7 @@ def patch_copy_xla_extension_stubs(dst_dir):
xla_extension_dir = os.path.join(dst_dir, "xla_extension")
os.makedirs(xla_extension_dir)
for stub_name in _XLA_EXTENSION_STUBS:
stub_path = r.Rlocation("xla/xla/python/xla_extension/" + stub_name)
stub_path = r.Rlocation("__main__/jaxlib/xla/xla_extension/" + stub_name)
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
continue
@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack():
if not _is_mac():
return
nm = subprocess.run(
["nm", "-g", r.Rlocation("xla/xla/python/xla_extension.so")],
["nm", "-g", r.Rlocation("__main/jaxlib/xla/xla_extension.so")],
capture_output=True,
text=True,
check=False,
@ -198,7 +198,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
"__main__/jaxlib/plugin_support.py",
"__main__/jaxlib/version.py",
"__main__/jaxlib/xla/xla_client.py",
f"xla/xla/python/xla_extension.{pyext}",
f"__main__/jaxlib/xla/xla_extension.{pyext}",
],
)
# This file is required by PEP-561. It marks jaxlib as package containing

View File

@ -14,6 +14,7 @@
load(
"//jaxlib:jax.bzl",
"if_oss",
"nanobind_extension",
"py_deps",
"py_strict_library",
@ -35,6 +36,114 @@ package_group(
],
)
nanobind_extension(
name = "xla_extension",
srcs = ["xla.cc"],
pytype_deps = py_deps(["numpy"]),
pytype_srcs = glob(["xla_extension/*.pyi"]),
visibility = ["//visibility:public"],
deps = [
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log:initialize",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@nanobind",
"@tsl//tsl/platform",
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla:types",
"@xla//xla:util",
"@xla//xla/backends/cpu/collectives:cpu_collectives",
"@xla//xla/ffi:ffi_api",
"@xla//xla/pjrt:exceptions",
"@xla//xla/pjrt:mlir_to_hlo",
"@xla//xla/pjrt:pjrt_api",
"@xla//xla/pjrt:pjrt_c_api_client",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:pjrt_common",
"@xla//xla/pjrt:pjrt_compiler",
"@xla//xla/pjrt:pjrt_executable",
"@xla//xla/pjrt:pjrt_layout",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/distributed",
"@xla//xla/pjrt/distributed:client",
"@xla//xla/pjrt/distributed:key_value_store_interface",
"@xla//xla/pjrt/distributed:protocol_proto_cc",
"@xla//xla/pjrt/distributed:service",
"@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"@xla//xla/python:config",
"@xla//xla/python:custom_call_sharding",
"@xla//xla/python:dlpack",
"@xla//xla/python:guard_lib",
"@xla//xla/python:jax_jit",
"@xla//xla/python:logging",
"@xla//xla/python:mlir",
"@xla//xla/python:nb_absl_flat_hash_map",
"@xla//xla/python:nb_absl_span",
"@xla//xla/python:nb_class_ptr",
"@xla//xla/python:ops",
"@xla//xla/python:pjit",
"@xla//xla/python:pmap_lib",
"@xla//xla/python:pprof_profile_builder",
"@xla//xla/python:profiler",
"@xla//xla/python:py_client",
"@xla//xla/python:python_ref_manager",
"@xla//xla/python:pytree",
"@xla//xla/python:refine_polymorphic_shapes",
"@xla//xla/python:sdy",
"@xla//xla/python:traceback",
"@xla//xla/python:types",
"@xla//xla/python:util",
"@xla//xla/python:weakref_lru_cache",
"@xla//xla/python:xla_compiler",
"@xla//xla/python/ifrt",
"@xla//xla/python/ifrt:plugin_program",
"@xla//xla/python/ifrt:plugin_program_serdes",
"@xla//xla/python/ifrt_proxy/client:py_module",
"@xla//xla/python/pjrt_ifrt",
"@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
"@xla//xla/python/pjrt_ifrt:xla_ifrt",
"@xla//xla/tsl/concurrency:ref_count",
"@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager",
"@xla//xla/tsl/platform:logging",
"@xla//xla/tsl/platform:status",
"@xla//xla/tsl/platform:statusor",
"@xla//xla/tsl/platform/cloud:gcs_file_system",
"@xla//xla/tsl/python/lib/core:numpy",
] + select({
# gloo tcp transport only builds on linux
"@xla//xla/tsl:macos": [
"@gloo//:transport_uv",
"@xla//xla/backends/cpu/collectives:gloo_collectives",
"@xla//xla/backends/cpu/collectives:gloo_kv_store",
],
"@xla//xla/tsl:windows": [],
"//conditions:default": [
"@gloo//:transport_tcp",
"@xla//xla/backends/cpu/collectives:gloo_collectives",
"@xla//xla/backends/cpu/collectives:gloo_kv_store",
"@xla//xla/python/transfer:py_socket_transfer",
],
}) + select({
# mpitrampoline does not build on windows
"@xla//xla/tsl:windows": [],
# we support MPI collectives only in OSS builds
"//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]),
}),
)
pytype_strict_library(
name = "xla_client",
srcs = ["xla_client.py"],
@ -43,7 +152,7 @@ pytype_strict_library(
deps = py_deps([
"numpy",
"ml_dtypes",
]) + ["@xla//xla/python:xla_extension"],
]) + [":xla_extension"],
)
py_strict_test(

965
jaxlib/xla/xla.cc Normal file
View File

@ -0,0 +1,965 @@
/* Copyright 2019 The JAX Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <Python.h>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/hash/hash.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
#include "nanobind/nb_defs.h"
#include "nanobind/stl/function.h" // IWYU pragma: keep
#include "nanobind/stl/optional.h" // IWYU pragma: keep
#include "nanobind/stl/pair.h" // IWYU pragma: keep
#include "nanobind/stl/set.h" // IWYU pragma: keep
#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
#include "nanobind/stl/variant.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/protocol.pb.h"
#include "xla/pjrt/distributed/service.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h"
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
#include "xla/pjrt/status_casters.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/executable.h"
#include "xla/python/ifrt/topology.h"
#include "xla/python/ifrt_proxy/client/py_module.h"
#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h"
#include "xla/python/py_client.h"
#include "xla/python/py_program.h"
#include "xla/python/sdy.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/python/lib/core/numpy.h" // NOLINT
#if defined(__linux__)
#include "gloo/transport/tcp/attr.h"
#include "gloo/transport/tcp/device.h"
#include "xla/backends/cpu/collectives/gloo_collectives.h"
#include "xla/backends/cpu/collectives/gloo_kv_store.h"
#include "xla/python/transfer/py_socket_transfer.h"
#elif defined(__APPLE__)
#include "gloo/transport/uv/device.h"
#include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT
#include "xla/backends/cpu/collectives/gloo_kv_store.h" // NOLINT
#endif // defined(__linux__)
#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE)
#include "xla/backends/cpu/collectives/mpi_collectives.h"
#endif // !_WIN32 && !PLATFORM_GOOGLE
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/exceptions.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/config.h"
#include "xla/python/custom_call_sharding.h"
#include "xla/python/dlpack.h"
#include "xla/python/guard_lib.h"
#include "xla/python/jax_jit.h"
#include "xla/python/logging.h" // IWYU pragma: keep
#include "xla/python/mlir.h"
#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep
#include "xla/python/nb_absl_span.h" // IWYU pragma: keep
#include "xla/python/nb_class_ptr.h"
#include "xla/python/ops.h"
#include "xla/python/pjit.h"
#include "xla/python/pjrt_ifrt/pjrt_client.h"
#include "xla/python/pjrt_ifrt/pjrt_executable.h"
#include "xla/python/pjrt_ifrt/pjrt_topology.h"
#include "xla/python/pmap_lib.h"
#include "xla/python/pprof_profile_builder.h"
#include "xla/python/profiler.h"
#include "xla/python/py_array.h"
#include "xla/python/py_compile_only_client.h"
#include "xla/python/py_device.h"
#include "xla/python/py_device_list.h"
#include "xla/python/py_executable.h"
#include "xla/python/py_memory_space.h"
#include "xla/python/python_ref_manager.h"
#include "xla/python/pytree.h"
#include "xla/python/sharding.h"
#include "xla/python/traceback.h"
#include "xla/python/weakref_lru_cache.h"
#include "xla/python/xla_compiler.h"
#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h"
#include "xla/tsl/platform/status.h"
#include "tsl/platform/platform.h"
// TODO(phawkins): remove host_id properties after JAX is update to avoid them.
namespace xla {
namespace {
namespace nb = nanobind;
bool IsOptimizedBuild() {
#if NDEBUG
return true;
#else
return false;
#endif // NDEBUG
}
// Is*san reports whether the build is under that particular sanitizer.
bool IsAsan() {
#if defined(ADDRESS_SANITIZER)
return true;
#else // defined(ADDRESS_SANITIZER)
return false;
#endif
}
bool IsMsan() {
#if defined(MEMORY_SANITIZER)
return true;
#else // defined(MEMORY_SANITIZER)
return false;
#endif
}
bool IsTsan() {
#if defined(THREAD_SANITIZER)
return true;
#else // defined(THREAD_SANITIZER)
return false;
#endif
}
// IsSanitized reports whether the build is under any sanitizer.
bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); }
} // namespace
NB_MODULE(xla_extension, m) {
// Initialize ABSL logging because code within XLA uses it.
#ifndef PLATFORM_GOOGLE
InitializeAbslLogging();
#endif // PLATFORM_GOOGLE
// We seem to get a fair number of leak warnings from nanobind. It's unclear
// whether these are false positives or not.
nb::set_leak_warnings(false);
tsl::ImportNumpy();
// Exceptions
nb::exception<XlaRuntimeError> xla_runtime_error(m, "XlaRuntimeError",
PyExc_RuntimeError);
xla_runtime_error.attr("__doc__") = nb::str(
"Runtime errors thrown by the JAX runtime. While the JAX runtime may "
"raise other exceptions as well, most exceptions thrown by the runtime "
"are instances of this class.");
// Types
nb::enum_<PrimitiveType>(m, "PrimitiveType", nb::is_arithmetic())
.value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
.value("PRED", PRED)
.value("S4", S4)
.value("S8", S8)
.value("S16", S16)
.value("S32", S32)
.value("S64", S64)
.value("U4", U4)
.value("U8", U8)
.value("U16", U16)
.value("U32", U32)
.value("U64", U64)
.value("F16", F16)
.value("F4E2M1FN", F4E2M1FN)
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
// .value("F8E3M4", F8E3M4)
// .value("F8E4M3", F8E4M3)
.value("F8E8M0FNU", F8E8M0FNU)
.value("F8E4M3FN", F8E4M3FN)
.value("F8E4M3B11FNUZ", F8E4M3B11FNUZ)
.value("F8E4M3FNUZ", F8E4M3FNUZ)
.value("F8E5M2", F8E5M2)
.value("F8E5M2FNUZ", F8E5M2FNUZ)
.value("BF16", BF16)
.value("F32", F32)
.value("F64", F64)
.value("C64", C64)
.value("C128", C128)
.value("TUPLE", TUPLE)
.value("OPAQUE_TYPE", OPAQUE_TYPE)
.value("TOKEN", TOKEN);
// Must be before PyClient.compile.
BuildXlaCompilerSubmodule(m);
PyDevice::RegisterPythonType(m);
PyMemorySpace::RegisterPythonType(m);
PyClient::RegisterPythonTypes(m);
nb::enum_<ifrt::ArrayCopySemantics>(m, "ArrayCopySemantics",
nb::is_arithmetic())
.value("ALWAYS_COPY", ifrt::ArrayCopySemantics::kAlwaysCopy)
.value("REUSE_INPUT", ifrt::ArrayCopySemantics::kReuseInput)
.value("DONATE_INPUT", ifrt::ArrayCopySemantics::kDonateInput);
nb::class_<PjRtLayout>(m, "PjRtLayout")
.def("__str__", &PjRtLayout::ToString)
.def("__eq__", [](const PjRtLayout& layout,
const PjRtLayout& other) { return layout == other; })
.def("__hash__",
[](const PjRtLayout& layout) { return absl::HashOf(layout); })
.def("_xla_layout", &PjRtLayout::xla_layout)
.def("__getstate__",
[](const PjRtLayout& layout) -> nb::tuple {
absl::StatusOr<std::string> serialized = layout.Serialize();
ThrowIfError(serialized.status());
return nb::make_tuple(
nb::bytes(serialized->data(), serialized->size()));
})
.def("__setstate__", [](PjRtLayout* self, nb::tuple t) {
nb::bytes serialized = nb::cast<nb::bytes>(t[0]);
absl::StatusOr<std::shared_ptr<const PjRtLayout>> layout =
PjRtLayout::Deserialize(
absl::string_view(serialized.c_str(), serialized.size()));
ThrowIfError(layout.status());
new (self) PjRtLayout((*layout)->xla_layout());
});
jax::BuildWeakrefLRUCacheAPI(m);
nb::class_<xla::cpu::CpuCollectives> cpu_collectives(m, "CpuCollectives");
m.def(
"make_gloo_tcp_collectives",
[](std::shared_ptr<DistributedRuntimeClient> distributed_client,
std::optional<std::string> hostname,
std::optional<std::string> interface)
-> std::shared_ptr<xla::cpu::CpuCollectives> {
#if defined(__linux__)
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
if (distributed_client != nullptr) {
kv_store = GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"cpu:");
}
auto gloo_kv_store = std::make_unique<cpu::GlooKeyValueStore>(kv_store);
auto tcp_attrs = gloo::transport::tcp::attr();
if (hostname) {
tcp_attrs.hostname = *hostname;
}
if (interface) {
tcp_attrs.iface = *interface;
}
auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs);
return std::make_shared<cpu::GlooCollectives>(std::move(gloo_kv_store),
std::move(tcp_device));
#elif defined(__APPLE__)
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
if (distributed_client != nullptr) {
kv_store = GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"cpu:");
}
auto gloo_kv_store = std::make_unique<cpu::GlooKeyValueStore>(kv_store);
auto uv_attrs = gloo::transport::uv::attr();
if (hostname) {
uv_attrs.hostname = *hostname;
}
if (interface) {
uv_attrs.iface = *interface;
}
auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs);
return std::make_shared<cpu::GlooCollectives>(std::move(gloo_kv_store),
std::move(uv_device));
#else // defined(__linux__)
throw xla::XlaRuntimeError(
"make_gloo_tcp_collectives only implemented for linux and macos");
#endif // defined(__linux__)
},
nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt,
nb::arg("interface").none() = std::nullopt);
#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE)
nb::class_<cpu::MpiCollectives> mpi_collectives(m, "MpiCollectives",
cpu_collectives);
mpi_collectives.def("Init", &cpu::MpiCollectives::Init);
mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize);
m.def("make_mpi_collectives", []() -> std::shared_ptr<cpu::MpiCollectives> {
return std::make_shared<cpu::MpiCollectives>();
});
#else // !_WIN32 && !PLATFORM_GOOGLE
m.def("make_mpi_collectives",
[]() -> std::shared_ptr<xla::cpu::CpuCollectives> {
throw xla::XlaRuntimeError(
"make_mpi_collectives is not implemented for Windows");
});
#endif // !_WIN32 && !PLATFORM_GOOGLE
m.def(
"get_tfrt_cpu_client",
[](bool asynchronous,
std::shared_ptr<DistributedRuntimeClient> distributed_client,
int node_id, int num_nodes,
std::shared_ptr<xla::cpu::CpuCollectives> collectives,
std::optional<int> num_devices) -> nb_class_ptr<PyClient> {
std::unique_ptr<ifrt::PjRtClient> ifrt_client;
{
nb::gil_scoped_release gil_release;
xla::CpuClientOptions options;
options.asynchronous = asynchronous;
options.collectives = std::move(collectives);
options.process_id = node_id;
options.cpu_device_count = num_devices;
std::unique_ptr<PjRtClient> client =
xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options)));
ifrt::PjRtClient::CreateOptions ifrt_options;
ifrt_options.pjrt_client =
std::shared_ptr<PjRtClient>(std::move(client));
if (distributed_client != nullptr) {
ifrt_options.kv_store =
GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"cpu:");
ifrt_options.process_id = node_id;
ifrt_options.num_processes = num_nodes;
}
ifrt_client =
ValueOrThrow(ifrt::PjRtClient::Create(std::move(ifrt_options)));
}
return PyClient::Make(std::move(ifrt_client));
},
nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr,
nb::arg("node_id") = 0, nb::arg("num_nodes") = 1,
nb::arg("collectives").none() =
std::shared_ptr<xla::cpu::CpuCollectives>(),
nb::arg("num_devices").none() = std::nullopt);
m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool {
absl::StatusOr<const PJRT_Api*> pjrt_api = pjrt::PjrtApi(platform_name);
return pjrt_api.ok();
});
m.def(
"load_pjrt_plugin",
[](std::string platform_name, std::optional<std::string> library_path,
std::optional<nb::capsule> c_api) -> nb::capsule {
if (library_path.has_value()) {
const PJRT_Api* api = xla::ValueOrThrow(
pjrt::LoadPjrtPlugin(platform_name, *library_path));
return nb::capsule(absl::bit_cast<void*>(api), "pjrt_c_api");
}
if (absl::string_view(c_api->name()) != "pjrt_c_api") {
throw nb::value_error(
"c_api argument to load_pjrt_plugin is not a pjrt_c_api "
"capsule.");
}
xla::ThrowIfError(pjrt::SetPjrtApi(
platform_name, static_cast<const PJRT_Api*>(c_api->data())));
return *c_api;
},
nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt,
nb::arg("c_api").none() = std::nullopt);
m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool {
return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name));
});
m.def("initialize_pjrt_plugin", [](std::string platform_name) {
return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name));
});
m.def(
"get_c_api_client",
[](std::string platform_name,
const absl::flat_hash_map<std::string, PjRtValueType>& options,
std::shared_ptr<DistributedRuntimeClient> distributed_client)
-> nb_class_ptr<PyClient> {
std::unique_ptr<ifrt::PjRtClient> ifrt_client;
{
nb::gil_scoped_release gil_release;
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
if (distributed_client != nullptr) {
kv_store = GetDistributedKeyValueStore(
distributed_client,
/*key_prefix=*/absl::StrCat(platform_name, ":"));
}
std::unique_ptr<PjRtClient> c_api_client = xla::ValueOrThrow(
GetCApiClient(platform_name, options, kv_store));
ifrt_client = ifrt::PjRtClient::Create(std::move(c_api_client));
}
return PyClient::Make(std::move(ifrt_client));
},
nb::arg("platform_name"),
nb::arg("options") = absl::flat_hash_map<std::string, PjRtValueType>(),
nb::arg("distributed_client").none() = nullptr);
// TODO(b/322357665): Delete this method after TPU plugin changes to use the
// standard registration.
m.def("get_default_c_api_topology",
[](std::string platform_name, std::string topology_name,
const absl::flat_hash_map<std::string, PjRtValueType>& options)
-> std::shared_ptr<ifrt::Topology> {
return std::make_shared<ifrt::PjRtTopology>(xla::ValueOrThrow(
GetCApiTopology(platform_name, topology_name, options)));
});
m.def("get_c_api_topology",
[](nb::capsule c_api, std::string topology_name,
const absl::flat_hash_map<std::string, PjRtValueType>& options)
-> std::shared_ptr<ifrt::Topology> {
if (absl::string_view(c_api.name()) != "pjrt_c_api") {
throw nb::value_error(
"Argument to get_c_api_topology was not a pjrt_c_api capsule.");
}
return std::make_shared<ifrt::PjRtTopology>(xla::ValueOrThrow(
GetCApiTopology(static_cast<const PJRT_Api*>(c_api.data()),
topology_name, options)));
});
m.def("get_topology_for_devices",
[](const std::vector<nb_class_ptr<PyDevice>>& py_devices) {
if (py_devices.empty()) {
throw nb::value_error(
"get_topology_for_devices requires >= 1 devices.");
}
auto client = py_devices[0]->client();
absl::InlinedVector<ifrt::Device*, 1> ifrt_devices;
ifrt_devices.reserve(py_devices.size());
for (const auto& py_device : py_devices) {
if (py_device->client().get() != client.get()) {
throw nb::value_error(
"devices passed to get_topology_for_devices come from "
"different clients.");
}
ifrt_devices.push_back(py_device->device());
}
ifrt::DeviceListRef device_list =
client->ifrt_client()->MakeDeviceList(ifrt_devices);
return xla::ValueOrThrow(
client->ifrt_client()->GetTopologyForDevices(device_list));
});
TF_CHECK_OK(PyArray::RegisterTypes(m));
jax::PyDeviceList::Register(m);
jax::RegisterSharding(m);
nb::class_<CompiledMemoryStats>(m, "CompiledMemoryStats")
.def_rw("generated_code_size_in_bytes",
&CompiledMemoryStats::generated_code_size_in_bytes)
.def_rw("argument_size_in_bytes",
&CompiledMemoryStats::argument_size_in_bytes)
.def_rw("output_size_in_bytes",
&CompiledMemoryStats::output_size_in_bytes)
.def_rw("alias_size_in_bytes", &CompiledMemoryStats::alias_size_in_bytes)
.def_rw("temp_size_in_bytes", &CompiledMemoryStats::temp_size_in_bytes)
.def_rw("host_generated_code_size_in_bytes",
&CompiledMemoryStats::host_generated_code_size_in_bytes)
.def_rw("host_argument_size_in_bytes",
&CompiledMemoryStats::host_argument_size_in_bytes)
.def_rw("host_output_size_in_bytes",
&CompiledMemoryStats::host_output_size_in_bytes)
.def_rw("host_alias_size_in_bytes",
&CompiledMemoryStats::host_alias_size_in_bytes)
.def_rw("host_temp_size_in_bytes",
&CompiledMemoryStats::host_temp_size_in_bytes)
.def_prop_ro("serialized_hlo_proto",
[](const CompiledMemoryStats& cms) -> nb::bytes {
return nb::bytes(cms.serialized_hlo_proto.data(),
cms.serialized_hlo_proto.size());
})
.def("__str__", &CompiledMemoryStats::DebugString);
nb::class_<PyExecuteResults>(m, "ExecuteResults")
.def("__len__", [](PyExecuteResults& results) { return results.Size(); })
.def("disassemble_into_single_device_arrays",
&PyExecuteResults::DisassembleIntoSingleDeviceArrays)
.def("disassemble_prefix_into_single_device_arrays",
&PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays)
.def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers)
.def("consume_token", &PyExecuteResults::ConsumeToken);
nb::class_<PyLoadedExecutable>(m, "LoadedExecutable")
.def_prop_ro("client", &PyLoadedExecutable::client)
.def("local_devices", &PyLoadedExecutable::AddressableDevices)
.def("size_of_generated_code_in_bytes",
&PyLoadedExecutable::SizeOfGeneratedCodeInBytes)
.def(
"get_compiled_memory_stats",
xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats))
.def("delete", &PyLoadedExecutable::Delete)
.def("execute_sharded_on_local_devices",
xla::ValueOrThrowWrapper(
&PyLoadedExecutable::ExecuteShardedOnLocalDevices),
nb::arg("arguments"))
.def("execute_sharded_on_local_devices_with_tokens",
xla::ValueOrThrowWrapper(
&PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens),
nb::arg("arguments"))
// TODO(parkers): Switch execute_sharded_on_local_devices* to this.
.def("execute_sharded",
xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded),
nb::arg("arguments"), nb::arg("with_tokens") = false)
.def("hlo_modules", ValueOrThrowWrapper(&PyLoadedExecutable::HloModules))
.def("get_output_memory_kinds",
xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds))
.def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings)
.def("get_parameter_layouts",
xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetParameterLayouts))
.def("get_output_layouts",
xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputLayouts))
.def("get_parameter_shardings",
&PyLoadedExecutable::GetParameterShardings)
.def("keep_alive", &PyLoadedExecutable::KeepAlive)
.def("cost_analysis",
[](const PyLoadedExecutable& self) {
auto map = ValueOrThrow(self.GetCostAnalysis());
return ifrt::ToPjRtAttributeMap(std::move(map));
})
.def_prop_ro("traceback", &PyLoadedExecutable::traceback)
.def_prop_ro("fingerprint", [](PyLoadedExecutable* exec) -> nb::object {
if (exec->fingerprint().has_value()) {
return nb::bytes(exec->fingerprint()->data(),
exec->fingerprint()->size());
} else {
return nb::none();
}
});
nb::class_<PyToken> token(m, "Token");
token.def("block_until_ready",
[](PyToken& self) { xla::ThrowIfError(self.Await()); });
nb::class_<PyShardedToken> sharded_token(m, "ShardedToken");
sharded_token.def("block_until_ready", [](PyShardedToken& self) {
xla::ThrowIfError(self.Await());
});
sharded_token.def("get_token", &PyShardedToken::GetPyToken);
m.def("buffer_to_dlpack_managed_tensor",
xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor),
nb::arg("buffer"), nb::arg("stream").none() = nb::none());
m.def(
"dlpack_managed_tensor_to_buffer",
[](const nb::capsule& tensor, nb_class_ptr<PyDevice> device,
std::optional<std::intptr_t> stream) {
return xla::ValueOrThrow(DLPackManagedTensorToBuffer(
tensor, device->device(), device->client(), stream));
},
nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none());
// Legacy overload
m.def(
"dlpack_managed_tensor_to_buffer",
[](const nb::capsule& tensor,
std::optional<nb_class_ptr<PyClient>> cpu_client,
std::optional<nb_class_ptr<PyClient>> gpu_client) {
return xla::ValueOrThrow(DLPackManagedTensorToBuffer(
tensor, std::move(cpu_client), std::move(gpu_client)));
},
nb::arg("dlpack"), nb::arg("cpu_backend").none() = nb::none(),
nb::arg("gpu_backend").none() = nb::none());
m.def("cuda_array_interface_to_buffer",
xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"),
nb::arg("gpu_backend").none() = nb::none(),
nb::arg("device_id").none() = nb::none());
jax::BuildConfigSubmodule(m);
BuildIfrtProgramsSubmodule(m);
BuildProfilerSubmodule(m);
BuildOpsSubmodule(m);
BuildPytreeSubmodule(m);
jax::BuildGuardSubmodule(m);
jax::BuildJaxjitSubmodule(m);
jax::BuildPmapSubmodule(m);
jax::BuildPjitSubmodule(m);
BuildTracebackSubmodule(m);
BuildMlirSubmodule(m);
BuildSdySubmodule(m);
BuildCustomCallShardingPybindAPI(m);
#if defined(__linux__)
aux::RegisterTransferServerTypes(m);
#endif // defined(__linux__)
// The following uses python bindings for PyClient defined above using
// pybind11, and hence needs pybind11::module_ (not just nanobind::module_).
xla::ifrt::proxy::BuildIfrtProxySubmodule(m);
nb::class_<tsl::PreemptionSyncManager> preemption_sync_manager(
m, "PreemptionSyncManager");
preemption_sync_manager
.def(
"initialize",
[](tsl::PreemptionSyncManager& manager,
DistributedRuntimeClient* client) {
tsl::CoordinationServiceAgent* agent =
xla::ValueOrThrow(client->GetCoordinationServiceAgent());
xla::ThrowIfError(manager.Initialize(agent));
},
nb::arg("distributed_client"))
.def("reached_sync_point",
[](tsl::PreemptionSyncManager& manager, int step_counter) {
return manager.ReachedSyncPoint(step_counter);
});
m.def("create_preemption_sync_manager",
[]() { return tsl::CreatePreemptionSyncManager(); });
nb::class_<DistributedRuntimeService> distributed_runtime_service(
m, "DistributedRuntimeService");
distributed_runtime_service.def("shutdown",
&DistributedRuntimeService::Shutdown,
nb::call_guard<nb::gil_scoped_release>());
nb::class_<DistributedRuntimeClient> distributed_runtime_client(
m, "DistributedRuntimeClient");
distributed_runtime_client
.def("connect",
[](DistributedRuntimeClient& self) {
nb::gil_scoped_release gil_release;
xla::ThrowIfError(self.Connect());
})
.def("shutdown",
[](DistributedRuntimeClient& self) {
nb::gil_scoped_release gil_release;
xla::ThrowIfError(self.Shutdown());
})
// This method assumes that the value is a Python string. Use
// `blocking_key_value_get_bytes()` if key_value_set() was called with a
// Python bytes object as its value.
.def(
"blocking_key_value_get",
[](DistributedRuntimeClient& client, std::string key,
int64_t timeout_in_ms) {
nb::gil_scoped_release gil_release;
return xla::ValueOrThrow(client.BlockingKeyValueGet(
key, absl::Milliseconds(timeout_in_ms)));
},
nb::arg("key"), nb::arg("timeout_in_ms"))
// Same as `blocking_key_value_get()`, but retrieves the raw Python byte
// values explicitly.
.def(
"blocking_key_value_get_bytes",
[](DistributedRuntimeClient& client, std::string key,
int64_t timeout_in_ms) -> nb::bytes {
std::string result;
{
nb::gil_scoped_release gil_release;
result = xla::ValueOrThrow(client.BlockingKeyValueGet(
key, absl::Milliseconds(timeout_in_ms)));
}
return nb::bytes(result.data(), result.size());
},
nb::arg("key"), nb::arg("timeout_in_ms"))
.def(
"key_value_try_get",
[](DistributedRuntimeClient& client, std::string key) {
nb::gil_scoped_release gil_release;
return xla::ValueOrThrow(client.KeyValueTryGet(key));
},
nb::arg("key"))
.def(
"key_value_try_get_bytes",
[](DistributedRuntimeClient& client, std::string key) -> nb::bytes {
std::string result;
{
nb::gil_scoped_release gil_release;
result = xla::ValueOrThrow(client.KeyValueTryGet(key));
}
return nb::bytes(result.data(), result.size());
},
nb::arg("key"))
.def(
"wait_at_barrier",
[](DistributedRuntimeClient& client, std::string barrier_id,
int64_t timeout_in_ms,
std::optional<std::vector<int32_t>> process_ids) {
nb::gil_scoped_release gil_release;
xla::ThrowIfError(client.WaitAtBarrier(
barrier_id, absl::Milliseconds(timeout_in_ms), process_ids));
},
nb::arg("barrier_id"), nb::arg("timeout_in_ms"),
nb::arg("process_ids") = std::nullopt)
.def(
"get_live_nodes",
[](DistributedRuntimeClient& client,
std::vector<int32_t> process_ids) {
nb::gil_scoped_release gil_release;
return xla::ValueOrThrow(client.GetLiveNodes(process_ids));
},
nb::arg("process_ids"))
// The key must be a string, but the value can either be a Python string
// or bytes object.
// With Python string values, use `key_value_set()` and
// `blocking_key_value_get()`.
// With Python byte object values, use `key_value_set()` and
// `blocking_key_value_get_bytes()`.
.def(
"key_value_set",
[](DistributedRuntimeClient& client, absl::string_view key,
absl::string_view value, bool allow_overwrite) {
nb::gil_scoped_release gil_release;
xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite));
},
nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false)
// The key must be a string, but the value must a
// Python bytes object.
// Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`.
.def(
"key_value_set_bytes",
[](DistributedRuntimeClient& client, absl::string_view key,
nb::bytes value, bool allow_overwrite) {
nb::gil_scoped_release gil_release;
xla::ThrowIfError(client.KeyValueSet(
key, absl::string_view(value.c_str(), value.size()),
allow_overwrite));
},
nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false)
// Assumes that all values in the directory are Python strings.
.def(
"key_value_dir_get",
[](DistributedRuntimeClient& client, absl::string_view key) {
nb::gil_scoped_release gil_release;
return xla::ValueOrThrow(client.KeyValueDirGet(key));
},
nb::arg("key"))
// Assumes that all values in the directory are Python byte objects.
// Same as `key_value_dir_get()`, but retrieves Python byte values
// explicitly.
.def(
"key_value_dir_get_bytes",
[](DistributedRuntimeClient& client, absl::string_view key)
-> std::vector<std::pair<std::string, nb::bytes>> {
std::vector<std::pair<std::string, std::string>> result;
{
nb::gil_scoped_release gil_release;
result = xla::ValueOrThrow(client.KeyValueDirGet(key));
}
// Convert std::string values to nb::bytes.
std::vector<std::pair<std::string, nb::bytes>> kvs;
kvs.reserve(result.size());
for (auto& kv : result) {
kvs.push_back(
std::pair(std::move(kv.first),
nb::bytes(kv.second.data(), kv.second.size())));
}
return kvs;
},
nb::arg("key"))
.def(
"key_value_delete",
[](DistributedRuntimeClient& client, absl::string_view key) {
nb::gil_scoped_release gil_release;
return xla::ThrowIfError(client.KeyValueDelete(key));
},
nb::arg("key"));
m.def(
"get_distributed_runtime_service",
[](std::string address, int num_nodes,
std::optional<int> heartbeat_interval,
std::optional<int> max_missing_heartbeats,
std::optional<int> cluster_register_timeout,
std::optional<int> shutdown_timeout)
-> std::unique_ptr<DistributedRuntimeService> {
CoordinationServiceImpl::Options options;
options.num_nodes = num_nodes;
if (heartbeat_interval.has_value()) {
options.heartbeat_interval = absl::Seconds(*heartbeat_interval);
}
if (max_missing_heartbeats.has_value()) {
options.max_missing_heartbeats = *max_missing_heartbeats;
}
if (cluster_register_timeout.has_value()) {
options.cluster_register_timeout =
absl::Seconds(*cluster_register_timeout);
}
if (shutdown_timeout.has_value()) {
options.shutdown_timeout = absl::Seconds(*shutdown_timeout);
}
std::unique_ptr<DistributedRuntimeService> service =
xla::ValueOrThrow(GetDistributedRuntimeService(address, options));
return service;
},
nb::arg("address"), nb::arg("num_nodes"),
nb::arg("heartbeat_interval").none() = std::nullopt,
nb::arg("max_missing_heartbeats").none() = std::nullopt,
nb::arg("cluster_register_timeout").none() = std::nullopt,
nb::arg("shutdown_timeout").none() = std::nullopt);
m.def(
"get_distributed_runtime_client",
[](std::string address, int node_id, std::optional<int> rpc_timeout,
std::optional<int> init_timeout, std::optional<int> shutdown_timeout,
std::optional<int> heartbeat_interval,
std::optional<int> max_missing_heartbeats,
std::optional<std::function<void(absl::Status)>>
missed_heartbeat_callback,
std::optional<bool> shutdown_on_destruction,
std::optional<bool> use_compression)
-> std::shared_ptr<DistributedRuntimeClient> {
bool compression = use_compression.value_or(false);
DistributedRuntimeClient::Options options;
options.node_id = node_id;
if (rpc_timeout.has_value()) {
options.rpc_timeout = absl::Seconds(*rpc_timeout);
}
if (init_timeout.has_value()) {
options.init_timeout = absl::Seconds(*init_timeout);
}
if (shutdown_timeout.has_value()) {
options.shutdown_timeout = absl::Seconds(*shutdown_timeout);
}
if (heartbeat_interval.has_value()) {
options.heartbeat_interval = absl::Seconds(*heartbeat_interval);
}
if (max_missing_heartbeats.has_value()) {
options.max_missing_heartbeats = *max_missing_heartbeats;
}
if (missed_heartbeat_callback.has_value()) {
options.missed_heartbeat_callback =
std::move(*missed_heartbeat_callback);
}
if (shutdown_on_destruction.has_value()) {
options.shutdown_on_destruction = *shutdown_on_destruction;
}
return GetDistributedRuntimeClient(address, options, compression);
},
nb::arg("address"), nb::arg("node_id"),
nb::arg("rpc_timeout").none() = std::nullopt,
nb::arg("init_timeout").none() = std::nullopt,
nb::arg("shutdown_timeout").none() = std::nullopt,
nb::arg("heartbeat_interval").none() = std::nullopt,
nb::arg("max_missing_heartbeats").none() = std::nullopt,
nb::arg("missed_heartbeat_callback").none() = std::nullopt,
nb::arg("shutdown_on_destruction").none() = std::nullopt,
nb::arg("use_compression").none() = std::nullopt);
m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); });
m.def("is_optimized_build", &IsOptimizedBuild);
m.def("json_to_pprof_profile", xla::ValueOrThrowWrapper(JsonToPprofProfile),
"Encodes the JSON representation of a pprof Profile into its binary "
"protocol buffer encoding.");
m.def("pprof_profile_to_json", xla::ValueOrThrowWrapper(PprofProfileToJson),
"Decodes an uncompressed pprof Profile protocol buffer into a JSON "
"representation");
RegisterCompileOnlyClient(m);
nb::class_<ifrt::Topology>(m, "DeviceTopology")
.def("_make_compile_only_devices",
[](std::shared_ptr<ifrt::Topology> topology) {
if (!llvm::isa<ifrt::PjRtTopology>(*topology)) {
throw xla::XlaRuntimeError("Only PjRtTopologies are supported.");
}
return MakeCompileOnlyClient(
std::dynamic_pointer_cast<ifrt::PjRtTopology>(topology))
->Devices();
})
.def_prop_ro(
"platform",
[](ifrt::Topology& topology) { return topology.platform_name(); })
.def_prop_ro(
"platform_version",
[](ifrt::Topology& topology) { return topology.platform_version(); })
.def("serialize",
[](ifrt::Topology& topology) -> nb::bytes {
std::string serialized = ValueOrThrow(topology.Serialize());
return nb::bytes(serialized.data(), serialized.size());
})
.def("__getattr__",
[](ifrt::Topology& topology, absl::string_view name) -> nb::object {
const auto& attrs = topology.Attributes().map();
auto it = attrs.find(name);
if (it != attrs.end()) {
return std::visit([](auto&& v) { return nb::cast(v.value); },
it->second);
}
throw nb::attribute_error(
absl::StrCat("Unknown attribute ", name).c_str());
});
nb::class_<ifrt::Executable>(m, "Executable")
.def("hlo_modules", ValueOrThrowWrapper(&ifrt::Executable::GetHloModules))
.def("get_output_memory_kinds",
xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputMemoryKinds))
.def("get_output_shardings", &ifrt::Executable::GetOutputShardings)
.def("get_parameter_layouts",
ValueOrThrowWrapper(&ifrt::Executable::GetParameterLayouts))
.def("get_output_layouts",
xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputLayouts))
.def("get_parameter_shardings", &ifrt::Executable::GetParameterShardings)
.def("get_compiled_memory_stats",
xla::ValueOrThrowWrapper(&ifrt::Executable::GetCompiledMemoryStats))
.def("serialize",
[](const ifrt::Executable& exec) -> nb::bytes {
std::string serialized = ValueOrThrow(exec.Serialize());
return nb::bytes(serialized.data(), serialized.size());
})
.def("cost_analysis", [](const ifrt::Executable& exec) {
auto attrs = ValueOrThrow(exec.GetCostAnalysis());
return ifrt::ToPjRtAttributeMap(std::move(attrs));
});
m.def("is_asan", IsAsan);
m.def("is_msan", IsMsan);
m.def("is_tsan", IsTsan);
m.def("is_sanitized", IsSanitized);
m.def(
"batched_device_put",
[](nb::object aval, nb::object sharding, std::vector<nb::object> xs,
std::vector<const PyDevice*> dst_devices, bool committed,
bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics) -> nb::object {
return ValueOrThrow(PyArray::BatchedDevicePut(
aval, sharding, std::move(xs), std::move(dst_devices), committed,
force_copy, host_buffer_semantics, jax::GetEnableX64()));
},
nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"),
nb::arg("committed") = true, nb::arg("force_copy") = false,
nb::arg("host_buffer_semantics") =
PjRtClient::HostBufferSemantics::kImmutableZeroCopy);
m.def(
"reorder_shards",
[](PyArray x, nb::object dst_sharding,
ifrt::ArrayCopySemantics array_copy_semantics) {
return ValueOrThrow(PyArray::ReorderShards(
std::move(x), std::move(dst_sharding), array_copy_semantics));
},
nb::arg("x"), nb::arg("dst_sharding"), nb::arg("array_copy_semantics"));
m.def("batched_block_until_ready", [](std::vector<nb::object> xs) {
ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs)));
});
m.def("check_and_canonicalize_memory_kind",
&jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(),
nb::arg("device_list"));
} // NOLINT(readability/fn_size)
} // namespace xla

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import atexit
from collections.abc import Mapping, Sequence
import contextlib
import enum # pylint: disable=g-bad-import-order
import enum
import gzip
import inspect
import logging

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,32 @@
# Copyright 2024 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, Generic, TypeVar
unset: object
_T = TypeVar('_T')
class Config(Generic[_T]):
def __init__(self, value: _T, include_in_jit_key: bool = False): ...
@property
def value(self) -> _T: ...
def get_local(self) -> Any: ...
def get_global(self) -> _T: ...
def set_local(self, value: Any) -> None: ...
def swap_local(self, value: Any) -> Any: ...
def set_global(self, value: _T) -> None: ...

View File

@ -0,0 +1,46 @@
# Copyright 2024 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, List, Optional
class TransferGuardLevel:
ALLOW: Any
LOG: Any
DISALLOW: Any
LOG_EXPLICIT: Any
DISALLOW_EXPLICIT: Any
class GarbageCollectionGuardLevel:
ALLOW: Any
LOG: Any
FATAL: Any
class GuardState:
host_to_device: Optional[TransferGuardLevel]
device_to_device: Optional[TransferGuardLevel]
device_to_host: Optional[TransferGuardLevel]
explicit_device_put: bool
explicit_device_get: bool
garbage_collect_array: Optional[GarbageCollectionGuardLevel]
def global_state() -> GuardState: ...
def thread_local_state() -> GuardState: ...
class _TestingScopedLogSink:
def __enter__(self) -> _TestingScopedLogSink: ...
def __exit__(self, *args, **kwargs) -> None: ...
def logs(self) -> List[str]: ...

View File

@ -0,0 +1,43 @@
# Copyright 2024 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, Sequence, Union
from jax.jaxlib.xla import xla_extension
class Program: ...
class CompileOptions: ...
def make_hlo_program(mlir_module: Union[str, bytes]) -> Program: ...
def make_colocated_python_program(
name : str,
picked_function: bytes,
devices: Sequence[xla_extension.Device] | xla_extension.DeviceList,
input_avals: Sequence[Any],
output_avals: Sequence[Any],
) -> Program: ...
def make_plugin_program(data: Union[str, bytes]) -> Program: ...
def make_colocated_python_compile_options() -> CompileOptions: ...
def make_xla_compile_options(
compile_options: xla_extension.CompileOptions,
host_callbacks: Sequence[Any]
) -> CompileOptions: ...
def make_plugin_compile_options() -> CompileOptions: ...

View File

@ -0,0 +1,33 @@
# Copyright 2024 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, Optional, Callable
from jax.jaxlib.xla import xla_extension
_Status = Any
Client = xla_extension.Client
class ClientConnectionOptions:
on_disconnect: Optional[Callable[[_Status], None]] = None
on_connection_update: Optional[Callable[[str], None]] = None
connection_timeout_in_seconds: Optional[int] = None
def get_client(
proxy_server_address: str,
options: ClientConnectionOptions
) -> Client: ...

View File

@ -0,0 +1,76 @@
# Copyright 2021 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, Callable, Optional, Sequence, Tuple
import numpy as np
from jax.jaxlib.xla import xla_extension
from . import pytree
Client = xla_extension.Client
Device = xla_extension.Device
class JitState:
disable_jit: Optional[bool]
enable_x64: Optional[bool]
default_device: Optional[Any]
extra_jit_context: Optional[Any]
post_hook: Optional[Callable[..., Any]]
def global_state() -> JitState: ...
def thread_local_state() -> JitState: ...
def get_enable_x64() -> bool: ...
def set_thread_local_state_initialization_callback(
function: Callable[[], None]): ...
def swap_thread_local_state_disable_jit(
value: Optional[bool]) -> Optional[bool]: ...
class ArgSignature:
dtype: np.dtype
shape: Tuple[int, ...]
weak_type: bool
def _ArgSignatureOfValue(
__arg: Any,
__jax_enable_x64: bool) -> ArgSignature: ...
def _is_float0(__arg: Any) -> bool: ...
class ArgumentSignature:
static_args: Sequence[Any]
static_arg_names: Sequence[str]
dynamic_arg_names: Sequence[str]
dynamic_arg_treedefs: Sequence[pytree.PyTreeDef]
def __eq__(self, value, /): ...
def __ne__(self, value, /): ...
def __hash__(self, /): ...
def __str__(self): ...
def __repr__(self): ...
def parse_arguments(
positional_args: Sequence[Any],
keyword_args: Sequence[Any],
kwnames: Tuple[str, ...],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
pytree_registry: pytree.PyTreeRegistry,
) -> tuple[ArgumentSignature, Sequence[Any]]: ...

View File

@ -0,0 +1,34 @@
# Copyright 2021 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union
from . import XlaComputation
def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ...
def mlir_module_to_xla_computation(
mlir_module: Union[bytes, str],
use_tuple_args: bool = ...,
return_tuple: bool = ...,
) -> XlaComputation: ...
def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> bytes: ...
def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> bytes: ...
def serialize_portable_artifact(mlir_module: str, target: str) -> bytes: ...
def deserialize_portable_artifact(mlir_module: bytes) -> str: ...
def refine_polymorphic_shapes(
mlir_module: Union[bytes, str],
enable_shape_assertions: bool = ...,
validate_static_shapes: bool = ...,
enable_shardy: bool = ...,
) -> bytes: ...

View File

@ -0,0 +1,465 @@
# Copyright 2021 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import enum
from typing import Any, Optional, Sequence, overload
from jax.jaxlib.xla import xla_extension
FftType = xla_extension.FftType
XlaBuilder = xla_extension.XlaBuilder
XlaComputation = xla_extension.XlaComputation
XlaOp = xla_extension.XlaOp
PrecisionConfig_Precision = xla_extension.PrecisionConfig_Precision
PrimitiveType = xla_extension.PrimitiveType
Shape = xla_extension.Shape
ShapeIndex = xla_extension.ShapeIndex
ResultAccuracy = xla_extension.ResultAccuracy
_ChannelHandle = Any
_ConvDimensionNumbers = Any
_DotDimensionNumbers = Any
_Layout = Any
_LiteralSlice = Any
_GatherDimensionNumbers = Any
_PaddingConfig = Any
_ReplicaGroup = Any
_ScatterDimensionNumbers = Any
class TriangularSolveOptions_Transpose(enum.IntEnum):
TRANSPOSE_INVALID: int
NO_TRANSPOSE: int
TRANSPOSE: int
ADJOINT: int
class RandomAlgorithm(enum.IntEnum):
RNG_DEFAULT: int
RNG_THREE_FRY: int
RNG_PHILOX: int
class CustomCallSchedule(enum.IntEnum):
SCHEDULE_NONE: int
SCHEDULE_LATEST: int
SCHEDULE_EARLIEST: int
# TODO(b/189822916): Remove this enum when all clients are migrated to the
# status-returning API.
class CustomCallApiVersion(enum.IntEnum):
API_VERSION_ORIGINAL: int
API_VERSION_STATUS_RETURNING: int
API_VERSION_STATUS_RETURNING_UNIFIED: int
API_VERSION_TYPED_FFI: int
def AfterAll(builder: XlaBuilder, tokens: Sequence[XlaOp]) -> XlaOp: ...
def AllGather(
operand: XlaOp,
all_gather_dimension: int,
shard_count: int,
replica_groups: Sequence[_ReplicaGroup] = ...,
channel_id: Optional[_ChannelHandle] = ...,
shape_with_layout: Optional[_Layout] = ...,
use_global_device_ids: Optional[bool] = ...) -> XlaOp: ...
def AllReduce(
operand: XlaOp,
computation: XlaComputation,
replica_groups: Sequence[_ReplicaGroup] = ...,
channel_id: Optional[_ChannelHandle] = ...,
shape_with_layout: Optional[_Layout] = ...) -> XlaOp: ...
def ApproxTopK(
builder: XlaBuilder,
operands: Sequence[XlaOp],
init_values: Sequence[XlaOp],
top_k: int,
reduction_dim: int,
comparator: XlaComputation,
recall_target: Optional[float],
aggregate_to_topk: Optional[bool],
reduction_input_size_override: Optional[int]) -> XlaOp: ...
def ApproxTopKFallback(
builder: XlaBuilder,
operands: Sequence[XlaOp],
init_values: Sequence[XlaOp],
top_k: int,
reduction_dim: int,
comparator: XlaComputation,
recall_target: Optional[float],
aggregate_to_topk: Optional[bool],
reduction_input_size_override: Optional[int]) -> XlaOp: ...
def ApproxTopKReductionOutputSize(
input_size: int,
rank: int,
top_k: int,
recall_target: float,
aggregate_to_topk: Optional[bool] = ...,
input_size_override: Optional[int] = ...) -> tuple[int, int]: ...
def ReduceScatter(
operand: XlaOp,
computation: XlaComputation,
scatter_dimension: int,
shard_count: int,
replica_groups: Sequence[_ReplicaGroup] = ...,
channel_id: Optional[_ChannelHandle] = ...,
layout: Optional[_Layout] = ...,
use_global_device_ids: Optional[bool] = ...) -> XlaOp: ...
def AllToAll(
operand: XlaOp,
split_dimension: int,
concat_dimension: int,
split_count: int,
replica_groups: Sequence[_ReplicaGroup] = ...,
layout: Optional[_Layout] = ...,
channel_id: Optional[_ChannelHandle] = ...) -> XlaOp: ...
def BitcastConvertType(operand: XlaOp,
new_element_type: PrimitiveType) -> XlaOp: ...
def Broadcast(operand: XlaOp, sizes: Sequence[int]) -> XlaOp: ...
def BroadcastInDim(operand: XlaOp,
shape: Sequence[int],
broadcast_dimensions: Sequence[int]) -> XlaOp: ...
def Call(builder: XlaBuilder,
computation: XlaComputation,
operands: Sequence[XlaOp]) -> XlaOp: ...
def Cholesky(a: XlaOp, lower: bool = ...) -> XlaOp: ...
def Clamp(min: XlaOp, operand: XlaOp, max: XlaOp) -> XlaOp: ...
def Collapse(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ...
def CollectivePermute(
operand: XlaOp,
source_target_pairs: Sequence[tuple[int, int]],
channel_id: Optional[_ChannelHandle] = ...,
inplace: bool = ...) -> XlaOp: ...
def ConcatInDim(builder: XlaBuilder,
operands: Sequence[XlaOp],
dimension: int) -> XlaOp: ...
@overload
def Conditional(branch_index: XlaOp,
branch_computations: Sequence[XlaComputation],
branch_operands: Sequence[XlaOp]) -> XlaOp: ...
@overload
def Conditional(
predicate: XlaOp,
true_operand: XlaOp,
true_computation: XlaComputation,
false_operand: XlaOp,
false_computation: XlaComputation) -> XlaOp: ...
def Constant(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ...
def ConstantLiteral(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ...
def ConvGeneralDilated(
lhs: XlaOp,
rhs: XlaOp,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
lhs_dilation: Sequence[int],
rhs_dilation: Sequence[int],
dimension_numbers: _ConvDimensionNumbers,
feature_group_count: int = ...,
batch_group_count: int = ...,
precision_config: Optional[PrecisionConfig_Precision] = ...,
preferred_element_type: Optional[PrimitiveType] = ...,
window_reversal: Optional[Sequence[bool]] = ...) -> XlaOp: ...
def ConvertElementType(
operand: XlaOp,
new_element_type: PrimitiveType) -> XlaOp: ...
def CreateToken(builder: XlaBuilder) -> XlaOp: ...
def CrossReplicaSum(
operand: XlaOp,
replica_groups: Sequence[_ReplicaGroup] = ...) -> XlaOp: ...
def CustomCall(
builder: XlaBuilder,
call_target_name: bytes,
operands: Sequence[XlaOp],
shape: Shape,
opaque: bytes = ...,
has_side_effect: bool = ...,
schedule: CustomCallSchedule = ...,
api_version: CustomCallApiVersion = ...) -> XlaOp: ...
def CustomCallWithLayout(
builder: XlaBuilder,
call_target_name: bytes,
operands: Sequence[XlaOp],
shape_with_layout: Shape,
operand_shapes_with_layout: Sequence[Shape],
opaque: bytes = ...,
has_side_effect: bool = ...,
schedule: CustomCallSchedule = ...,
api_version: CustomCallApiVersion = ...) -> XlaOp: ...
def CustomCallWithAliasing(
builder: XlaBuilder,
call_target_name: bytes,
operands: Sequence[XlaOp],
shape_with_layout: Shape,
operand_shapes_with_layout: Sequence[Shape],
opaque: bytes = ...,
has_side_effect: bool = ...,
output_operand_aliasing: Sequence[tuple[ShapeIndex, tuple[int, ShapeIndex]]] = ...,
literal: _LiteralSlice = ...,
schedule: CustomCallSchedule = ...,
api_version: CustomCallApiVersion = ...) -> XlaOp: ...
def Dot(
lhs: XlaOp,
rhs: XlaOp,
precision_config: Optional[PrecisionConfig_Precision] = ...,
preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ...
def DotGeneral(
lhs: XlaOp,
rhs: XlaOp,
dimensions_numbers: _DotDimensionNumbers,
precision_config: Optional[PrecisionConfig_Precision] = ...,
preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ...
def DynamicReshape(
operand: XlaOp,
dim_sizes: Sequence[XlaOp],
new_size_bounds: Sequence[int],
dims_are_dynamic: Sequence[bool]) -> XlaOp: ...
def DynamicSlice(
operand: XlaOp,
start_indices: Sequence[XlaOp],
slice_sizes: Sequence[int]) -> XlaOp: ...
def DynamicUpdateSlice(
operand: XlaOp,
update: XlaOp,
start_indices: Sequence[XlaOp]) -> XlaOp: ...
def Eigh(
a: XlaOp,
lower: bool = ...,
max_iter: int = ...,
epsilon: float = ...,
sort_eigenvalues: bool = ...) -> tuple[XlaOp, XlaOp]: ...
def Fft(
operand: XlaOp,
fft_type: FftType,
fft_length: Sequence[int]) -> XlaOp: ...
def Gather(
a: XlaOp,
start_indices: XlaOp,
dimension_numbers: _GatherDimensionNumbers,
slice_sizes: Sequence[int],
indices_are_sorted: bool = ...) -> XlaOp: ...
def GetDimensionSize(operand: XlaOp, index: int) -> XlaOp: ...
def GetTupleElement(tuple_data: XlaOp, index: int) -> XlaOp: ...
def InfeedWithToken(
token: XlaOp,
shape: Shape,
config: Optional[str] = ...) -> XlaOp: ...
@overload
def Iota(builder: XlaBuilder, shape: Shape, iota_dimension: int) -> XlaOp: ...
@overload
def Iota(builder: XlaBuilder, type: PrimitiveType, size: int) -> XlaOp: ...
def LU(a: XlaOp) -> tuple[XlaOp, XlaOp, XlaOp]: ...
def Map(
builder: XlaBuilder,
operands: Sequence[XlaOp],
computation: XlaComputation,
dimensions: Sequence[int],
static_operands: Sequence[XlaOp] = ...) -> XlaOp: ...
def MultiCollectivePermute(
operands: Sequence[XlaOp],
source_target_pairs: Sequence[tuple[int, int]],
channel_id: Optional[_ChannelHandle] = ...,
inplace: bool = ...) -> XlaOp: ...
def NextAfter(__from: XlaOp, to: XlaOp) -> XlaOp: ...
def OutfeedWithToken(
operand: XlaOp,
token: XlaOp,
shape_with_layout: Shape,
outfeed_config: Optional[str] = ...) -> XlaOp: ...
def Pad(
operand: XlaOp,
padding_value: XlaOp,
padding_config: _PaddingConfig) -> XlaOp: ...
def Parameter(
builder: XlaBuilder,
parameter_number: int,
shape: Shape,
name: str = ...,
replicated_at_leaf_buffers: Sequence[bool] = ...) -> XlaOp: ...
def ProductOfElementaryHouseholderReflectors(a: XlaOp, taus: XlaOp) -> XlaOp: ...
def QR(a: XlaOp, full_matrices: bool) -> tuple[XlaOp, XlaOp]: ...
def QrDecomposition(a: XlaOp) -> tuple[XlaOp, XlaOp]: ...
def Reduce(
builder: XlaBuilder,
operands: Sequence[XlaOp],
init_values: Sequence[XlaOp],
computation: XlaComputation,
dimensions_to_reduce: Sequence[int]) -> XlaOp: ...
def ReducePrecision(
operand: XlaOp,
exponent_bits: int,
mantissa_bits: int) -> XlaOp: ...
@overload
def ReduceWindowWithGeneralPadding(
operand: XlaOp,
init_value: XlaOp,
computation: XlaComputation,
window_dimensions: Sequence[int],
window_strides: Sequence[int],
base_dilations: Sequence[int],
window_dilations: Sequence[int],
padding: Sequence[tuple[int, int]]) -> XlaOp: ...
@overload
def ReduceWindowWithGeneralPadding(
operands: Sequence[XlaOp],
init_values: Sequence[XlaOp],
computation: XlaComputation,
window_dimensions: Sequence[int],
window_strides: Sequence[int],
base_dilations: Sequence[int],
window_dilations: Sequence[int],
padding: Sequence[tuple[int, int]]) -> XlaOp: ...
def ReplicaId(builder: XlaBuilder) -> XlaOp: ...
def Reshape(operand: XlaOp, new_sizes: Sequence[int]) -> XlaOp: ...
def Rev(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ...
def RngBitGenerator(
algorithm: RandomAlgorithm,
initial_state: XlaOp,
shape: Shape) -> XlaOp: ...
def RngNormal(mu: XlaOp, sigma: XlaOp, shape: Shape) -> XlaOp: ...
def RngUniform(a: XlaOp, b: XlaOp, shape: Shape) -> XlaOp: ...
@overload
def Scatter(
input: XlaOp,
scatter_indices: XlaOp,
updates: XlaOp,
update_computation: XlaComputation,
dimension_numbers: _ScatterDimensionNumbers,
indices_are_sorted: bool = ...,
unique_indices: bool = ...) -> XlaOp: ...
@overload
def Scatter(
inputs: Sequence[XlaOp],
scatter_indices: XlaOp,
updates: Sequence[XlaOp],
update_computation: XlaComputation,
dimension_numbers: _ScatterDimensionNumbers,
indices_are_sorted: bool = ...,
unique_indices: bool = ...) -> XlaOp: ...
def Select(pred: XlaOp, on_true: XlaOp, on_false: XlaOp) -> XlaOp: ...
def SelectAndScatterWithGeneralPadding(
operand: XlaOp,
select: XlaComputation,
window_dimensions: Sequence[int],
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
source: XlaOp,
init_value: XlaOp,
scatter: XlaComputation) -> XlaOp: ...
def Slice(
operand: XlaOp,
start_indices: Sequence[int],
limit_indices: Sequence[int],
strides: Sequence[int]) -> XlaOp: ...
def SliceInDim(
operand: XlaOp,
start_index: int,
limit_index: int,
stride: int,
dimno: int) -> XlaOp: ...
def Sort(
builder: XlaBuilder,
operands: Sequence[XlaOp],
comparator: Optional[XlaComputation] = ...,
dimension: int = ...,
is_stable: bool = ...) -> XlaOp: ...
def SVD(
a: XlaOp,
max_iter: int = ...,
epsilon: float = ...) -> tuple[XlaOp, XlaOp, XlaOp]: ...
def TopK(input: XlaOp, k: int) -> XlaOp: ...
def Transpose(operand: XlaOp, permutation: Sequence[int]) -> XlaOp: ...
def TriangularSolve(
a: XlaOp,
b: XlaOp,
left_side: bool,
lower: bool,
unit_diagonal: bool,
transpose_a: TriangularSolveOptions_Transpose) -> XlaOp: ...
def Tuple(builder: XlaBuilder, elements: Sequence[XlaOp]) -> XlaOp: ...
def While(
condition: XlaComputation,
body: XlaComputation,
init: XlaOp) -> XlaOp: ...
def Igamma(a: XlaOp, x: XlaOp) -> XlaOp: ...
def Igammac(a: XlaOp, x: XlaOp) -> XlaOp: ...
def IgammaGradA(a: XlaOp, x: XlaOp) -> XlaOp: ...
def RandomGammaGrad(a: XlaOp, x: XlaOp) -> XlaOp: ...
def RegularizedIncompleteBeta(a: XlaOp, b: XlaOp, x: XlaOp) -> XlaOp: ...
def Zeta(a: XlaOp, q: XlaOp) -> XlaOp: ...
def Eq(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Ne(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Ge(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Gt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Lt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Le(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Add(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Sub(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Mul(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Div(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Rem(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Max(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Min(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def And(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Or(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Xor(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def ShiftLeft(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def ShiftRightArithmetic(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def ShiftRightLogical(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Atan2(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Pow(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Complex(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
def Not(__arg: XlaOp) -> XlaOp: ...
def PopulationCount(__arg: XlaOp) -> XlaOp: ...
def Clz(__arg: XlaOp) -> XlaOp: ...
def Abs(__arg: XlaOp) -> XlaOp: ...
def Exp(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def Expm1(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def Floor(__arg: XlaOp) -> XlaOp: ...
def Ceil(__arg: XlaOp) -> XlaOp: ...
def Round(__arg: XlaOp) -> XlaOp: ...
def Log(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def Log1p(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def Sign(__arg: XlaOp) -> XlaOp: ...
def Cos(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def OptimizationBarrier(__arg: XlaOp) -> XlaOp: ...
def Sin(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def Tan(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def Tanh(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def IsFinite(__arg: XlaOp) -> XlaOp: ...
def Neg(__arg: XlaOp) -> XlaOp: ...
def Sqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def Rsqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def Cbrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def Square(__arg: XlaOp) -> XlaOp: ...
def Reciprocal(__arg: XlaOp) -> XlaOp: ...
def Erfc(__arg: XlaOp) -> XlaOp: ...
def Erf(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
def ErfInv(__arg: XlaOp) -> XlaOp: ...
def Lgamma(__arg: XlaOp) -> XlaOp: ...
def Digamma(__arg: XlaOp) -> XlaOp: ...
def BesselI0e(__arg: XlaOp) -> XlaOp: ...
def BesselI1e(__arg: XlaOp) -> XlaOp: ...
def Acos(__arg: XlaOp) -> XlaOp: ...
def Asin(__arg: XlaOp) -> XlaOp: ...
def Atan(__arg: XlaOp) -> XlaOp: ...
def Acosh(__arg: XlaOp) -> XlaOp: ...
def Asinh(__arg: XlaOp) -> XlaOp: ...
def Atanh(__arg: XlaOp) -> XlaOp: ...
def Cosh(__arg: XlaOp) -> XlaOp: ...
def Sinh(__arg: XlaOp) -> XlaOp: ...
def Real(__arg: XlaOp) -> XlaOp: ...
def Imag(__arg: XlaOp) -> XlaOp: ...
def Conj(__arg: XlaOp) -> XlaOp: ...

View File

@ -0,0 +1,83 @@
# Copyright 2021 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import inspect
from typing import Any, Callable, Sequence, Iterable, Tuple
from . import pytree
_AvalDimSharding = Any
_MeshDimAssignment = Any
class NoSharding:
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
def __eq__(self, __other: Any) -> bool: ...
class Chunked:
@property
def chunks(self) -> Sequence[int]: ...
def __init__(self, __chunks: Sequence[int]) -> None: ...
def __repr__(self) -> str: ...
def __eq__(self, __other: Any) -> bool: ...
class Unstacked:
@property
def size(self) -> int: ...
def __init__(self, __sz: int) -> None: ...
def __repr__(self) -> str: ...
def __eq__(self, __other: Any) -> bool: ...
class ShardedAxis:
@property
def axis(self) -> int: ...
def __init__(self, __axis: int) -> None: ...
def __repr__(self) -> str: ...
def __eq__(self, __other: ShardedAxis) -> bool: ...
class Replicated:
@property
def replicas(self) -> int: ...
def __init__(self, __replicas: int) -> None: ...
def __repr__(self) -> str: ...
def __eq__(self, __other: Replicated) -> bool: ...
class ShardingSpec:
def __init__(self,
sharding: Iterable[_AvalDimSharding],
mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ...
@property
def sharding(self) -> Tuple[_AvalDimSharding, ...]: ...
@property
def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ...
def __eq__(self, __other: ShardingSpec) -> bool: ...
def __hash__(self) -> int: ...
_HAS_DYNAMIC_ATTRIBUTES = True
class PmapFunction:
def __call__(self, *args, **kwargs) -> Any: ...
def __getstate__(self) -> Any: ...
def __setstate__(self, Any): ...
__signature__: inspect.Signature
def _cache_size(self) -> int: ...
def _cache_clear(self) -> None: ...
def _debug_cache_keys(self) -> str: ...
def pmap(fun: Callable[..., Any],
cache_miss: Callable[..., Any],
static_argnums: Sequence[int],
shard_arg_fallback: Callable[..., Any],
pytree_registry: pytree.PyTreeRegistry) -> PmapFunction: ...

View File

@ -0,0 +1,58 @@
# Copyright 2021 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from types import TracebackType
from typing import Any, Optional, Type, Union, List, Tuple
_Status = Any
class ProfilerServer: ...
def start_server(port: int) -> ProfilerServer: ...
def register_plugin_profiler(c_api: Any) -> None: ...
def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ...
def get_instructins_profile(tensorboard_dir: str) -> List[Tuple[str, float]]: ...
def get_fdo_profile(
xspace: bytes, as_textproto: bool = ...
) -> Union[bytes, str]: ...
class ProfilerSession:
def __init__(self, options: Optional[ProfileOptions] = ...) -> None: ...
def stop(self) -> bytes: ...
def export(self, xspace: bytes, tensorboard_dir: str) -> _Status:...
class ProfileOptions:
include_dataset_ops: bool
host_tracer_level: int
python_tracer_level: int
enable_hlo_proto: bool
start_timestamp_ns: int
duration_ms: int
repository_path: str
def aggregate_profiled_instructions(profiles: List[bytes], percentile: int) -> str: ...
class TraceMe:
def __init__(self, name: str, **kwargs: Any) -> None: ...
def __enter__(self) -> TraceMe: ...
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> Optional[bool]:...
def set_metadata(self, **kwargs): ...
@staticmethod
def is_enabled() -> bool: ...

View File

@ -0,0 +1,158 @@
# Copyright 2021 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import (
Any,
Callable,
Hashable,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
_T = TypeVar("_T")
version: int
class PyTreeRegistry:
def __init__(
self,
*,
enable_none: bool = ...,
enable_tuple: bool = ...,
enable_namedtuple: bool = ...,
enable_list: bool = ...,
enable_dict: bool = ...
): ...
def flatten(
self,
tree: Any,
leaf_predicate: Optional[Callable[[Any], bool]] = ...,
) -> Tuple[List[Any], PyTreeDef]: ...
def flatten_one_level(
self, tree: Any
) -> Optional[Tuple[Iterable[Any], Any]]: ...
def flatten_one_level_with_keys(
self, tree: Any
) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ...
def flatten_with_path(
self,
tree: Any,
leaf_predicate: Optional[Callable[[Any], bool]] = ...,
) -> Tuple[List[Tuple[_KeyPath, Any]], PyTreeDef]: ...
def register_node(
self,
__type: Type[_T],
to_iterable: Callable[[_T], Tuple[_Children, _AuxData]],
from_iterable: Callable[[_AuxData, _Children], _T],
to_iterable_with_keys: (
Callable[[_T], Tuple[_KeyLeafPairs, _AuxData]] | None
) = ...,
) -> Any: ...
def register_dataclass_node(
self, __type: Type[_T], meta_fields: List[str], data_fields: List[str]
) -> Any: ...
def default_registry() -> PyTreeRegistry: ...
def tuple(registry: PyTreeRegistry, arg0: Sequence[PyTreeDef]) -> PyTreeDef: ...
def all_leaves(registry: PyTreeRegistry, arg0: Iterable[Any]) -> bool: ...
class SequenceKey(Hashable):
idx: int
__match_args__: tuple = ...
def __init__(self, idx: int): ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __getstate__(self) -> Any: ...
def __setstate__(self, state: Any): ...
def __eq__(self, __other: Any) -> bool: ...
class DictKey(Hashable):
key: Hashable
__match_args__: tuple = ...
def __init__(self, key: Hashable): ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __getstate__(self) -> Any: ...
def __setstate__(self, state: Any): ...
def __eq__(self, __other: Any) -> bool: ...
class GetAttrKey(Hashable):
name: str
__match_args__: tuple = ...
def __init__(self, name: str): ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __getstate__(self) -> Any: ...
def __setstate__(self, state: Any): ...
def __eq__(self, __other: Any) -> bool: ...
class FlattenedIndexKey(Hashable):
key: int
__match_args__: tuple = ...
def __init__(self, key: int): ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __getstate__(self) -> Any: ...
def __setstate__(self, state: Any): ...
def __eq__(self, __other: Any) -> bool: ...
class PyTreeDef:
def unflatten(self, __leaves: Iterable[Any]) -> Any: ...
def flatten_up_to(self, __xs: Any) -> List[Any]: ...
def compose(self, __inner: PyTreeDef) -> PyTreeDef: ...
def walk(
self,
__f_node: Callable[[Any, Any], Any],
__f_leaf: Optional[Callable[[_T], Any]],
leaves: Iterable[Any],
) -> Any: ...
def from_iterable_tree(self, __xs: Any): ...
def node_data(self) -> Optional[Tuple[Type, Any]]: ...
def children(self) -> List[PyTreeDef]: ...
@staticmethod
def make_from_node_data_and_children(
registry: PyTreeRegistry,
node_data: Optional[Tuple[Type, Any]],
children: Iterable[PyTreeDef],
) -> PyTreeDef: ...
num_leaves: int
num_nodes: int
def __repr__(self) -> str: ...
def __eq__(self, __other: PyTreeDef) -> bool: ...
def __ne__(self, __other: PyTreeDef) -> bool: ...
def __hash__(self) -> int: ...
def __getstate__(self) -> Any: ...
def __setstate__(self, state: Any): ...
def serialize_using_proto(self) -> bytes: ...
@staticmethod
def deserialize_using_proto(
registry: PyTreeRegistry, data: bytes
) -> PyTreeDef: ...
_Children = TypeVar("_Children", bound=Iterable[Any])
_KeyLeafPair = TypeVar("_KeyLeafPair", bound=Tuple[Any, Any])
_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[_KeyLeafPair])
_KeyPath = TypeVar("_KeyPath", bound=Tuple[Any, ...])
_AuxData = TypeVar("_AuxData", bound=Hashable)

View File

@ -0,0 +1,32 @@
# Copyright 2021 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from mlir import ir
def sdy_round_trip_export_pipeline(
module: ir.module
) -> str: ...
def sdy_round_trip_import_shardings(
module: ir.module
) -> str: ...
def get_mesh(
module: ir.module
) -> tuple[tuple[str, int], ...]: ...
def lowered_with_shardy(
module: ir.module
) -> bool: ...

View File

@ -0,0 +1,39 @@
# Copyright 2022 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, List, Optional
class TransferGuardLevel:
ALLOW: Any
LOG: Any
DISALLOW: Any
LOG_EXPLICIT: Any
DISALLOW_EXPLICIT: Any
class TransferGuardState:
host_to_device: Optional[TransferGuardLevel]
device_to_device: Optional[TransferGuardLevel]
device_to_host: Optional[TransferGuardLevel]
explicit_device_put: bool
explicit_device_get: bool
def global_state() -> TransferGuardState: ...
def thread_local_state() -> TransferGuardState: ...
class _TestingScopedLogSink:
def __enter__(self) -> _TestingScopedLogSink: ...
def __exit__(self, *args, **kwargs) -> None: ...
def logs(self) -> List[str]: ...

17
jaxlib/xla_extension.py Normal file
View File

@ -0,0 +1,17 @@
# Copyright 2025 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from jaxlib.xla.xla_extension import * # noqa: F403
from jaxlib.xla.xla_extension import sdy # noqa: F401