mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[JAX] [XLA:Python] Migrate xla_extension and its type stubs into jaxlib.
Future changes will migrate many of its dependent modules. PiperOrigin-RevId: 739361786
This commit is contained in:
parent
2692c5ff98
commit
55e408471c
@ -45,6 +45,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib:cpu_feature_guard",
|
||||
"//jaxlib:utils",
|
||||
"//jaxlib/xla:xla_client",
|
||||
"//jaxlib/xla:xla_extension",
|
||||
"//jaxlib/triton",
|
||||
"//jaxlib/mlir/_mlir_libs:register_jax_dialects",
|
||||
"//jaxlib/mlir:arithmetic_dialect",
|
||||
@ -61,6 +62,5 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:sparse_tensor_dialect",
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
# xla_extension
|
||||
]),
|
||||
)
|
||||
|
10
jaxlib/BUILD
10
jaxlib/BUILD
@ -29,13 +29,6 @@ package(
|
||||
default_visibility = ["//jax:internal"],
|
||||
)
|
||||
|
||||
# This makes xla_extension module accessible from jax._src.lib.
|
||||
genrule(
|
||||
name = "xla_extension_py",
|
||||
outs = ["xla_extension.py"],
|
||||
cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@",
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "jaxlib",
|
||||
srcs = [
|
||||
@ -51,8 +44,8 @@ py_library_providing_imports_info(
|
||||
"lapack.py",
|
||||
"plugin_support.py",
|
||||
"xla_client.py",
|
||||
"xla_extension.py",
|
||||
":version",
|
||||
":xla_extension_py",
|
||||
],
|
||||
data = [":ffi_headers"],
|
||||
lib_rule = pytype_library,
|
||||
@ -82,6 +75,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mosaic",
|
||||
"//jaxlib/triton",
|
||||
"//jaxlib/xla:xla_client",
|
||||
"//jaxlib/xla:xla_extension",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -610,3 +610,12 @@ def jax_py_test(
|
||||
if "PYTHONWARNINGS" not in env:
|
||||
env["PYTHONWARNINGS"] = "error"
|
||||
py_test(name = name, env = env, **kwargs)
|
||||
|
||||
def if_oss(oss_value, google_value = []):
|
||||
"""Returns one of the arguments based on the non-configurable build env.
|
||||
|
||||
Specifically, it does not return a `select`, and can be used to e.g.
|
||||
compute elements of list attributes.
|
||||
"""
|
||||
_ = (google_value, oss_value) # buildifier: disable=unused-variable
|
||||
return oss_value
|
||||
|
@ -62,11 +62,11 @@ py_binary(
|
||||
"//jaxlib",
|
||||
"//jaxlib:README.md",
|
||||
"//jaxlib:setup.py",
|
||||
"//jaxlib/xla:xla_client.py",
|
||||
"//jaxlib/xla:xla_extension",
|
||||
"@xla//xla/ffi/api:api.h",
|
||||
"@xla//xla/ffi/api:c_api.h",
|
||||
"@xla//xla/ffi/api:ffi.h",
|
||||
"@xla//xla/python:xla_client.py",
|
||||
"@xla//xla/python:xla_extension",
|
||||
] + if_windows([
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
]),
|
||||
|
@ -110,7 +110,7 @@ def patch_copy_xla_extension_stubs(dst_dir):
|
||||
xla_extension_dir = os.path.join(dst_dir, "xla_extension")
|
||||
os.makedirs(xla_extension_dir)
|
||||
for stub_name in _XLA_EXTENSION_STUBS:
|
||||
stub_path = r.Rlocation("xla/xla/python/xla_extension/" + stub_name)
|
||||
stub_path = r.Rlocation("__main__/jaxlib/xla/xla_extension/" + stub_name)
|
||||
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
|
||||
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
|
||||
continue
|
||||
@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack():
|
||||
if not _is_mac():
|
||||
return
|
||||
nm = subprocess.run(
|
||||
["nm", "-g", r.Rlocation("xla/xla/python/xla_extension.so")],
|
||||
["nm", "-g", r.Rlocation("__main/jaxlib/xla/xla_extension.so")],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
@ -198,7 +198,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
"__main__/jaxlib/plugin_support.py",
|
||||
"__main__/jaxlib/version.py",
|
||||
"__main__/jaxlib/xla/xla_client.py",
|
||||
f"xla/xla/python/xla_extension.{pyext}",
|
||||
f"__main__/jaxlib/xla/xla_extension.{pyext}",
|
||||
],
|
||||
)
|
||||
# This file is required by PEP-561. It marks jaxlib as package containing
|
||||
|
111
jaxlib/xla/BUILD
111
jaxlib/xla/BUILD
@ -14,6 +14,7 @@
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_oss",
|
||||
"nanobind_extension",
|
||||
"py_deps",
|
||||
"py_strict_library",
|
||||
@ -35,6 +36,114 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "xla_extension",
|
||||
srcs = ["xla.cc"],
|
||||
pytype_deps = py_deps(["numpy"]),
|
||||
pytype_srcs = glob(["xla_extension/*.pyi"]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/log:initialize",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@nanobind",
|
||||
"@tsl//tsl/platform",
|
||||
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@xla//xla:literal",
|
||||
"@xla//xla:shape_util",
|
||||
"@xla//xla:types",
|
||||
"@xla//xla:util",
|
||||
"@xla//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"@xla//xla/ffi:ffi_api",
|
||||
"@xla//xla/pjrt:exceptions",
|
||||
"@xla//xla/pjrt:mlir_to_hlo",
|
||||
"@xla//xla/pjrt:pjrt_api",
|
||||
"@xla//xla/pjrt:pjrt_c_api_client",
|
||||
"@xla//xla/pjrt:pjrt_client",
|
||||
"@xla//xla/pjrt:pjrt_common",
|
||||
"@xla//xla/pjrt:pjrt_compiler",
|
||||
"@xla//xla/pjrt:pjrt_executable",
|
||||
"@xla//xla/pjrt:pjrt_layout",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"@xla//xla/pjrt/distributed",
|
||||
"@xla//xla/pjrt/distributed:client",
|
||||
"@xla//xla/pjrt/distributed:key_value_store_interface",
|
||||
"@xla//xla/pjrt/distributed:protocol_proto_cc",
|
||||
"@xla//xla/pjrt/distributed:service",
|
||||
"@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
|
||||
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
|
||||
"@xla//xla/python:config",
|
||||
"@xla//xla/python:custom_call_sharding",
|
||||
"@xla//xla/python:dlpack",
|
||||
"@xla//xla/python:guard_lib",
|
||||
"@xla//xla/python:jax_jit",
|
||||
"@xla//xla/python:logging",
|
||||
"@xla//xla/python:mlir",
|
||||
"@xla//xla/python:nb_absl_flat_hash_map",
|
||||
"@xla//xla/python:nb_absl_span",
|
||||
"@xla//xla/python:nb_class_ptr",
|
||||
"@xla//xla/python:ops",
|
||||
"@xla//xla/python:pjit",
|
||||
"@xla//xla/python:pmap_lib",
|
||||
"@xla//xla/python:pprof_profile_builder",
|
||||
"@xla//xla/python:profiler",
|
||||
"@xla//xla/python:py_client",
|
||||
"@xla//xla/python:python_ref_manager",
|
||||
"@xla//xla/python:pytree",
|
||||
"@xla//xla/python:refine_polymorphic_shapes",
|
||||
"@xla//xla/python:sdy",
|
||||
"@xla//xla/python:traceback",
|
||||
"@xla//xla/python:types",
|
||||
"@xla//xla/python:util",
|
||||
"@xla//xla/python:weakref_lru_cache",
|
||||
"@xla//xla/python:xla_compiler",
|
||||
"@xla//xla/python/ifrt",
|
||||
"@xla//xla/python/ifrt:plugin_program",
|
||||
"@xla//xla/python/ifrt:plugin_program_serdes",
|
||||
"@xla//xla/python/ifrt_proxy/client:py_module",
|
||||
"@xla//xla/python/pjrt_ifrt",
|
||||
"@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
|
||||
"@xla//xla/python/pjrt_ifrt:xla_ifrt",
|
||||
"@xla//xla/tsl/concurrency:ref_count",
|
||||
"@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager",
|
||||
"@xla//xla/tsl/platform:logging",
|
||||
"@xla//xla/tsl/platform:status",
|
||||
"@xla//xla/tsl/platform:statusor",
|
||||
"@xla//xla/tsl/platform/cloud:gcs_file_system",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
] + select({
|
||||
# gloo tcp transport only builds on linux
|
||||
"@xla//xla/tsl:macos": [
|
||||
"@gloo//:transport_uv",
|
||||
"@xla//xla/backends/cpu/collectives:gloo_collectives",
|
||||
"@xla//xla/backends/cpu/collectives:gloo_kv_store",
|
||||
],
|
||||
"@xla//xla/tsl:windows": [],
|
||||
"//conditions:default": [
|
||||
"@gloo//:transport_tcp",
|
||||
"@xla//xla/backends/cpu/collectives:gloo_collectives",
|
||||
"@xla//xla/backends/cpu/collectives:gloo_kv_store",
|
||||
"@xla//xla/python/transfer:py_socket_transfer",
|
||||
],
|
||||
}) + select({
|
||||
# mpitrampoline does not build on windows
|
||||
"@xla//xla/tsl:windows": [],
|
||||
# we support MPI collectives only in OSS builds
|
||||
"//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]),
|
||||
}),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "xla_client",
|
||||
srcs = ["xla_client.py"],
|
||||
@ -43,7 +152,7 @@ pytype_strict_library(
|
||||
deps = py_deps([
|
||||
"numpy",
|
||||
"ml_dtypes",
|
||||
]) + ["@xla//xla/python:xla_extension"],
|
||||
]) + [":xla_extension"],
|
||||
)
|
||||
|
||||
py_strict_test(
|
||||
|
965
jaxlib/xla/xla.cc
Normal file
965
jaxlib/xla/xla.cc
Normal file
@ -0,0 +1,965 @@
|
||||
/* Copyright 2019 The JAX Authors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/hash/hash.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "nanobind/nb_defs.h"
|
||||
#include "nanobind/stl/function.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/optional.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/pair.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/set.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/string.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/variant.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/vector.h" // IWYU pragma: keep
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||
#include "xla/pjrt/distributed/client.h"
|
||||
#include "xla/pjrt/distributed/distributed.h"
|
||||
#include "xla/pjrt/distributed/protocol.pb.h"
|
||||
#include "xla/pjrt/distributed/service.h"
|
||||
#include "xla/pjrt/pjrt_compiler.h"
|
||||
#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h"
|
||||
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
|
||||
#include "xla/pjrt/status_casters.h"
|
||||
#include "xla/python/ifrt/array.h"
|
||||
#include "xla/python/ifrt/device_list.h"
|
||||
#include "xla/python/ifrt/executable.h"
|
||||
#include "xla/python/ifrt/topology.h"
|
||||
#include "xla/python/ifrt_proxy/client/py_module.h"
|
||||
#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h"
|
||||
#include "xla/python/py_client.h"
|
||||
#include "xla/python/py_program.h"
|
||||
#include "xla/python/sdy.h"
|
||||
#include "xla/tsl/concurrency/ref_count.h"
|
||||
#include "xla/tsl/python/lib/core/numpy.h" // NOLINT
|
||||
|
||||
#if defined(__linux__)
|
||||
#include "gloo/transport/tcp/attr.h"
|
||||
#include "gloo/transport/tcp/device.h"
|
||||
#include "xla/backends/cpu/collectives/gloo_collectives.h"
|
||||
#include "xla/backends/cpu/collectives/gloo_kv_store.h"
|
||||
#include "xla/python/transfer/py_socket_transfer.h"
|
||||
#elif defined(__APPLE__)
|
||||
#include "gloo/transport/uv/device.h"
|
||||
#include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT
|
||||
#include "xla/backends/cpu/collectives/gloo_kv_store.h" // NOLINT
|
||||
#endif // defined(__linux__)
|
||||
|
||||
#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE)
|
||||
#include "xla/backends/cpu/collectives/mpi_collectives.h"
|
||||
#endif // !_WIN32 && !PLATFORM_GOOGLE
|
||||
|
||||
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
||||
#include "xla/pjrt/exceptions.h"
|
||||
#include "xla/pjrt/pjrt_api.h"
|
||||
#include "xla/pjrt/pjrt_c_api_client.h"
|
||||
#include "xla/pjrt/pjrt_client.h"
|
||||
#include "xla/pjrt/pjrt_common.h"
|
||||
#include "xla/pjrt/pjrt_executable.h"
|
||||
#include "xla/pjrt/pjrt_layout.h"
|
||||
#include "xla/python/config.h"
|
||||
#include "xla/python/custom_call_sharding.h"
|
||||
#include "xla/python/dlpack.h"
|
||||
#include "xla/python/guard_lib.h"
|
||||
#include "xla/python/jax_jit.h"
|
||||
#include "xla/python/logging.h" // IWYU pragma: keep
|
||||
#include "xla/python/mlir.h"
|
||||
#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep
|
||||
#include "xla/python/nb_absl_span.h" // IWYU pragma: keep
|
||||
#include "xla/python/nb_class_ptr.h"
|
||||
#include "xla/python/ops.h"
|
||||
#include "xla/python/pjit.h"
|
||||
#include "xla/python/pjrt_ifrt/pjrt_client.h"
|
||||
#include "xla/python/pjrt_ifrt/pjrt_executable.h"
|
||||
#include "xla/python/pjrt_ifrt/pjrt_topology.h"
|
||||
#include "xla/python/pmap_lib.h"
|
||||
#include "xla/python/pprof_profile_builder.h"
|
||||
#include "xla/python/profiler.h"
|
||||
#include "xla/python/py_array.h"
|
||||
#include "xla/python/py_compile_only_client.h"
|
||||
#include "xla/python/py_device.h"
|
||||
#include "xla/python/py_device_list.h"
|
||||
#include "xla/python/py_executable.h"
|
||||
#include "xla/python/py_memory_space.h"
|
||||
#include "xla/python/python_ref_manager.h"
|
||||
#include "xla/python/pytree.h"
|
||||
#include "xla/python/sharding.h"
|
||||
#include "xla/python/traceback.h"
|
||||
#include "xla/python/weakref_lru_cache.h"
|
||||
#include "xla/python/xla_compiler.h"
|
||||
#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h"
|
||||
#include "xla/tsl/platform/status.h"
|
||||
#include "tsl/platform/platform.h"
|
||||
|
||||
// TODO(phawkins): remove host_id properties after JAX is update to avoid them.
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
bool IsOptimizedBuild() {
|
||||
#if NDEBUG
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif // NDEBUG
|
||||
}
|
||||
|
||||
// Is*san reports whether the build is under that particular sanitizer.
|
||||
bool IsAsan() {
|
||||
#if defined(ADDRESS_SANITIZER)
|
||||
return true;
|
||||
#else // defined(ADDRESS_SANITIZER)
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsMsan() {
|
||||
#if defined(MEMORY_SANITIZER)
|
||||
return true;
|
||||
#else // defined(MEMORY_SANITIZER)
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsTsan() {
|
||||
#if defined(THREAD_SANITIZER)
|
||||
return true;
|
||||
#else // defined(THREAD_SANITIZER)
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
// IsSanitized reports whether the build is under any sanitizer.
|
||||
bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); }
|
||||
|
||||
} // namespace
|
||||
|
||||
NB_MODULE(xla_extension, m) {
|
||||
// Initialize ABSL logging because code within XLA uses it.
|
||||
#ifndef PLATFORM_GOOGLE
|
||||
InitializeAbslLogging();
|
||||
#endif // PLATFORM_GOOGLE
|
||||
|
||||
// We seem to get a fair number of leak warnings from nanobind. It's unclear
|
||||
// whether these are false positives or not.
|
||||
nb::set_leak_warnings(false);
|
||||
|
||||
tsl::ImportNumpy();
|
||||
|
||||
// Exceptions
|
||||
nb::exception<XlaRuntimeError> xla_runtime_error(m, "XlaRuntimeError",
|
||||
PyExc_RuntimeError);
|
||||
xla_runtime_error.attr("__doc__") = nb::str(
|
||||
"Runtime errors thrown by the JAX runtime. While the JAX runtime may "
|
||||
"raise other exceptions as well, most exceptions thrown by the runtime "
|
||||
"are instances of this class.");
|
||||
|
||||
// Types
|
||||
nb::enum_<PrimitiveType>(m, "PrimitiveType", nb::is_arithmetic())
|
||||
.value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
|
||||
.value("PRED", PRED)
|
||||
.value("S4", S4)
|
||||
.value("S8", S8)
|
||||
.value("S16", S16)
|
||||
.value("S32", S32)
|
||||
.value("S64", S64)
|
||||
.value("U4", U4)
|
||||
.value("U8", U8)
|
||||
.value("U16", U16)
|
||||
.value("U32", U32)
|
||||
.value("U64", U64)
|
||||
.value("F16", F16)
|
||||
.value("F4E2M1FN", F4E2M1FN)
|
||||
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
|
||||
// .value("F8E3M4", F8E3M4)
|
||||
// .value("F8E4M3", F8E4M3)
|
||||
.value("F8E8M0FNU", F8E8M0FNU)
|
||||
.value("F8E4M3FN", F8E4M3FN)
|
||||
.value("F8E4M3B11FNUZ", F8E4M3B11FNUZ)
|
||||
.value("F8E4M3FNUZ", F8E4M3FNUZ)
|
||||
.value("F8E5M2", F8E5M2)
|
||||
.value("F8E5M2FNUZ", F8E5M2FNUZ)
|
||||
.value("BF16", BF16)
|
||||
.value("F32", F32)
|
||||
.value("F64", F64)
|
||||
.value("C64", C64)
|
||||
.value("C128", C128)
|
||||
.value("TUPLE", TUPLE)
|
||||
.value("OPAQUE_TYPE", OPAQUE_TYPE)
|
||||
.value("TOKEN", TOKEN);
|
||||
|
||||
// Must be before PyClient.compile.
|
||||
BuildXlaCompilerSubmodule(m);
|
||||
|
||||
PyDevice::RegisterPythonType(m);
|
||||
PyMemorySpace::RegisterPythonType(m);
|
||||
PyClient::RegisterPythonTypes(m);
|
||||
|
||||
nb::enum_<ifrt::ArrayCopySemantics>(m, "ArrayCopySemantics",
|
||||
nb::is_arithmetic())
|
||||
.value("ALWAYS_COPY", ifrt::ArrayCopySemantics::kAlwaysCopy)
|
||||
.value("REUSE_INPUT", ifrt::ArrayCopySemantics::kReuseInput)
|
||||
.value("DONATE_INPUT", ifrt::ArrayCopySemantics::kDonateInput);
|
||||
|
||||
nb::class_<PjRtLayout>(m, "PjRtLayout")
|
||||
.def("__str__", &PjRtLayout::ToString)
|
||||
.def("__eq__", [](const PjRtLayout& layout,
|
||||
const PjRtLayout& other) { return layout == other; })
|
||||
.def("__hash__",
|
||||
[](const PjRtLayout& layout) { return absl::HashOf(layout); })
|
||||
.def("_xla_layout", &PjRtLayout::xla_layout)
|
||||
.def("__getstate__",
|
||||
[](const PjRtLayout& layout) -> nb::tuple {
|
||||
absl::StatusOr<std::string> serialized = layout.Serialize();
|
||||
ThrowIfError(serialized.status());
|
||||
return nb::make_tuple(
|
||||
nb::bytes(serialized->data(), serialized->size()));
|
||||
})
|
||||
.def("__setstate__", [](PjRtLayout* self, nb::tuple t) {
|
||||
nb::bytes serialized = nb::cast<nb::bytes>(t[0]);
|
||||
absl::StatusOr<std::shared_ptr<const PjRtLayout>> layout =
|
||||
PjRtLayout::Deserialize(
|
||||
absl::string_view(serialized.c_str(), serialized.size()));
|
||||
ThrowIfError(layout.status());
|
||||
new (self) PjRtLayout((*layout)->xla_layout());
|
||||
});
|
||||
|
||||
jax::BuildWeakrefLRUCacheAPI(m);
|
||||
|
||||
nb::class_<xla::cpu::CpuCollectives> cpu_collectives(m, "CpuCollectives");
|
||||
|
||||
m.def(
|
||||
"make_gloo_tcp_collectives",
|
||||
[](std::shared_ptr<DistributedRuntimeClient> distributed_client,
|
||||
|
||||
std::optional<std::string> hostname,
|
||||
std::optional<std::string> interface)
|
||||
-> std::shared_ptr<xla::cpu::CpuCollectives> {
|
||||
#if defined(__linux__)
|
||||
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
|
||||
if (distributed_client != nullptr) {
|
||||
kv_store = GetDistributedKeyValueStore(distributed_client,
|
||||
/*key_prefix=*/"cpu:");
|
||||
}
|
||||
auto gloo_kv_store = std::make_unique<cpu::GlooKeyValueStore>(kv_store);
|
||||
auto tcp_attrs = gloo::transport::tcp::attr();
|
||||
if (hostname) {
|
||||
tcp_attrs.hostname = *hostname;
|
||||
}
|
||||
if (interface) {
|
||||
tcp_attrs.iface = *interface;
|
||||
}
|
||||
auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs);
|
||||
return std::make_shared<cpu::GlooCollectives>(std::move(gloo_kv_store),
|
||||
std::move(tcp_device));
|
||||
#elif defined(__APPLE__)
|
||||
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
|
||||
if (distributed_client != nullptr) {
|
||||
kv_store = GetDistributedKeyValueStore(distributed_client,
|
||||
/*key_prefix=*/"cpu:");
|
||||
}
|
||||
auto gloo_kv_store = std::make_unique<cpu::GlooKeyValueStore>(kv_store);
|
||||
auto uv_attrs = gloo::transport::uv::attr();
|
||||
if (hostname) {
|
||||
uv_attrs.hostname = *hostname;
|
||||
}
|
||||
if (interface) {
|
||||
uv_attrs.iface = *interface;
|
||||
}
|
||||
auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs);
|
||||
return std::make_shared<cpu::GlooCollectives>(std::move(gloo_kv_store),
|
||||
std::move(uv_device));
|
||||
#else // defined(__linux__)
|
||||
throw xla::XlaRuntimeError(
|
||||
"make_gloo_tcp_collectives only implemented for linux and macos");
|
||||
#endif // defined(__linux__)
|
||||
},
|
||||
nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt,
|
||||
nb::arg("interface").none() = std::nullopt);
|
||||
|
||||
#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE)
|
||||
nb::class_<cpu::MpiCollectives> mpi_collectives(m, "MpiCollectives",
|
||||
cpu_collectives);
|
||||
mpi_collectives.def("Init", &cpu::MpiCollectives::Init);
|
||||
mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize);
|
||||
m.def("make_mpi_collectives", []() -> std::shared_ptr<cpu::MpiCollectives> {
|
||||
return std::make_shared<cpu::MpiCollectives>();
|
||||
});
|
||||
#else // !_WIN32 && !PLATFORM_GOOGLE
|
||||
m.def("make_mpi_collectives",
|
||||
[]() -> std::shared_ptr<xla::cpu::CpuCollectives> {
|
||||
throw xla::XlaRuntimeError(
|
||||
"make_mpi_collectives is not implemented for Windows");
|
||||
});
|
||||
#endif // !_WIN32 && !PLATFORM_GOOGLE
|
||||
|
||||
m.def(
|
||||
"get_tfrt_cpu_client",
|
||||
[](bool asynchronous,
|
||||
std::shared_ptr<DistributedRuntimeClient> distributed_client,
|
||||
int node_id, int num_nodes,
|
||||
std::shared_ptr<xla::cpu::CpuCollectives> collectives,
|
||||
std::optional<int> num_devices) -> nb_class_ptr<PyClient> {
|
||||
std::unique_ptr<ifrt::PjRtClient> ifrt_client;
|
||||
{
|
||||
nb::gil_scoped_release gil_release;
|
||||
xla::CpuClientOptions options;
|
||||
|
||||
options.asynchronous = asynchronous;
|
||||
options.collectives = std::move(collectives);
|
||||
options.process_id = node_id;
|
||||
options.cpu_device_count = num_devices;
|
||||
std::unique_ptr<PjRtClient> client =
|
||||
xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options)));
|
||||
ifrt::PjRtClient::CreateOptions ifrt_options;
|
||||
ifrt_options.pjrt_client =
|
||||
std::shared_ptr<PjRtClient>(std::move(client));
|
||||
if (distributed_client != nullptr) {
|
||||
ifrt_options.kv_store =
|
||||
GetDistributedKeyValueStore(distributed_client,
|
||||
/*key_prefix=*/"cpu:");
|
||||
ifrt_options.process_id = node_id;
|
||||
ifrt_options.num_processes = num_nodes;
|
||||
}
|
||||
ifrt_client =
|
||||
ValueOrThrow(ifrt::PjRtClient::Create(std::move(ifrt_options)));
|
||||
}
|
||||
return PyClient::Make(std::move(ifrt_client));
|
||||
},
|
||||
nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr,
|
||||
nb::arg("node_id") = 0, nb::arg("num_nodes") = 1,
|
||||
nb::arg("collectives").none() =
|
||||
std::shared_ptr<xla::cpu::CpuCollectives>(),
|
||||
nb::arg("num_devices").none() = std::nullopt);
|
||||
m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool {
|
||||
absl::StatusOr<const PJRT_Api*> pjrt_api = pjrt::PjrtApi(platform_name);
|
||||
return pjrt_api.ok();
|
||||
});
|
||||
m.def(
|
||||
"load_pjrt_plugin",
|
||||
[](std::string platform_name, std::optional<std::string> library_path,
|
||||
std::optional<nb::capsule> c_api) -> nb::capsule {
|
||||
if (library_path.has_value()) {
|
||||
const PJRT_Api* api = xla::ValueOrThrow(
|
||||
pjrt::LoadPjrtPlugin(platform_name, *library_path));
|
||||
return nb::capsule(absl::bit_cast<void*>(api), "pjrt_c_api");
|
||||
}
|
||||
if (absl::string_view(c_api->name()) != "pjrt_c_api") {
|
||||
throw nb::value_error(
|
||||
"c_api argument to load_pjrt_plugin is not a pjrt_c_api "
|
||||
"capsule.");
|
||||
}
|
||||
xla::ThrowIfError(pjrt::SetPjrtApi(
|
||||
platform_name, static_cast<const PJRT_Api*>(c_api->data())));
|
||||
return *c_api;
|
||||
},
|
||||
nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt,
|
||||
nb::arg("c_api").none() = std::nullopt);
|
||||
m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool {
|
||||
return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name));
|
||||
});
|
||||
m.def("initialize_pjrt_plugin", [](std::string platform_name) {
|
||||
return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name));
|
||||
});
|
||||
|
||||
m.def(
|
||||
"get_c_api_client",
|
||||
[](std::string platform_name,
|
||||
const absl::flat_hash_map<std::string, PjRtValueType>& options,
|
||||
std::shared_ptr<DistributedRuntimeClient> distributed_client)
|
||||
-> nb_class_ptr<PyClient> {
|
||||
std::unique_ptr<ifrt::PjRtClient> ifrt_client;
|
||||
{
|
||||
nb::gil_scoped_release gil_release;
|
||||
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
|
||||
if (distributed_client != nullptr) {
|
||||
kv_store = GetDistributedKeyValueStore(
|
||||
distributed_client,
|
||||
/*key_prefix=*/absl::StrCat(platform_name, ":"));
|
||||
}
|
||||
std::unique_ptr<PjRtClient> c_api_client = xla::ValueOrThrow(
|
||||
GetCApiClient(platform_name, options, kv_store));
|
||||
ifrt_client = ifrt::PjRtClient::Create(std::move(c_api_client));
|
||||
}
|
||||
return PyClient::Make(std::move(ifrt_client));
|
||||
},
|
||||
nb::arg("platform_name"),
|
||||
nb::arg("options") = absl::flat_hash_map<std::string, PjRtValueType>(),
|
||||
nb::arg("distributed_client").none() = nullptr);
|
||||
// TODO(b/322357665): Delete this method after TPU plugin changes to use the
|
||||
// standard registration.
|
||||
m.def("get_default_c_api_topology",
|
||||
[](std::string platform_name, std::string topology_name,
|
||||
const absl::flat_hash_map<std::string, PjRtValueType>& options)
|
||||
-> std::shared_ptr<ifrt::Topology> {
|
||||
return std::make_shared<ifrt::PjRtTopology>(xla::ValueOrThrow(
|
||||
GetCApiTopology(platform_name, topology_name, options)));
|
||||
});
|
||||
m.def("get_c_api_topology",
|
||||
[](nb::capsule c_api, std::string topology_name,
|
||||
const absl::flat_hash_map<std::string, PjRtValueType>& options)
|
||||
-> std::shared_ptr<ifrt::Topology> {
|
||||
if (absl::string_view(c_api.name()) != "pjrt_c_api") {
|
||||
throw nb::value_error(
|
||||
"Argument to get_c_api_topology was not a pjrt_c_api capsule.");
|
||||
}
|
||||
return std::make_shared<ifrt::PjRtTopology>(xla::ValueOrThrow(
|
||||
GetCApiTopology(static_cast<const PJRT_Api*>(c_api.data()),
|
||||
topology_name, options)));
|
||||
});
|
||||
m.def("get_topology_for_devices",
|
||||
[](const std::vector<nb_class_ptr<PyDevice>>& py_devices) {
|
||||
if (py_devices.empty()) {
|
||||
throw nb::value_error(
|
||||
"get_topology_for_devices requires >= 1 devices.");
|
||||
}
|
||||
auto client = py_devices[0]->client();
|
||||
absl::InlinedVector<ifrt::Device*, 1> ifrt_devices;
|
||||
ifrt_devices.reserve(py_devices.size());
|
||||
for (const auto& py_device : py_devices) {
|
||||
if (py_device->client().get() != client.get()) {
|
||||
throw nb::value_error(
|
||||
"devices passed to get_topology_for_devices come from "
|
||||
"different clients.");
|
||||
}
|
||||
ifrt_devices.push_back(py_device->device());
|
||||
}
|
||||
ifrt::DeviceListRef device_list =
|
||||
client->ifrt_client()->MakeDeviceList(ifrt_devices);
|
||||
return xla::ValueOrThrow(
|
||||
client->ifrt_client()->GetTopologyForDevices(device_list));
|
||||
});
|
||||
|
||||
TF_CHECK_OK(PyArray::RegisterTypes(m));
|
||||
jax::PyDeviceList::Register(m);
|
||||
jax::RegisterSharding(m);
|
||||
|
||||
nb::class_<CompiledMemoryStats>(m, "CompiledMemoryStats")
|
||||
.def_rw("generated_code_size_in_bytes",
|
||||
&CompiledMemoryStats::generated_code_size_in_bytes)
|
||||
.def_rw("argument_size_in_bytes",
|
||||
&CompiledMemoryStats::argument_size_in_bytes)
|
||||
.def_rw("output_size_in_bytes",
|
||||
&CompiledMemoryStats::output_size_in_bytes)
|
||||
.def_rw("alias_size_in_bytes", &CompiledMemoryStats::alias_size_in_bytes)
|
||||
.def_rw("temp_size_in_bytes", &CompiledMemoryStats::temp_size_in_bytes)
|
||||
.def_rw("host_generated_code_size_in_bytes",
|
||||
&CompiledMemoryStats::host_generated_code_size_in_bytes)
|
||||
.def_rw("host_argument_size_in_bytes",
|
||||
&CompiledMemoryStats::host_argument_size_in_bytes)
|
||||
.def_rw("host_output_size_in_bytes",
|
||||
&CompiledMemoryStats::host_output_size_in_bytes)
|
||||
.def_rw("host_alias_size_in_bytes",
|
||||
&CompiledMemoryStats::host_alias_size_in_bytes)
|
||||
.def_rw("host_temp_size_in_bytes",
|
||||
&CompiledMemoryStats::host_temp_size_in_bytes)
|
||||
.def_prop_ro("serialized_hlo_proto",
|
||||
[](const CompiledMemoryStats& cms) -> nb::bytes {
|
||||
return nb::bytes(cms.serialized_hlo_proto.data(),
|
||||
cms.serialized_hlo_proto.size());
|
||||
})
|
||||
.def("__str__", &CompiledMemoryStats::DebugString);
|
||||
|
||||
nb::class_<PyExecuteResults>(m, "ExecuteResults")
|
||||
.def("__len__", [](PyExecuteResults& results) { return results.Size(); })
|
||||
.def("disassemble_into_single_device_arrays",
|
||||
&PyExecuteResults::DisassembleIntoSingleDeviceArrays)
|
||||
.def("disassemble_prefix_into_single_device_arrays",
|
||||
&PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays)
|
||||
.def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers)
|
||||
.def("consume_token", &PyExecuteResults::ConsumeToken);
|
||||
|
||||
nb::class_<PyLoadedExecutable>(m, "LoadedExecutable")
|
||||
.def_prop_ro("client", &PyLoadedExecutable::client)
|
||||
.def("local_devices", &PyLoadedExecutable::AddressableDevices)
|
||||
.def("size_of_generated_code_in_bytes",
|
||||
&PyLoadedExecutable::SizeOfGeneratedCodeInBytes)
|
||||
.def(
|
||||
"get_compiled_memory_stats",
|
||||
xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats))
|
||||
.def("delete", &PyLoadedExecutable::Delete)
|
||||
.def("execute_sharded_on_local_devices",
|
||||
xla::ValueOrThrowWrapper(
|
||||
&PyLoadedExecutable::ExecuteShardedOnLocalDevices),
|
||||
nb::arg("arguments"))
|
||||
.def("execute_sharded_on_local_devices_with_tokens",
|
||||
xla::ValueOrThrowWrapper(
|
||||
&PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens),
|
||||
nb::arg("arguments"))
|
||||
// TODO(parkers): Switch execute_sharded_on_local_devices* to this.
|
||||
.def("execute_sharded",
|
||||
xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded),
|
||||
nb::arg("arguments"), nb::arg("with_tokens") = false)
|
||||
.def("hlo_modules", ValueOrThrowWrapper(&PyLoadedExecutable::HloModules))
|
||||
.def("get_output_memory_kinds",
|
||||
xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds))
|
||||
.def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings)
|
||||
.def("get_parameter_layouts",
|
||||
xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetParameterLayouts))
|
||||
.def("get_output_layouts",
|
||||
xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputLayouts))
|
||||
.def("get_parameter_shardings",
|
||||
&PyLoadedExecutable::GetParameterShardings)
|
||||
.def("keep_alive", &PyLoadedExecutable::KeepAlive)
|
||||
.def("cost_analysis",
|
||||
[](const PyLoadedExecutable& self) {
|
||||
auto map = ValueOrThrow(self.GetCostAnalysis());
|
||||
return ifrt::ToPjRtAttributeMap(std::move(map));
|
||||
})
|
||||
.def_prop_ro("traceback", &PyLoadedExecutable::traceback)
|
||||
.def_prop_ro("fingerprint", [](PyLoadedExecutable* exec) -> nb::object {
|
||||
if (exec->fingerprint().has_value()) {
|
||||
return nb::bytes(exec->fingerprint()->data(),
|
||||
exec->fingerprint()->size());
|
||||
} else {
|
||||
return nb::none();
|
||||
}
|
||||
});
|
||||
nb::class_<PyToken> token(m, "Token");
|
||||
token.def("block_until_ready",
|
||||
[](PyToken& self) { xla::ThrowIfError(self.Await()); });
|
||||
|
||||
nb::class_<PyShardedToken> sharded_token(m, "ShardedToken");
|
||||
sharded_token.def("block_until_ready", [](PyShardedToken& self) {
|
||||
xla::ThrowIfError(self.Await());
|
||||
});
|
||||
sharded_token.def("get_token", &PyShardedToken::GetPyToken);
|
||||
|
||||
m.def("buffer_to_dlpack_managed_tensor",
|
||||
xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor),
|
||||
nb::arg("buffer"), nb::arg("stream").none() = nb::none());
|
||||
m.def(
|
||||
"dlpack_managed_tensor_to_buffer",
|
||||
[](const nb::capsule& tensor, nb_class_ptr<PyDevice> device,
|
||||
std::optional<std::intptr_t> stream) {
|
||||
return xla::ValueOrThrow(DLPackManagedTensorToBuffer(
|
||||
tensor, device->device(), device->client(), stream));
|
||||
},
|
||||
nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none());
|
||||
// Legacy overload
|
||||
m.def(
|
||||
"dlpack_managed_tensor_to_buffer",
|
||||
[](const nb::capsule& tensor,
|
||||
std::optional<nb_class_ptr<PyClient>> cpu_client,
|
||||
std::optional<nb_class_ptr<PyClient>> gpu_client) {
|
||||
return xla::ValueOrThrow(DLPackManagedTensorToBuffer(
|
||||
tensor, std::move(cpu_client), std::move(gpu_client)));
|
||||
},
|
||||
nb::arg("dlpack"), nb::arg("cpu_backend").none() = nb::none(),
|
||||
nb::arg("gpu_backend").none() = nb::none());
|
||||
m.def("cuda_array_interface_to_buffer",
|
||||
xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"),
|
||||
nb::arg("gpu_backend").none() = nb::none(),
|
||||
nb::arg("device_id").none() = nb::none());
|
||||
|
||||
jax::BuildConfigSubmodule(m);
|
||||
BuildIfrtProgramsSubmodule(m);
|
||||
BuildProfilerSubmodule(m);
|
||||
BuildOpsSubmodule(m);
|
||||
BuildPytreeSubmodule(m);
|
||||
jax::BuildGuardSubmodule(m);
|
||||
jax::BuildJaxjitSubmodule(m);
|
||||
jax::BuildPmapSubmodule(m);
|
||||
jax::BuildPjitSubmodule(m);
|
||||
BuildTracebackSubmodule(m);
|
||||
BuildMlirSubmodule(m);
|
||||
BuildSdySubmodule(m);
|
||||
BuildCustomCallShardingPybindAPI(m);
|
||||
#if defined(__linux__)
|
||||
aux::RegisterTransferServerTypes(m);
|
||||
#endif // defined(__linux__)
|
||||
|
||||
// The following uses python bindings for PyClient defined above using
|
||||
// pybind11, and hence needs pybind11::module_ (not just nanobind::module_).
|
||||
xla::ifrt::proxy::BuildIfrtProxySubmodule(m);
|
||||
|
||||
nb::class_<tsl::PreemptionSyncManager> preemption_sync_manager(
|
||||
m, "PreemptionSyncManager");
|
||||
preemption_sync_manager
|
||||
.def(
|
||||
"initialize",
|
||||
[](tsl::PreemptionSyncManager& manager,
|
||||
DistributedRuntimeClient* client) {
|
||||
tsl::CoordinationServiceAgent* agent =
|
||||
xla::ValueOrThrow(client->GetCoordinationServiceAgent());
|
||||
xla::ThrowIfError(manager.Initialize(agent));
|
||||
},
|
||||
nb::arg("distributed_client"))
|
||||
.def("reached_sync_point",
|
||||
[](tsl::PreemptionSyncManager& manager, int step_counter) {
|
||||
return manager.ReachedSyncPoint(step_counter);
|
||||
});
|
||||
m.def("create_preemption_sync_manager",
|
||||
[]() { return tsl::CreatePreemptionSyncManager(); });
|
||||
|
||||
nb::class_<DistributedRuntimeService> distributed_runtime_service(
|
||||
m, "DistributedRuntimeService");
|
||||
distributed_runtime_service.def("shutdown",
|
||||
&DistributedRuntimeService::Shutdown,
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
nb::class_<DistributedRuntimeClient> distributed_runtime_client(
|
||||
m, "DistributedRuntimeClient");
|
||||
distributed_runtime_client
|
||||
.def("connect",
|
||||
[](DistributedRuntimeClient& self) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
xla::ThrowIfError(self.Connect());
|
||||
})
|
||||
.def("shutdown",
|
||||
[](DistributedRuntimeClient& self) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
xla::ThrowIfError(self.Shutdown());
|
||||
})
|
||||
// This method assumes that the value is a Python string. Use
|
||||
// `blocking_key_value_get_bytes()` if key_value_set() was called with a
|
||||
// Python bytes object as its value.
|
||||
.def(
|
||||
"blocking_key_value_get",
|
||||
[](DistributedRuntimeClient& client, std::string key,
|
||||
int64_t timeout_in_ms) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
return xla::ValueOrThrow(client.BlockingKeyValueGet(
|
||||
key, absl::Milliseconds(timeout_in_ms)));
|
||||
},
|
||||
nb::arg("key"), nb::arg("timeout_in_ms"))
|
||||
// Same as `blocking_key_value_get()`, but retrieves the raw Python byte
|
||||
// values explicitly.
|
||||
.def(
|
||||
"blocking_key_value_get_bytes",
|
||||
[](DistributedRuntimeClient& client, std::string key,
|
||||
int64_t timeout_in_ms) -> nb::bytes {
|
||||
std::string result;
|
||||
{
|
||||
nb::gil_scoped_release gil_release;
|
||||
result = xla::ValueOrThrow(client.BlockingKeyValueGet(
|
||||
key, absl::Milliseconds(timeout_in_ms)));
|
||||
}
|
||||
return nb::bytes(result.data(), result.size());
|
||||
},
|
||||
nb::arg("key"), nb::arg("timeout_in_ms"))
|
||||
.def(
|
||||
"key_value_try_get",
|
||||
[](DistributedRuntimeClient& client, std::string key) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
return xla::ValueOrThrow(client.KeyValueTryGet(key));
|
||||
},
|
||||
nb::arg("key"))
|
||||
.def(
|
||||
"key_value_try_get_bytes",
|
||||
[](DistributedRuntimeClient& client, std::string key) -> nb::bytes {
|
||||
std::string result;
|
||||
{
|
||||
nb::gil_scoped_release gil_release;
|
||||
result = xla::ValueOrThrow(client.KeyValueTryGet(key));
|
||||
}
|
||||
return nb::bytes(result.data(), result.size());
|
||||
},
|
||||
nb::arg("key"))
|
||||
.def(
|
||||
"wait_at_barrier",
|
||||
[](DistributedRuntimeClient& client, std::string barrier_id,
|
||||
int64_t timeout_in_ms,
|
||||
std::optional<std::vector<int32_t>> process_ids) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
xla::ThrowIfError(client.WaitAtBarrier(
|
||||
barrier_id, absl::Milliseconds(timeout_in_ms), process_ids));
|
||||
},
|
||||
nb::arg("barrier_id"), nb::arg("timeout_in_ms"),
|
||||
nb::arg("process_ids") = std::nullopt)
|
||||
.def(
|
||||
"get_live_nodes",
|
||||
[](DistributedRuntimeClient& client,
|
||||
std::vector<int32_t> process_ids) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
return xla::ValueOrThrow(client.GetLiveNodes(process_ids));
|
||||
},
|
||||
nb::arg("process_ids"))
|
||||
// The key must be a string, but the value can either be a Python string
|
||||
// or bytes object.
|
||||
// With Python string values, use `key_value_set()` and
|
||||
// `blocking_key_value_get()`.
|
||||
// With Python byte object values, use `key_value_set()` and
|
||||
// `blocking_key_value_get_bytes()`.
|
||||
.def(
|
||||
"key_value_set",
|
||||
[](DistributedRuntimeClient& client, absl::string_view key,
|
||||
absl::string_view value, bool allow_overwrite) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite));
|
||||
},
|
||||
nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false)
|
||||
// The key must be a string, but the value must a
|
||||
// Python bytes object.
|
||||
// Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`.
|
||||
.def(
|
||||
"key_value_set_bytes",
|
||||
[](DistributedRuntimeClient& client, absl::string_view key,
|
||||
nb::bytes value, bool allow_overwrite) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
xla::ThrowIfError(client.KeyValueSet(
|
||||
key, absl::string_view(value.c_str(), value.size()),
|
||||
allow_overwrite));
|
||||
},
|
||||
nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false)
|
||||
// Assumes that all values in the directory are Python strings.
|
||||
.def(
|
||||
"key_value_dir_get",
|
||||
[](DistributedRuntimeClient& client, absl::string_view key) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
return xla::ValueOrThrow(client.KeyValueDirGet(key));
|
||||
},
|
||||
nb::arg("key"))
|
||||
// Assumes that all values in the directory are Python byte objects.
|
||||
// Same as `key_value_dir_get()`, but retrieves Python byte values
|
||||
// explicitly.
|
||||
.def(
|
||||
"key_value_dir_get_bytes",
|
||||
[](DistributedRuntimeClient& client, absl::string_view key)
|
||||
-> std::vector<std::pair<std::string, nb::bytes>> {
|
||||
std::vector<std::pair<std::string, std::string>> result;
|
||||
{
|
||||
nb::gil_scoped_release gil_release;
|
||||
result = xla::ValueOrThrow(client.KeyValueDirGet(key));
|
||||
}
|
||||
// Convert std::string values to nb::bytes.
|
||||
std::vector<std::pair<std::string, nb::bytes>> kvs;
|
||||
kvs.reserve(result.size());
|
||||
for (auto& kv : result) {
|
||||
kvs.push_back(
|
||||
std::pair(std::move(kv.first),
|
||||
nb::bytes(kv.second.data(), kv.second.size())));
|
||||
}
|
||||
return kvs;
|
||||
},
|
||||
nb::arg("key"))
|
||||
.def(
|
||||
"key_value_delete",
|
||||
[](DistributedRuntimeClient& client, absl::string_view key) {
|
||||
nb::gil_scoped_release gil_release;
|
||||
return xla::ThrowIfError(client.KeyValueDelete(key));
|
||||
},
|
||||
nb::arg("key"));
|
||||
|
||||
m.def(
|
||||
"get_distributed_runtime_service",
|
||||
[](std::string address, int num_nodes,
|
||||
std::optional<int> heartbeat_interval,
|
||||
std::optional<int> max_missing_heartbeats,
|
||||
std::optional<int> cluster_register_timeout,
|
||||
std::optional<int> shutdown_timeout)
|
||||
-> std::unique_ptr<DistributedRuntimeService> {
|
||||
CoordinationServiceImpl::Options options;
|
||||
options.num_nodes = num_nodes;
|
||||
if (heartbeat_interval.has_value()) {
|
||||
options.heartbeat_interval = absl::Seconds(*heartbeat_interval);
|
||||
}
|
||||
if (max_missing_heartbeats.has_value()) {
|
||||
options.max_missing_heartbeats = *max_missing_heartbeats;
|
||||
}
|
||||
if (cluster_register_timeout.has_value()) {
|
||||
options.cluster_register_timeout =
|
||||
absl::Seconds(*cluster_register_timeout);
|
||||
}
|
||||
if (shutdown_timeout.has_value()) {
|
||||
options.shutdown_timeout = absl::Seconds(*shutdown_timeout);
|
||||
}
|
||||
std::unique_ptr<DistributedRuntimeService> service =
|
||||
xla::ValueOrThrow(GetDistributedRuntimeService(address, options));
|
||||
return service;
|
||||
},
|
||||
nb::arg("address"), nb::arg("num_nodes"),
|
||||
nb::arg("heartbeat_interval").none() = std::nullopt,
|
||||
nb::arg("max_missing_heartbeats").none() = std::nullopt,
|
||||
nb::arg("cluster_register_timeout").none() = std::nullopt,
|
||||
nb::arg("shutdown_timeout").none() = std::nullopt);
|
||||
|
||||
m.def(
|
||||
"get_distributed_runtime_client",
|
||||
[](std::string address, int node_id, std::optional<int> rpc_timeout,
|
||||
std::optional<int> init_timeout, std::optional<int> shutdown_timeout,
|
||||
std::optional<int> heartbeat_interval,
|
||||
std::optional<int> max_missing_heartbeats,
|
||||
std::optional<std::function<void(absl::Status)>>
|
||||
missed_heartbeat_callback,
|
||||
std::optional<bool> shutdown_on_destruction,
|
||||
std::optional<bool> use_compression)
|
||||
-> std::shared_ptr<DistributedRuntimeClient> {
|
||||
bool compression = use_compression.value_or(false);
|
||||
DistributedRuntimeClient::Options options;
|
||||
options.node_id = node_id;
|
||||
if (rpc_timeout.has_value()) {
|
||||
options.rpc_timeout = absl::Seconds(*rpc_timeout);
|
||||
}
|
||||
if (init_timeout.has_value()) {
|
||||
options.init_timeout = absl::Seconds(*init_timeout);
|
||||
}
|
||||
if (shutdown_timeout.has_value()) {
|
||||
options.shutdown_timeout = absl::Seconds(*shutdown_timeout);
|
||||
}
|
||||
if (heartbeat_interval.has_value()) {
|
||||
options.heartbeat_interval = absl::Seconds(*heartbeat_interval);
|
||||
}
|
||||
if (max_missing_heartbeats.has_value()) {
|
||||
options.max_missing_heartbeats = *max_missing_heartbeats;
|
||||
}
|
||||
if (missed_heartbeat_callback.has_value()) {
|
||||
options.missed_heartbeat_callback =
|
||||
std::move(*missed_heartbeat_callback);
|
||||
}
|
||||
if (shutdown_on_destruction.has_value()) {
|
||||
options.shutdown_on_destruction = *shutdown_on_destruction;
|
||||
}
|
||||
return GetDistributedRuntimeClient(address, options, compression);
|
||||
},
|
||||
nb::arg("address"), nb::arg("node_id"),
|
||||
nb::arg("rpc_timeout").none() = std::nullopt,
|
||||
nb::arg("init_timeout").none() = std::nullopt,
|
||||
nb::arg("shutdown_timeout").none() = std::nullopt,
|
||||
nb::arg("heartbeat_interval").none() = std::nullopt,
|
||||
nb::arg("max_missing_heartbeats").none() = std::nullopt,
|
||||
nb::arg("missed_heartbeat_callback").none() = std::nullopt,
|
||||
nb::arg("shutdown_on_destruction").none() = std::nullopt,
|
||||
nb::arg("use_compression").none() = std::nullopt);
|
||||
|
||||
m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); });
|
||||
|
||||
m.def("is_optimized_build", &IsOptimizedBuild);
|
||||
|
||||
m.def("json_to_pprof_profile", xla::ValueOrThrowWrapper(JsonToPprofProfile),
|
||||
"Encodes the JSON representation of a pprof Profile into its binary "
|
||||
"protocol buffer encoding.");
|
||||
m.def("pprof_profile_to_json", xla::ValueOrThrowWrapper(PprofProfileToJson),
|
||||
"Decodes an uncompressed pprof Profile protocol buffer into a JSON "
|
||||
"representation");
|
||||
|
||||
RegisterCompileOnlyClient(m);
|
||||
nb::class_<ifrt::Topology>(m, "DeviceTopology")
|
||||
.def("_make_compile_only_devices",
|
||||
[](std::shared_ptr<ifrt::Topology> topology) {
|
||||
if (!llvm::isa<ifrt::PjRtTopology>(*topology)) {
|
||||
throw xla::XlaRuntimeError("Only PjRtTopologies are supported.");
|
||||
}
|
||||
return MakeCompileOnlyClient(
|
||||
std::dynamic_pointer_cast<ifrt::PjRtTopology>(topology))
|
||||
->Devices();
|
||||
})
|
||||
.def_prop_ro(
|
||||
"platform",
|
||||
[](ifrt::Topology& topology) { return topology.platform_name(); })
|
||||
.def_prop_ro(
|
||||
"platform_version",
|
||||
[](ifrt::Topology& topology) { return topology.platform_version(); })
|
||||
.def("serialize",
|
||||
[](ifrt::Topology& topology) -> nb::bytes {
|
||||
std::string serialized = ValueOrThrow(topology.Serialize());
|
||||
return nb::bytes(serialized.data(), serialized.size());
|
||||
})
|
||||
.def("__getattr__",
|
||||
[](ifrt::Topology& topology, absl::string_view name) -> nb::object {
|
||||
const auto& attrs = topology.Attributes().map();
|
||||
auto it = attrs.find(name);
|
||||
if (it != attrs.end()) {
|
||||
return std::visit([](auto&& v) { return nb::cast(v.value); },
|
||||
it->second);
|
||||
}
|
||||
throw nb::attribute_error(
|
||||
absl::StrCat("Unknown attribute ", name).c_str());
|
||||
});
|
||||
|
||||
nb::class_<ifrt::Executable>(m, "Executable")
|
||||
.def("hlo_modules", ValueOrThrowWrapper(&ifrt::Executable::GetHloModules))
|
||||
.def("get_output_memory_kinds",
|
||||
xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputMemoryKinds))
|
||||
.def("get_output_shardings", &ifrt::Executable::GetOutputShardings)
|
||||
.def("get_parameter_layouts",
|
||||
ValueOrThrowWrapper(&ifrt::Executable::GetParameterLayouts))
|
||||
.def("get_output_layouts",
|
||||
xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputLayouts))
|
||||
.def("get_parameter_shardings", &ifrt::Executable::GetParameterShardings)
|
||||
.def("get_compiled_memory_stats",
|
||||
xla::ValueOrThrowWrapper(&ifrt::Executable::GetCompiledMemoryStats))
|
||||
.def("serialize",
|
||||
[](const ifrt::Executable& exec) -> nb::bytes {
|
||||
std::string serialized = ValueOrThrow(exec.Serialize());
|
||||
return nb::bytes(serialized.data(), serialized.size());
|
||||
})
|
||||
.def("cost_analysis", [](const ifrt::Executable& exec) {
|
||||
auto attrs = ValueOrThrow(exec.GetCostAnalysis());
|
||||
return ifrt::ToPjRtAttributeMap(std::move(attrs));
|
||||
});
|
||||
|
||||
m.def("is_asan", IsAsan);
|
||||
m.def("is_msan", IsMsan);
|
||||
m.def("is_tsan", IsTsan);
|
||||
m.def("is_sanitized", IsSanitized);
|
||||
|
||||
m.def(
|
||||
"batched_device_put",
|
||||
[](nb::object aval, nb::object sharding, std::vector<nb::object> xs,
|
||||
std::vector<const PyDevice*> dst_devices, bool committed,
|
||||
bool force_copy,
|
||||
PjRtClient::HostBufferSemantics host_buffer_semantics) -> nb::object {
|
||||
return ValueOrThrow(PyArray::BatchedDevicePut(
|
||||
aval, sharding, std::move(xs), std::move(dst_devices), committed,
|
||||
force_copy, host_buffer_semantics, jax::GetEnableX64()));
|
||||
},
|
||||
nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"),
|
||||
nb::arg("committed") = true, nb::arg("force_copy") = false,
|
||||
nb::arg("host_buffer_semantics") =
|
||||
PjRtClient::HostBufferSemantics::kImmutableZeroCopy);
|
||||
m.def(
|
||||
"reorder_shards",
|
||||
[](PyArray x, nb::object dst_sharding,
|
||||
ifrt::ArrayCopySemantics array_copy_semantics) {
|
||||
return ValueOrThrow(PyArray::ReorderShards(
|
||||
std::move(x), std::move(dst_sharding), array_copy_semantics));
|
||||
},
|
||||
nb::arg("x"), nb::arg("dst_sharding"), nb::arg("array_copy_semantics"));
|
||||
|
||||
m.def("batched_block_until_ready", [](std::vector<nb::object> xs) {
|
||||
ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs)));
|
||||
});
|
||||
|
||||
m.def("check_and_canonicalize_memory_kind",
|
||||
&jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(),
|
||||
nb::arg("device_list"));
|
||||
} // NOLINT(readability/fn_size)
|
||||
|
||||
} // namespace xla
|
@ -19,7 +19,7 @@ from __future__ import annotations
|
||||
import atexit
|
||||
from collections.abc import Mapping, Sequence
|
||||
import contextlib
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
import enum
|
||||
import gzip
|
||||
import inspect
|
||||
import logging
|
||||
|
1059
jaxlib/xla/xla_extension/__init__.pyi
Normal file
1059
jaxlib/xla/xla_extension/__init__.pyi
Normal file
File diff suppressed because it is too large
Load Diff
32
jaxlib/xla/xla_extension/config.pyi
Normal file
32
jaxlib/xla/xla_extension/config.pyi
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2024 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
unset: object
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
class Config(Generic[_T]):
|
||||
def __init__(self, value: _T, include_in_jit_key: bool = False): ...
|
||||
|
||||
@property
|
||||
def value(self) -> _T: ...
|
||||
|
||||
def get_local(self) -> Any: ...
|
||||
def get_global(self) -> _T: ...
|
||||
def set_local(self, value: Any) -> None: ...
|
||||
def swap_local(self, value: Any) -> Any: ...
|
||||
def set_global(self, value: _T) -> None: ...
|
46
jaxlib/xla/xla_extension/guard_lib.pyi
Normal file
46
jaxlib/xla/xla_extension/guard_lib.pyi
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2024 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
class TransferGuardLevel:
|
||||
ALLOW: Any
|
||||
LOG: Any
|
||||
DISALLOW: Any
|
||||
LOG_EXPLICIT: Any
|
||||
DISALLOW_EXPLICIT: Any
|
||||
|
||||
class GarbageCollectionGuardLevel:
|
||||
ALLOW: Any
|
||||
LOG: Any
|
||||
FATAL: Any
|
||||
|
||||
class GuardState:
|
||||
host_to_device: Optional[TransferGuardLevel]
|
||||
device_to_device: Optional[TransferGuardLevel]
|
||||
device_to_host: Optional[TransferGuardLevel]
|
||||
|
||||
explicit_device_put: bool
|
||||
explicit_device_get: bool
|
||||
|
||||
garbage_collect_array: Optional[GarbageCollectionGuardLevel]
|
||||
|
||||
def global_state() -> GuardState: ...
|
||||
def thread_local_state() -> GuardState: ...
|
||||
|
||||
class _TestingScopedLogSink:
|
||||
def __enter__(self) -> _TestingScopedLogSink: ...
|
||||
def __exit__(self, *args, **kwargs) -> None: ...
|
||||
def logs(self) -> List[str]: ...
|
43
jaxlib/xla/xla_extension/ifrt_programs.pyi
Normal file
43
jaxlib/xla/xla_extension/ifrt_programs.pyi
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2024 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import Any, Sequence, Union
|
||||
|
||||
from jax.jaxlib.xla import xla_extension
|
||||
|
||||
class Program: ...
|
||||
|
||||
class CompileOptions: ...
|
||||
|
||||
def make_hlo_program(mlir_module: Union[str, bytes]) -> Program: ...
|
||||
|
||||
def make_colocated_python_program(
|
||||
name : str,
|
||||
picked_function: bytes,
|
||||
devices: Sequence[xla_extension.Device] | xla_extension.DeviceList,
|
||||
input_avals: Sequence[Any],
|
||||
output_avals: Sequence[Any],
|
||||
) -> Program: ...
|
||||
|
||||
def make_plugin_program(data: Union[str, bytes]) -> Program: ...
|
||||
|
||||
def make_colocated_python_compile_options() -> CompileOptions: ...
|
||||
|
||||
def make_xla_compile_options(
|
||||
compile_options: xla_extension.CompileOptions,
|
||||
host_callbacks: Sequence[Any]
|
||||
) -> CompileOptions: ...
|
||||
|
||||
def make_plugin_compile_options() -> CompileOptions: ...
|
33
jaxlib/xla/xla_extension/ifrt_proxy.pyi
Normal file
33
jaxlib/xla/xla_extension/ifrt_proxy.pyi
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2024 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import Any, Optional, Callable
|
||||
|
||||
from jax.jaxlib.xla import xla_extension
|
||||
|
||||
_Status = Any
|
||||
Client = xla_extension.Client
|
||||
|
||||
|
||||
class ClientConnectionOptions:
|
||||
on_disconnect: Optional[Callable[[_Status], None]] = None
|
||||
on_connection_update: Optional[Callable[[str], None]] = None
|
||||
connection_timeout_in_seconds: Optional[int] = None
|
||||
|
||||
|
||||
def get_client(
|
||||
proxy_server_address: str,
|
||||
options: ClientConnectionOptions
|
||||
) -> Client: ...
|
76
jaxlib/xla/xla_extension/jax_jit.pyi
Normal file
76
jaxlib/xla/xla_extension/jax_jit.pyi
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright 2021 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from jax.jaxlib.xla import xla_extension
|
||||
|
||||
from . import pytree
|
||||
|
||||
Client = xla_extension.Client
|
||||
Device = xla_extension.Device
|
||||
|
||||
|
||||
class JitState:
|
||||
disable_jit: Optional[bool]
|
||||
enable_x64: Optional[bool]
|
||||
default_device: Optional[Any]
|
||||
extra_jit_context: Optional[Any]
|
||||
post_hook: Optional[Callable[..., Any]]
|
||||
|
||||
def global_state() -> JitState: ...
|
||||
def thread_local_state() -> JitState: ...
|
||||
|
||||
def get_enable_x64() -> bool: ...
|
||||
def set_thread_local_state_initialization_callback(
|
||||
function: Callable[[], None]): ...
|
||||
|
||||
def swap_thread_local_state_disable_jit(
|
||||
value: Optional[bool]) -> Optional[bool]: ...
|
||||
|
||||
class ArgSignature:
|
||||
dtype: np.dtype
|
||||
shape: Tuple[int, ...]
|
||||
weak_type: bool
|
||||
|
||||
def _ArgSignatureOfValue(
|
||||
__arg: Any,
|
||||
__jax_enable_x64: bool) -> ArgSignature: ...
|
||||
|
||||
def _is_float0(__arg: Any) -> bool: ...
|
||||
|
||||
|
||||
class ArgumentSignature:
|
||||
static_args: Sequence[Any]
|
||||
static_arg_names: Sequence[str]
|
||||
dynamic_arg_names: Sequence[str]
|
||||
dynamic_arg_treedefs: Sequence[pytree.PyTreeDef]
|
||||
|
||||
def __eq__(self, value, /): ...
|
||||
def __ne__(self, value, /): ...
|
||||
def __hash__(self, /): ...
|
||||
def __str__(self): ...
|
||||
def __repr__(self): ...
|
||||
|
||||
|
||||
def parse_arguments(
|
||||
positional_args: Sequence[Any],
|
||||
keyword_args: Sequence[Any],
|
||||
kwnames: Tuple[str, ...],
|
||||
static_argnums: Sequence[int],
|
||||
static_argnames: Sequence[str],
|
||||
pytree_registry: pytree.PyTreeRegistry,
|
||||
) -> tuple[ArgumentSignature, Sequence[Any]]: ...
|
34
jaxlib/xla/xla_extension/mlir.pyi
Normal file
34
jaxlib/xla/xla_extension/mlir.pyi
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2021 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import Union
|
||||
from . import XlaComputation
|
||||
|
||||
def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ...
|
||||
def mlir_module_to_xla_computation(
|
||||
mlir_module: Union[bytes, str],
|
||||
use_tuple_args: bool = ...,
|
||||
return_tuple: bool = ...,
|
||||
) -> XlaComputation: ...
|
||||
def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> bytes: ...
|
||||
def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> bytes: ...
|
||||
def serialize_portable_artifact(mlir_module: str, target: str) -> bytes: ...
|
||||
def deserialize_portable_artifact(mlir_module: bytes) -> str: ...
|
||||
def refine_polymorphic_shapes(
|
||||
mlir_module: Union[bytes, str],
|
||||
enable_shape_assertions: bool = ...,
|
||||
validate_static_shapes: bool = ...,
|
||||
enable_shardy: bool = ...,
|
||||
) -> bytes: ...
|
465
jaxlib/xla/xla_extension/ops.pyi
Normal file
465
jaxlib/xla/xla_extension/ops.pyi
Normal file
@ -0,0 +1,465 @@
|
||||
# Copyright 2021 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import enum
|
||||
from typing import Any, Optional, Sequence, overload
|
||||
|
||||
from jax.jaxlib.xla import xla_extension
|
||||
|
||||
FftType = xla_extension.FftType
|
||||
XlaBuilder = xla_extension.XlaBuilder
|
||||
XlaComputation = xla_extension.XlaComputation
|
||||
XlaOp = xla_extension.XlaOp
|
||||
PrecisionConfig_Precision = xla_extension.PrecisionConfig_Precision
|
||||
PrimitiveType = xla_extension.PrimitiveType
|
||||
Shape = xla_extension.Shape
|
||||
ShapeIndex = xla_extension.ShapeIndex
|
||||
ResultAccuracy = xla_extension.ResultAccuracy
|
||||
|
||||
_ChannelHandle = Any
|
||||
_ConvDimensionNumbers = Any
|
||||
_DotDimensionNumbers = Any
|
||||
_Layout = Any
|
||||
_LiteralSlice = Any
|
||||
_GatherDimensionNumbers = Any
|
||||
_PaddingConfig = Any
|
||||
_ReplicaGroup = Any
|
||||
_ScatterDimensionNumbers = Any
|
||||
|
||||
class TriangularSolveOptions_Transpose(enum.IntEnum):
|
||||
TRANSPOSE_INVALID: int
|
||||
NO_TRANSPOSE: int
|
||||
TRANSPOSE: int
|
||||
ADJOINT: int
|
||||
|
||||
class RandomAlgorithm(enum.IntEnum):
|
||||
RNG_DEFAULT: int
|
||||
RNG_THREE_FRY: int
|
||||
RNG_PHILOX: int
|
||||
|
||||
class CustomCallSchedule(enum.IntEnum):
|
||||
SCHEDULE_NONE: int
|
||||
SCHEDULE_LATEST: int
|
||||
SCHEDULE_EARLIEST: int
|
||||
|
||||
# TODO(b/189822916): Remove this enum when all clients are migrated to the
|
||||
# status-returning API.
|
||||
class CustomCallApiVersion(enum.IntEnum):
|
||||
API_VERSION_ORIGINAL: int
|
||||
API_VERSION_STATUS_RETURNING: int
|
||||
API_VERSION_STATUS_RETURNING_UNIFIED: int
|
||||
API_VERSION_TYPED_FFI: int
|
||||
|
||||
def AfterAll(builder: XlaBuilder, tokens: Sequence[XlaOp]) -> XlaOp: ...
|
||||
def AllGather(
|
||||
operand: XlaOp,
|
||||
all_gather_dimension: int,
|
||||
shard_count: int,
|
||||
replica_groups: Sequence[_ReplicaGroup] = ...,
|
||||
channel_id: Optional[_ChannelHandle] = ...,
|
||||
shape_with_layout: Optional[_Layout] = ...,
|
||||
use_global_device_ids: Optional[bool] = ...) -> XlaOp: ...
|
||||
def AllReduce(
|
||||
operand: XlaOp,
|
||||
computation: XlaComputation,
|
||||
replica_groups: Sequence[_ReplicaGroup] = ...,
|
||||
channel_id: Optional[_ChannelHandle] = ...,
|
||||
shape_with_layout: Optional[_Layout] = ...) -> XlaOp: ...
|
||||
def ApproxTopK(
|
||||
builder: XlaBuilder,
|
||||
operands: Sequence[XlaOp],
|
||||
init_values: Sequence[XlaOp],
|
||||
top_k: int,
|
||||
reduction_dim: int,
|
||||
comparator: XlaComputation,
|
||||
recall_target: Optional[float],
|
||||
aggregate_to_topk: Optional[bool],
|
||||
reduction_input_size_override: Optional[int]) -> XlaOp: ...
|
||||
def ApproxTopKFallback(
|
||||
builder: XlaBuilder,
|
||||
operands: Sequence[XlaOp],
|
||||
init_values: Sequence[XlaOp],
|
||||
top_k: int,
|
||||
reduction_dim: int,
|
||||
comparator: XlaComputation,
|
||||
recall_target: Optional[float],
|
||||
aggregate_to_topk: Optional[bool],
|
||||
reduction_input_size_override: Optional[int]) -> XlaOp: ...
|
||||
def ApproxTopKReductionOutputSize(
|
||||
input_size: int,
|
||||
rank: int,
|
||||
top_k: int,
|
||||
recall_target: float,
|
||||
aggregate_to_topk: Optional[bool] = ...,
|
||||
input_size_override: Optional[int] = ...) -> tuple[int, int]: ...
|
||||
def ReduceScatter(
|
||||
operand: XlaOp,
|
||||
computation: XlaComputation,
|
||||
scatter_dimension: int,
|
||||
shard_count: int,
|
||||
replica_groups: Sequence[_ReplicaGroup] = ...,
|
||||
channel_id: Optional[_ChannelHandle] = ...,
|
||||
layout: Optional[_Layout] = ...,
|
||||
use_global_device_ids: Optional[bool] = ...) -> XlaOp: ...
|
||||
def AllToAll(
|
||||
operand: XlaOp,
|
||||
split_dimension: int,
|
||||
concat_dimension: int,
|
||||
split_count: int,
|
||||
replica_groups: Sequence[_ReplicaGroup] = ...,
|
||||
layout: Optional[_Layout] = ...,
|
||||
channel_id: Optional[_ChannelHandle] = ...) -> XlaOp: ...
|
||||
def BitcastConvertType(operand: XlaOp,
|
||||
new_element_type: PrimitiveType) -> XlaOp: ...
|
||||
def Broadcast(operand: XlaOp, sizes: Sequence[int]) -> XlaOp: ...
|
||||
def BroadcastInDim(operand: XlaOp,
|
||||
shape: Sequence[int],
|
||||
broadcast_dimensions: Sequence[int]) -> XlaOp: ...
|
||||
def Call(builder: XlaBuilder,
|
||||
computation: XlaComputation,
|
||||
operands: Sequence[XlaOp]) -> XlaOp: ...
|
||||
def Cholesky(a: XlaOp, lower: bool = ...) -> XlaOp: ...
|
||||
def Clamp(min: XlaOp, operand: XlaOp, max: XlaOp) -> XlaOp: ...
|
||||
def Collapse(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ...
|
||||
def CollectivePermute(
|
||||
operand: XlaOp,
|
||||
source_target_pairs: Sequence[tuple[int, int]],
|
||||
channel_id: Optional[_ChannelHandle] = ...,
|
||||
inplace: bool = ...) -> XlaOp: ...
|
||||
def ConcatInDim(builder: XlaBuilder,
|
||||
operands: Sequence[XlaOp],
|
||||
dimension: int) -> XlaOp: ...
|
||||
@overload
|
||||
def Conditional(branch_index: XlaOp,
|
||||
branch_computations: Sequence[XlaComputation],
|
||||
branch_operands: Sequence[XlaOp]) -> XlaOp: ...
|
||||
@overload
|
||||
def Conditional(
|
||||
predicate: XlaOp,
|
||||
true_operand: XlaOp,
|
||||
true_computation: XlaComputation,
|
||||
false_operand: XlaOp,
|
||||
false_computation: XlaComputation) -> XlaOp: ...
|
||||
|
||||
def Constant(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ...
|
||||
def ConstantLiteral(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ...
|
||||
def ConvGeneralDilated(
|
||||
lhs: XlaOp,
|
||||
rhs: XlaOp,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
lhs_dilation: Sequence[int],
|
||||
rhs_dilation: Sequence[int],
|
||||
dimension_numbers: _ConvDimensionNumbers,
|
||||
feature_group_count: int = ...,
|
||||
batch_group_count: int = ...,
|
||||
precision_config: Optional[PrecisionConfig_Precision] = ...,
|
||||
preferred_element_type: Optional[PrimitiveType] = ...,
|
||||
window_reversal: Optional[Sequence[bool]] = ...) -> XlaOp: ...
|
||||
def ConvertElementType(
|
||||
operand: XlaOp,
|
||||
new_element_type: PrimitiveType) -> XlaOp: ...
|
||||
def CreateToken(builder: XlaBuilder) -> XlaOp: ...
|
||||
def CrossReplicaSum(
|
||||
operand: XlaOp,
|
||||
replica_groups: Sequence[_ReplicaGroup] = ...) -> XlaOp: ...
|
||||
def CustomCall(
|
||||
builder: XlaBuilder,
|
||||
call_target_name: bytes,
|
||||
operands: Sequence[XlaOp],
|
||||
shape: Shape,
|
||||
opaque: bytes = ...,
|
||||
has_side_effect: bool = ...,
|
||||
schedule: CustomCallSchedule = ...,
|
||||
api_version: CustomCallApiVersion = ...) -> XlaOp: ...
|
||||
def CustomCallWithLayout(
|
||||
builder: XlaBuilder,
|
||||
call_target_name: bytes,
|
||||
operands: Sequence[XlaOp],
|
||||
shape_with_layout: Shape,
|
||||
operand_shapes_with_layout: Sequence[Shape],
|
||||
opaque: bytes = ...,
|
||||
has_side_effect: bool = ...,
|
||||
schedule: CustomCallSchedule = ...,
|
||||
api_version: CustomCallApiVersion = ...) -> XlaOp: ...
|
||||
def CustomCallWithAliasing(
|
||||
builder: XlaBuilder,
|
||||
call_target_name: bytes,
|
||||
operands: Sequence[XlaOp],
|
||||
shape_with_layout: Shape,
|
||||
operand_shapes_with_layout: Sequence[Shape],
|
||||
opaque: bytes = ...,
|
||||
has_side_effect: bool = ...,
|
||||
output_operand_aliasing: Sequence[tuple[ShapeIndex, tuple[int, ShapeIndex]]] = ...,
|
||||
literal: _LiteralSlice = ...,
|
||||
schedule: CustomCallSchedule = ...,
|
||||
api_version: CustomCallApiVersion = ...) -> XlaOp: ...
|
||||
def Dot(
|
||||
lhs: XlaOp,
|
||||
rhs: XlaOp,
|
||||
precision_config: Optional[PrecisionConfig_Precision] = ...,
|
||||
preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ...
|
||||
def DotGeneral(
|
||||
lhs: XlaOp,
|
||||
rhs: XlaOp,
|
||||
dimensions_numbers: _DotDimensionNumbers,
|
||||
precision_config: Optional[PrecisionConfig_Precision] = ...,
|
||||
preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ...
|
||||
def DynamicReshape(
|
||||
operand: XlaOp,
|
||||
dim_sizes: Sequence[XlaOp],
|
||||
new_size_bounds: Sequence[int],
|
||||
dims_are_dynamic: Sequence[bool]) -> XlaOp: ...
|
||||
def DynamicSlice(
|
||||
operand: XlaOp,
|
||||
start_indices: Sequence[XlaOp],
|
||||
slice_sizes: Sequence[int]) -> XlaOp: ...
|
||||
def DynamicUpdateSlice(
|
||||
operand: XlaOp,
|
||||
update: XlaOp,
|
||||
start_indices: Sequence[XlaOp]) -> XlaOp: ...
|
||||
def Eigh(
|
||||
a: XlaOp,
|
||||
lower: bool = ...,
|
||||
max_iter: int = ...,
|
||||
epsilon: float = ...,
|
||||
sort_eigenvalues: bool = ...) -> tuple[XlaOp, XlaOp]: ...
|
||||
def Fft(
|
||||
operand: XlaOp,
|
||||
fft_type: FftType,
|
||||
fft_length: Sequence[int]) -> XlaOp: ...
|
||||
def Gather(
|
||||
a: XlaOp,
|
||||
start_indices: XlaOp,
|
||||
dimension_numbers: _GatherDimensionNumbers,
|
||||
slice_sizes: Sequence[int],
|
||||
indices_are_sorted: bool = ...) -> XlaOp: ...
|
||||
def GetDimensionSize(operand: XlaOp, index: int) -> XlaOp: ...
|
||||
def GetTupleElement(tuple_data: XlaOp, index: int) -> XlaOp: ...
|
||||
def InfeedWithToken(
|
||||
token: XlaOp,
|
||||
shape: Shape,
|
||||
config: Optional[str] = ...) -> XlaOp: ...
|
||||
@overload
|
||||
def Iota(builder: XlaBuilder, shape: Shape, iota_dimension: int) -> XlaOp: ...
|
||||
@overload
|
||||
def Iota(builder: XlaBuilder, type: PrimitiveType, size: int) -> XlaOp: ...
|
||||
def LU(a: XlaOp) -> tuple[XlaOp, XlaOp, XlaOp]: ...
|
||||
def Map(
|
||||
builder: XlaBuilder,
|
||||
operands: Sequence[XlaOp],
|
||||
computation: XlaComputation,
|
||||
dimensions: Sequence[int],
|
||||
static_operands: Sequence[XlaOp] = ...) -> XlaOp: ...
|
||||
def MultiCollectivePermute(
|
||||
operands: Sequence[XlaOp],
|
||||
source_target_pairs: Sequence[tuple[int, int]],
|
||||
channel_id: Optional[_ChannelHandle] = ...,
|
||||
inplace: bool = ...) -> XlaOp: ...
|
||||
def NextAfter(__from: XlaOp, to: XlaOp) -> XlaOp: ...
|
||||
def OutfeedWithToken(
|
||||
operand: XlaOp,
|
||||
token: XlaOp,
|
||||
shape_with_layout: Shape,
|
||||
outfeed_config: Optional[str] = ...) -> XlaOp: ...
|
||||
def Pad(
|
||||
operand: XlaOp,
|
||||
padding_value: XlaOp,
|
||||
padding_config: _PaddingConfig) -> XlaOp: ...
|
||||
def Parameter(
|
||||
builder: XlaBuilder,
|
||||
parameter_number: int,
|
||||
shape: Shape,
|
||||
name: str = ...,
|
||||
replicated_at_leaf_buffers: Sequence[bool] = ...) -> XlaOp: ...
|
||||
def ProductOfElementaryHouseholderReflectors(a: XlaOp, taus: XlaOp) -> XlaOp: ...
|
||||
def QR(a: XlaOp, full_matrices: bool) -> tuple[XlaOp, XlaOp]: ...
|
||||
def QrDecomposition(a: XlaOp) -> tuple[XlaOp, XlaOp]: ...
|
||||
def Reduce(
|
||||
builder: XlaBuilder,
|
||||
operands: Sequence[XlaOp],
|
||||
init_values: Sequence[XlaOp],
|
||||
computation: XlaComputation,
|
||||
dimensions_to_reduce: Sequence[int]) -> XlaOp: ...
|
||||
def ReducePrecision(
|
||||
operand: XlaOp,
|
||||
exponent_bits: int,
|
||||
mantissa_bits: int) -> XlaOp: ...
|
||||
@overload
|
||||
def ReduceWindowWithGeneralPadding(
|
||||
operand: XlaOp,
|
||||
init_value: XlaOp,
|
||||
computation: XlaComputation,
|
||||
window_dimensions: Sequence[int],
|
||||
window_strides: Sequence[int],
|
||||
base_dilations: Sequence[int],
|
||||
window_dilations: Sequence[int],
|
||||
padding: Sequence[tuple[int, int]]) -> XlaOp: ...
|
||||
@overload
|
||||
def ReduceWindowWithGeneralPadding(
|
||||
operands: Sequence[XlaOp],
|
||||
init_values: Sequence[XlaOp],
|
||||
computation: XlaComputation,
|
||||
window_dimensions: Sequence[int],
|
||||
window_strides: Sequence[int],
|
||||
base_dilations: Sequence[int],
|
||||
window_dilations: Sequence[int],
|
||||
padding: Sequence[tuple[int, int]]) -> XlaOp: ...
|
||||
def ReplicaId(builder: XlaBuilder) -> XlaOp: ...
|
||||
def Reshape(operand: XlaOp, new_sizes: Sequence[int]) -> XlaOp: ...
|
||||
def Rev(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ...
|
||||
def RngBitGenerator(
|
||||
algorithm: RandomAlgorithm,
|
||||
initial_state: XlaOp,
|
||||
shape: Shape) -> XlaOp: ...
|
||||
def RngNormal(mu: XlaOp, sigma: XlaOp, shape: Shape) -> XlaOp: ...
|
||||
def RngUniform(a: XlaOp, b: XlaOp, shape: Shape) -> XlaOp: ...
|
||||
@overload
|
||||
def Scatter(
|
||||
input: XlaOp,
|
||||
scatter_indices: XlaOp,
|
||||
updates: XlaOp,
|
||||
update_computation: XlaComputation,
|
||||
dimension_numbers: _ScatterDimensionNumbers,
|
||||
indices_are_sorted: bool = ...,
|
||||
unique_indices: bool = ...) -> XlaOp: ...
|
||||
@overload
|
||||
def Scatter(
|
||||
inputs: Sequence[XlaOp],
|
||||
scatter_indices: XlaOp,
|
||||
updates: Sequence[XlaOp],
|
||||
update_computation: XlaComputation,
|
||||
dimension_numbers: _ScatterDimensionNumbers,
|
||||
indices_are_sorted: bool = ...,
|
||||
unique_indices: bool = ...) -> XlaOp: ...
|
||||
def Select(pred: XlaOp, on_true: XlaOp, on_false: XlaOp) -> XlaOp: ...
|
||||
def SelectAndScatterWithGeneralPadding(
|
||||
operand: XlaOp,
|
||||
select: XlaComputation,
|
||||
window_dimensions: Sequence[int],
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
source: XlaOp,
|
||||
init_value: XlaOp,
|
||||
scatter: XlaComputation) -> XlaOp: ...
|
||||
def Slice(
|
||||
operand: XlaOp,
|
||||
start_indices: Sequence[int],
|
||||
limit_indices: Sequence[int],
|
||||
strides: Sequence[int]) -> XlaOp: ...
|
||||
def SliceInDim(
|
||||
operand: XlaOp,
|
||||
start_index: int,
|
||||
limit_index: int,
|
||||
stride: int,
|
||||
dimno: int) -> XlaOp: ...
|
||||
def Sort(
|
||||
builder: XlaBuilder,
|
||||
operands: Sequence[XlaOp],
|
||||
comparator: Optional[XlaComputation] = ...,
|
||||
dimension: int = ...,
|
||||
is_stable: bool = ...) -> XlaOp: ...
|
||||
def SVD(
|
||||
a: XlaOp,
|
||||
max_iter: int = ...,
|
||||
epsilon: float = ...) -> tuple[XlaOp, XlaOp, XlaOp]: ...
|
||||
def TopK(input: XlaOp, k: int) -> XlaOp: ...
|
||||
def Transpose(operand: XlaOp, permutation: Sequence[int]) -> XlaOp: ...
|
||||
def TriangularSolve(
|
||||
a: XlaOp,
|
||||
b: XlaOp,
|
||||
left_side: bool,
|
||||
lower: bool,
|
||||
unit_diagonal: bool,
|
||||
transpose_a: TriangularSolveOptions_Transpose) -> XlaOp: ...
|
||||
def Tuple(builder: XlaBuilder, elements: Sequence[XlaOp]) -> XlaOp: ...
|
||||
def While(
|
||||
condition: XlaComputation,
|
||||
body: XlaComputation,
|
||||
init: XlaOp) -> XlaOp: ...
|
||||
|
||||
|
||||
def Igamma(a: XlaOp, x: XlaOp) -> XlaOp: ...
|
||||
def Igammac(a: XlaOp, x: XlaOp) -> XlaOp: ...
|
||||
def IgammaGradA(a: XlaOp, x: XlaOp) -> XlaOp: ...
|
||||
def RandomGammaGrad(a: XlaOp, x: XlaOp) -> XlaOp: ...
|
||||
def RegularizedIncompleteBeta(a: XlaOp, b: XlaOp, x: XlaOp) -> XlaOp: ...
|
||||
def Zeta(a: XlaOp, q: XlaOp) -> XlaOp: ...
|
||||
|
||||
def Eq(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Ne(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Ge(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Gt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Lt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Le(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Add(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Sub(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Mul(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Div(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Rem(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Max(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Min(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def And(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Or(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Xor(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def ShiftLeft(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def ShiftRightArithmetic(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def ShiftRightLogical(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Atan2(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Pow(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
def Complex(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ...
|
||||
|
||||
def Not(__arg: XlaOp) -> XlaOp: ...
|
||||
def PopulationCount(__arg: XlaOp) -> XlaOp: ...
|
||||
def Clz(__arg: XlaOp) -> XlaOp: ...
|
||||
def Abs(__arg: XlaOp) -> XlaOp: ...
|
||||
def Exp(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def Expm1(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def Floor(__arg: XlaOp) -> XlaOp: ...
|
||||
def Ceil(__arg: XlaOp) -> XlaOp: ...
|
||||
def Round(__arg: XlaOp) -> XlaOp: ...
|
||||
def Log(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def Log1p(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def Sign(__arg: XlaOp) -> XlaOp: ...
|
||||
def Cos(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def OptimizationBarrier(__arg: XlaOp) -> XlaOp: ...
|
||||
def Sin(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def Tan(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def Tanh(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def IsFinite(__arg: XlaOp) -> XlaOp: ...
|
||||
def Neg(__arg: XlaOp) -> XlaOp: ...
|
||||
def Sqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def Rsqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def Cbrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def Square(__arg: XlaOp) -> XlaOp: ...
|
||||
def Reciprocal(__arg: XlaOp) -> XlaOp: ...
|
||||
def Erfc(__arg: XlaOp) -> XlaOp: ...
|
||||
def Erf(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ...
|
||||
def ErfInv(__arg: XlaOp) -> XlaOp: ...
|
||||
def Lgamma(__arg: XlaOp) -> XlaOp: ...
|
||||
def Digamma(__arg: XlaOp) -> XlaOp: ...
|
||||
def BesselI0e(__arg: XlaOp) -> XlaOp: ...
|
||||
def BesselI1e(__arg: XlaOp) -> XlaOp: ...
|
||||
def Acos(__arg: XlaOp) -> XlaOp: ...
|
||||
def Asin(__arg: XlaOp) -> XlaOp: ...
|
||||
def Atan(__arg: XlaOp) -> XlaOp: ...
|
||||
def Acosh(__arg: XlaOp) -> XlaOp: ...
|
||||
def Asinh(__arg: XlaOp) -> XlaOp: ...
|
||||
def Atanh(__arg: XlaOp) -> XlaOp: ...
|
||||
def Cosh(__arg: XlaOp) -> XlaOp: ...
|
||||
def Sinh(__arg: XlaOp) -> XlaOp: ...
|
||||
def Real(__arg: XlaOp) -> XlaOp: ...
|
||||
def Imag(__arg: XlaOp) -> XlaOp: ...
|
||||
def Conj(__arg: XlaOp) -> XlaOp: ...
|
83
jaxlib/xla/xla_extension/pmap_lib.pyi
Normal file
83
jaxlib/xla/xla_extension/pmap_lib.pyi
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright 2021 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Sequence, Iterable, Tuple
|
||||
|
||||
from . import pytree
|
||||
|
||||
_AvalDimSharding = Any
|
||||
_MeshDimAssignment = Any
|
||||
|
||||
class NoSharding:
|
||||
def __init__(self) -> None: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __eq__(self, __other: Any) -> bool: ...
|
||||
|
||||
class Chunked:
|
||||
@property
|
||||
def chunks(self) -> Sequence[int]: ...
|
||||
def __init__(self, __chunks: Sequence[int]) -> None: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __eq__(self, __other: Any) -> bool: ...
|
||||
|
||||
class Unstacked:
|
||||
@property
|
||||
def size(self) -> int: ...
|
||||
def __init__(self, __sz: int) -> None: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __eq__(self, __other: Any) -> bool: ...
|
||||
|
||||
class ShardedAxis:
|
||||
@property
|
||||
def axis(self) -> int: ...
|
||||
def __init__(self, __axis: int) -> None: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __eq__(self, __other: ShardedAxis) -> bool: ...
|
||||
|
||||
class Replicated:
|
||||
@property
|
||||
def replicas(self) -> int: ...
|
||||
def __init__(self, __replicas: int) -> None: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __eq__(self, __other: Replicated) -> bool: ...
|
||||
|
||||
class ShardingSpec:
|
||||
def __init__(self,
|
||||
sharding: Iterable[_AvalDimSharding],
|
||||
mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ...
|
||||
@property
|
||||
def sharding(self) -> Tuple[_AvalDimSharding, ...]: ...
|
||||
@property
|
||||
def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ...
|
||||
def __eq__(self, __other: ShardingSpec) -> bool: ...
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
|
||||
class PmapFunction:
|
||||
def __call__(self, *args, **kwargs) -> Any: ...
|
||||
def __getstate__(self) -> Any: ...
|
||||
def __setstate__(self, Any): ...
|
||||
__signature__: inspect.Signature
|
||||
def _cache_size(self) -> int: ...
|
||||
def _cache_clear(self) -> None: ...
|
||||
def _debug_cache_keys(self) -> str: ...
|
||||
|
||||
def pmap(fun: Callable[..., Any],
|
||||
cache_miss: Callable[..., Any],
|
||||
static_argnums: Sequence[int],
|
||||
shard_arg_fallback: Callable[..., Any],
|
||||
pytree_registry: pytree.PyTreeRegistry) -> PmapFunction: ...
|
58
jaxlib/xla/xla_extension/profiler.pyi
Normal file
58
jaxlib/xla/xla_extension/profiler.pyi
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright 2021 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, Type, Union, List, Tuple
|
||||
|
||||
_Status = Any
|
||||
|
||||
class ProfilerServer: ...
|
||||
def start_server(port: int) -> ProfilerServer: ...
|
||||
|
||||
def register_plugin_profiler(c_api: Any) -> None: ...
|
||||
|
||||
def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ...
|
||||
def get_instructins_profile(tensorboard_dir: str) -> List[Tuple[str, float]]: ...
|
||||
def get_fdo_profile(
|
||||
xspace: bytes, as_textproto: bool = ...
|
||||
) -> Union[bytes, str]: ...
|
||||
|
||||
class ProfilerSession:
|
||||
def __init__(self, options: Optional[ProfileOptions] = ...) -> None: ...
|
||||
def stop(self) -> bytes: ...
|
||||
def export(self, xspace: bytes, tensorboard_dir: str) -> _Status:...
|
||||
|
||||
class ProfileOptions:
|
||||
include_dataset_ops: bool
|
||||
host_tracer_level: int
|
||||
python_tracer_level: int
|
||||
enable_hlo_proto: bool
|
||||
start_timestamp_ns: int
|
||||
duration_ms: int
|
||||
repository_path: str
|
||||
|
||||
def aggregate_profiled_instructions(profiles: List[bytes], percentile: int) -> str: ...
|
||||
|
||||
class TraceMe:
|
||||
def __init__(self, name: str, **kwargs: Any) -> None: ...
|
||||
def __enter__(self) -> TraceMe: ...
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType]) -> Optional[bool]:...
|
||||
def set_metadata(self, **kwargs): ...
|
||||
@staticmethod
|
||||
def is_enabled() -> bool: ...
|
158
jaxlib/xla/xla_extension/pytree.pyi
Normal file
158
jaxlib/xla/xla_extension/pytree.pyi
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright 2021 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Hashable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
version: int
|
||||
|
||||
class PyTreeRegistry:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
enable_none: bool = ...,
|
||||
enable_tuple: bool = ...,
|
||||
enable_namedtuple: bool = ...,
|
||||
enable_list: bool = ...,
|
||||
enable_dict: bool = ...
|
||||
): ...
|
||||
def flatten(
|
||||
self,
|
||||
tree: Any,
|
||||
leaf_predicate: Optional[Callable[[Any], bool]] = ...,
|
||||
) -> Tuple[List[Any], PyTreeDef]: ...
|
||||
def flatten_one_level(
|
||||
self, tree: Any
|
||||
) -> Optional[Tuple[Iterable[Any], Any]]: ...
|
||||
def flatten_one_level_with_keys(
|
||||
self, tree: Any
|
||||
) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ...
|
||||
def flatten_with_path(
|
||||
self,
|
||||
tree: Any,
|
||||
leaf_predicate: Optional[Callable[[Any], bool]] = ...,
|
||||
) -> Tuple[List[Tuple[_KeyPath, Any]], PyTreeDef]: ...
|
||||
def register_node(
|
||||
self,
|
||||
__type: Type[_T],
|
||||
to_iterable: Callable[[_T], Tuple[_Children, _AuxData]],
|
||||
from_iterable: Callable[[_AuxData, _Children], _T],
|
||||
to_iterable_with_keys: (
|
||||
Callable[[_T], Tuple[_KeyLeafPairs, _AuxData]] | None
|
||||
) = ...,
|
||||
) -> Any: ...
|
||||
def register_dataclass_node(
|
||||
self, __type: Type[_T], meta_fields: List[str], data_fields: List[str]
|
||||
) -> Any: ...
|
||||
|
||||
def default_registry() -> PyTreeRegistry: ...
|
||||
def tuple(registry: PyTreeRegistry, arg0: Sequence[PyTreeDef]) -> PyTreeDef: ...
|
||||
def all_leaves(registry: PyTreeRegistry, arg0: Iterable[Any]) -> bool: ...
|
||||
|
||||
class SequenceKey(Hashable):
|
||||
idx: int
|
||||
__match_args__: tuple = ...
|
||||
def __init__(self, idx: int): ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __getstate__(self) -> Any: ...
|
||||
def __setstate__(self, state: Any): ...
|
||||
def __eq__(self, __other: Any) -> bool: ...
|
||||
|
||||
class DictKey(Hashable):
|
||||
key: Hashable
|
||||
__match_args__: tuple = ...
|
||||
def __init__(self, key: Hashable): ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __getstate__(self) -> Any: ...
|
||||
def __setstate__(self, state: Any): ...
|
||||
def __eq__(self, __other: Any) -> bool: ...
|
||||
|
||||
class GetAttrKey(Hashable):
|
||||
name: str
|
||||
__match_args__: tuple = ...
|
||||
def __init__(self, name: str): ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __getstate__(self) -> Any: ...
|
||||
def __setstate__(self, state: Any): ...
|
||||
def __eq__(self, __other: Any) -> bool: ...
|
||||
|
||||
class FlattenedIndexKey(Hashable):
|
||||
key: int
|
||||
__match_args__: tuple = ...
|
||||
def __init__(self, key: int): ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __getstate__(self) -> Any: ...
|
||||
def __setstate__(self, state: Any): ...
|
||||
def __eq__(self, __other: Any) -> bool: ...
|
||||
|
||||
class PyTreeDef:
|
||||
def unflatten(self, __leaves: Iterable[Any]) -> Any: ...
|
||||
def flatten_up_to(self, __xs: Any) -> List[Any]: ...
|
||||
def compose(self, __inner: PyTreeDef) -> PyTreeDef: ...
|
||||
def walk(
|
||||
self,
|
||||
__f_node: Callable[[Any, Any], Any],
|
||||
__f_leaf: Optional[Callable[[_T], Any]],
|
||||
leaves: Iterable[Any],
|
||||
) -> Any: ...
|
||||
def from_iterable_tree(self, __xs: Any): ...
|
||||
def node_data(self) -> Optional[Tuple[Type, Any]]: ...
|
||||
def children(self) -> List[PyTreeDef]: ...
|
||||
@staticmethod
|
||||
def make_from_node_data_and_children(
|
||||
registry: PyTreeRegistry,
|
||||
node_data: Optional[Tuple[Type, Any]],
|
||||
children: Iterable[PyTreeDef],
|
||||
) -> PyTreeDef: ...
|
||||
|
||||
num_leaves: int
|
||||
num_nodes: int
|
||||
def __repr__(self) -> str: ...
|
||||
def __eq__(self, __other: PyTreeDef) -> bool: ...
|
||||
def __ne__(self, __other: PyTreeDef) -> bool: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __getstate__(self) -> Any: ...
|
||||
def __setstate__(self, state: Any): ...
|
||||
def serialize_using_proto(self) -> bytes: ...
|
||||
@staticmethod
|
||||
def deserialize_using_proto(
|
||||
registry: PyTreeRegistry, data: bytes
|
||||
) -> PyTreeDef: ...
|
||||
|
||||
_Children = TypeVar("_Children", bound=Iterable[Any])
|
||||
_KeyLeafPair = TypeVar("_KeyLeafPair", bound=Tuple[Any, Any])
|
||||
_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[_KeyLeafPair])
|
||||
_KeyPath = TypeVar("_KeyPath", bound=Tuple[Any, ...])
|
||||
_AuxData = TypeVar("_AuxData", bound=Hashable)
|
32
jaxlib/xla/xla_extension/sdy.pyi
Normal file
32
jaxlib/xla/xla_extension/sdy.pyi
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2021 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from mlir import ir
|
||||
|
||||
def sdy_round_trip_export_pipeline(
|
||||
module: ir.module
|
||||
) -> str: ...
|
||||
|
||||
def sdy_round_trip_import_shardings(
|
||||
module: ir.module
|
||||
) -> str: ...
|
||||
|
||||
def get_mesh(
|
||||
module: ir.module
|
||||
) -> tuple[tuple[str, int], ...]: ...
|
||||
|
||||
def lowered_with_shardy(
|
||||
module: ir.module
|
||||
) -> bool: ...
|
39
jaxlib/xla/xla_extension/transfer_guard_lib.pyi
Normal file
39
jaxlib/xla/xla_extension/transfer_guard_lib.pyi
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2022 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
class TransferGuardLevel:
|
||||
ALLOW: Any
|
||||
LOG: Any
|
||||
DISALLOW: Any
|
||||
LOG_EXPLICIT: Any
|
||||
DISALLOW_EXPLICIT: Any
|
||||
|
||||
class TransferGuardState:
|
||||
host_to_device: Optional[TransferGuardLevel]
|
||||
device_to_device: Optional[TransferGuardLevel]
|
||||
device_to_host: Optional[TransferGuardLevel]
|
||||
|
||||
explicit_device_put: bool
|
||||
explicit_device_get: bool
|
||||
|
||||
def global_state() -> TransferGuardState: ...
|
||||
def thread_local_state() -> TransferGuardState: ...
|
||||
|
||||
class _TestingScopedLogSink:
|
||||
def __enter__(self) -> _TestingScopedLogSink: ...
|
||||
def __exit__(self, *args, **kwargs) -> None: ...
|
||||
def logs(self) -> List[str]: ...
|
17
jaxlib/xla_extension.py
Normal file
17
jaxlib/xla_extension.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from jaxlib.xla.xla_extension import * # noqa: F403
|
||||
from jaxlib.xla.xla_extension import sdy # noqa: F401
|
Loading…
x
Reference in New Issue
Block a user