mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Migrate xla_client and its Python tests out of XLA into JAX.
This change copies targets into jaxlib, and a subsequent change will delete them from XLA. We separate these into two phases because we cannot atomically change both JAX and XLA. Future changes will migrate more of the C++ pieces of XLA:Python. PiperOrigin-RevId: 739158120
This commit is contained in:
parent
f6cbab70fd
commit
0d5458316b
@ -44,6 +44,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mosaic/python:tpu_dialect",
|
||||
"//jaxlib:cpu_feature_guard",
|
||||
"//jaxlib:utils",
|
||||
"//jaxlib/xla:xla_client",
|
||||
"//jaxlib/triton",
|
||||
"//jaxlib/mlir/_mlir_libs:register_jax_dialects",
|
||||
"//jaxlib/mlir:arithmetic_dialect",
|
||||
@ -60,6 +61,6 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:sparse_tensor_dialect",
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
# xla_client
|
||||
# xla_extension
|
||||
]),
|
||||
)
|
||||
|
@ -40,7 +40,7 @@ except Exception as err:
|
||||
raise ImportError(msg) from err
|
||||
|
||||
|
||||
# Checks the jaxlib version before importing anything else from jaxlib.
|
||||
# Checks the jaxlib version before importing anything else.
|
||||
# Returns the jaxlib version string.
|
||||
def check_jaxlib_version(jax_version: str, jaxlib_version: str,
|
||||
minimum_jaxlib_version: str) -> tuple[int, ...]:
|
||||
@ -77,20 +77,23 @@ version = check_jaxlib_version(
|
||||
jaxlib_version=jaxlib.version.__version__,
|
||||
minimum_jaxlib_version=jax.version._minimum_jaxlib_version)
|
||||
|
||||
# Before importing any C compiled modules from jaxlib, first import the CPU
|
||||
# Before importing any C compiled modules, first import the CPU
|
||||
# feature guard module to verify that jaxlib was compiled in a way that only
|
||||
# uses instructions that are present on this machine.
|
||||
import jaxlib.cpu_feature_guard as cpu_feature_guard
|
||||
cpu_feature_guard.check_cpu_features()
|
||||
|
||||
import jaxlib.utils as utils # noqa: F401
|
||||
import jaxlib.xla_client as xla_client
|
||||
import jaxlib.lapack as lapack # noqa: F401
|
||||
import jaxlib.utils as utils # noqa: F401
|
||||
import jaxlib.xla_extension as xla_extension # noqa: F401
|
||||
from jaxlib.xla_extension import guard_lib as guard_lib # noqa: F401
|
||||
from jaxlib.xla_extension import jax_jit as jax_jit # noqa: F401
|
||||
from jaxlib.xla_extension import pmap_lib as pmap_lib # noqa: F401
|
||||
from jaxlib.xla_extension import pytree as pytree # noqa: F401
|
||||
import jaxlib.xla_client as xla_client # noqa: F401
|
||||
|
||||
from jaxlib.xla_extension import Device as Device # noqa: F401
|
||||
|
||||
xla_extension = xla_client._xla
|
||||
pytree = xla_client._xla.pytree
|
||||
jax_jit = xla_client._xla.jax_jit
|
||||
pmap_lib = xla_client._xla.pmap_lib
|
||||
|
||||
# XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882
|
||||
def _xla_gc_callback(*args):
|
||||
@ -167,6 +170,3 @@ def _cuda_path() -> str | None:
|
||||
return None
|
||||
|
||||
cuda_path = _cuda_path()
|
||||
|
||||
guard_lib = xla_client._xla.guard_lib
|
||||
Device = xla_client._xla.Device
|
||||
|
@ -81,7 +81,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
"//jaxlib/mosaic",
|
||||
"//jaxlib/triton",
|
||||
"@xla//xla/python:xla_extension",
|
||||
"//jaxlib/xla:xla_client",
|
||||
],
|
||||
)
|
||||
|
||||
@ -94,7 +94,7 @@ symlink_files(
|
||||
|
||||
symlink_files(
|
||||
name = "xla_client",
|
||||
srcs = ["@xla//xla/python:xla_client"],
|
||||
srcs = ["//jaxlib/xla:xla_client"],
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
|
@ -132,6 +132,9 @@ def pytype_strict_library(name, pytype_srcs = [], **kwargs):
|
||||
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
|
||||
native.py_library(name = name, data = data, **new_kwargs)
|
||||
|
||||
py_strict_library = native.py_library
|
||||
py_strict_test = native.py_test
|
||||
|
||||
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
|
||||
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
|
||||
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
|
||||
|
162
jaxlib/xla/BUILD
Normal file
162
jaxlib/xla/BUILD
Normal file
@ -0,0 +1,162 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"nanobind_extension",
|
||||
"py_deps",
|
||||
"py_strict_library",
|
||||
"py_strict_test",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//jax:internal"],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "xla_python",
|
||||
includes = [
|
||||
"//jax:internal",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "xla_client",
|
||||
srcs = ["xla_client.py"],
|
||||
pytype_srcs = ["xla_client.pyi"],
|
||||
visibility = [":xla_python"],
|
||||
deps = py_deps([
|
||||
"numpy",
|
||||
"ml_dtypes",
|
||||
]) + ["@xla//xla/python:xla_extension"],
|
||||
)
|
||||
|
||||
py_strict_test(
|
||||
name = "xla_client_backend_independent_test",
|
||||
srcs = ["xla_client_backend_independent_test.py"],
|
||||
deps = [
|
||||
":xla_client",
|
||||
] + py_deps([
|
||||
"absl/testing",
|
||||
"numpy",
|
||||
"portpicker",
|
||||
]),
|
||||
)
|
||||
|
||||
py_strict_library(
|
||||
name = "xla_client_test",
|
||||
testonly = 1,
|
||||
srcs = ["xla_client_test.py"],
|
||||
visibility = [":xla_python"],
|
||||
deps = [
|
||||
":xla_client",
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
"//jaxlib",
|
||||
] + py_deps([
|
||||
"absl/flags",
|
||||
"absl/logging",
|
||||
"absl/testing",
|
||||
"ml_dtypes",
|
||||
"numpy",
|
||||
]),
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "custom_calls_testlib",
|
||||
testonly = 1,
|
||||
srcs = ["custom_calls_testlib.cc"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/status",
|
||||
"@nanobind",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
],
|
||||
)
|
||||
|
||||
py_strict_test(
|
||||
name = "xla_client_test_cpu",
|
||||
srcs = ["xla_client_test.py"],
|
||||
args = ["--backend=cpu"],
|
||||
env = {
|
||||
"XLA_FLAGS": "--xla_force_host_platform_device_count=4",
|
||||
},
|
||||
main = "xla_client_test.py",
|
||||
deps = [
|
||||
":custom_calls_testlib",
|
||||
":xla_client",
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
"//jaxlib",
|
||||
] + py_deps([
|
||||
"absl/flags",
|
||||
"absl/logging",
|
||||
"absl/testing",
|
||||
"ml_dtypes",
|
||||
"numpy",
|
||||
]),
|
||||
)
|
||||
|
||||
py_strict_test(
|
||||
name = "weakref_lru_cache_test",
|
||||
srcs = ["weakref_lru_cache_test.py"],
|
||||
deps = [
|
||||
":xla_client",
|
||||
] + py_deps([
|
||||
"absl/flags",
|
||||
"absl/logging",
|
||||
"absl/testing",
|
||||
]),
|
||||
)
|
||||
|
||||
py_strict_test(
|
||||
name = "pytree_test",
|
||||
srcs = ["pytree_test.py"],
|
||||
deps = [
|
||||
":xla_client",
|
||||
] + py_deps([
|
||||
"absl/flags",
|
||||
"absl/logging",
|
||||
"absl/testing",
|
||||
]),
|
||||
)
|
||||
|
||||
py_strict_test(
|
||||
name = "config_test",
|
||||
srcs = ["config_test.py"],
|
||||
deps = [
|
||||
":xla_client",
|
||||
] + py_deps([
|
||||
"absl/flags",
|
||||
"absl/logging",
|
||||
"absl/testing",
|
||||
]),
|
||||
)
|
||||
|
||||
py_strict_test(
|
||||
name = "jax_jit_test",
|
||||
srcs = ["jax_jit_test.py"],
|
||||
deps = [
|
||||
":xla_client",
|
||||
] + py_deps([
|
||||
"absl/flags",
|
||||
"absl/logging",
|
||||
"absl/testing",
|
||||
"numpy",
|
||||
]),
|
||||
)
|
71
jaxlib/xla/config_test.py
Normal file
71
jaxlib/xla/config_test.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright 2024 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import threading
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax.jaxlib.xla import xla_client
|
||||
|
||||
config = xla_client._xla.config
|
||||
|
||||
|
||||
class ConfigTest(absltest.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
c = config.Config(1)
|
||||
self.assertEqual(c.value, 1)
|
||||
self.assertEqual(c.get_global(), 1)
|
||||
self.assertEqual(c.get_local(), config.unset)
|
||||
|
||||
c.set_global(2)
|
||||
self.assertEqual(c.value, 2)
|
||||
self.assertEqual(c.get_global(), 2)
|
||||
self.assertEqual(c.get_local(), config.unset)
|
||||
|
||||
c.set_local(3)
|
||||
self.assertEqual(c.value, 3)
|
||||
self.assertEqual(c.get_global(), 2)
|
||||
self.assertEqual(c.get_local(), 3)
|
||||
|
||||
c.set_global(4)
|
||||
self.assertEqual(c.value, 3)
|
||||
self.assertEqual(c.get_global(), 4)
|
||||
self.assertEqual(c.get_local(), 3)
|
||||
|
||||
c.set_local(config.unset)
|
||||
self.assertEqual(c.value, 4)
|
||||
self.assertEqual(c.get_global(), 4)
|
||||
self.assertEqual(c.get_local(), config.unset)
|
||||
|
||||
def testThreading(self):
|
||||
c = config.Config(1)
|
||||
|
||||
def Body():
|
||||
for i in range(100):
|
||||
c.set_local(i)
|
||||
self.assertEqual(c.get_local(), i)
|
||||
self.assertEqual(c.get_global(), 1)
|
||||
self.assertEqual(c.value, i)
|
||||
|
||||
threads = [threading.Thread(target=Body) for _ in range(4)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
128
jaxlib/xla/custom_calls_testlib.cc
Normal file
128
jaxlib/xla/custom_calls_testlib.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2024 The JAX Authors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "third_party/nanobind/include/nanobind/nanobind.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
namespace nb = ::nanobind;
|
||||
|
||||
// Implement custom calls as static functions with XLA FFI types in the function
|
||||
// signature that gives access to the arguments and results buffers together
|
||||
// with their types and dimensions. See `ffi/api/ffi_test.cc` for more XLA FFI
|
||||
// examples and features (e.g. binding attributes, custom user-defined structs
|
||||
// and arbitrary execution context).
|
||||
|
||||
static Error AlwaysFail(Result<AnyBuffer>) {
|
||||
return Error(XLA_FFI_Error_Code_INTERNAL, "Failed intentionally");
|
||||
}
|
||||
|
||||
static Error AlwaysSucceed(Result<AnyBuffer>) { return Error::Success(); }
|
||||
|
||||
static Error Subtract(BufferR0<DataType::F32> a, BufferR0<DataType::F32> b,
|
||||
Result<BufferR0<DataType::F32>> out) {
|
||||
*out->typed_data() = *a.typed_data() - *b.typed_data();
|
||||
return Error::Success();
|
||||
}
|
||||
|
||||
static Error SubtractCst(BufferR0<DataType::F32> a,
|
||||
Result<BufferR0<DataType::F32>> out, float cst) {
|
||||
*out->typed_data() = *a.typed_data() - cst;
|
||||
return Error::Success();
|
||||
}
|
||||
|
||||
// Define XLA FFI handlers from the implementations defined above using explicit
|
||||
// XLA FFI binding API to describe type signatures of custom calls.
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, Ffi::Bind().Ret<AnyBuffer>());
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed,
|
||||
Ffi::Bind().Ret<AnyBuffer>());
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER(kSubtract, Subtract,
|
||||
Ffi::Bind()
|
||||
.Arg<BufferR0<DataType::F32>>()
|
||||
.Arg<BufferR0<DataType::F32>>()
|
||||
.Ret<BufferR0<DataType::F32>>());
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER(kSubtractCst, SubtractCst,
|
||||
Ffi::Bind()
|
||||
.Arg<BufferR0<DataType::F32>>()
|
||||
.Ret<BufferR0<DataType::F32>>()
|
||||
.Attr<float>("cst"));
|
||||
|
||||
// XLA FFI calls can also be stateful.
|
||||
struct TestFfiState {
|
||||
static TypeId id;
|
||||
explicit TestFfiState(int32_t value) : value(value) {}
|
||||
int32_t value;
|
||||
};
|
||||
TypeId TestFfiState::id = {};
|
||||
|
||||
static ErrorOr<std::unique_ptr<TestFfiState>> StateInstantiate() {
|
||||
return std::make_unique<TestFfiState>(42);
|
||||
}
|
||||
|
||||
static Error StateExecute(TestFfiState* state,
|
||||
Result<BufferR0<DataType::S32>> out) {
|
||||
*out->typed_data() = state->value;
|
||||
return Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate,
|
||||
Ffi::BindInstantiate());
|
||||
XLA_FFI_DEFINE_HANDLER(
|
||||
kStateExecute, StateExecute,
|
||||
Ffi::Bind().Ctx<State<TestFfiState>>().Ret<BufferR0<DataType::S32>>());
|
||||
|
||||
template <typename T>
|
||||
static auto BindFunction(T* fn) {
|
||||
return nb::capsule(reinterpret_cast<void*>(fn));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static auto BindTypeId(T* typeId) {
|
||||
return nb::capsule(reinterpret_cast<void*>(typeId));
|
||||
}
|
||||
|
||||
// Custom calls registration library that exports function pointers to XLA FFI
|
||||
// handlers to the python users.
|
||||
NB_MODULE(custom_calls_testlib, m) {
|
||||
m.def("registrations", []() {
|
||||
nb::dict dict;
|
||||
dict["always_fail"] = BindFunction(kAlwaysFail);
|
||||
dict["always_succeed"] = BindFunction(kAlwaysSucceed);
|
||||
dict["subtract_f32"] = BindFunction(kSubtract);
|
||||
dict["subtract_f32_cst"] = BindFunction(kSubtractCst);
|
||||
|
||||
nb::dict bundle;
|
||||
bundle["instantiate"] = BindFunction(kStateInstantiate);
|
||||
bundle["execute"] = BindFunction(kStateExecute);
|
||||
dict["stateful"] = bundle;
|
||||
|
||||
return dict;
|
||||
});
|
||||
m.def("type_ids", []() {
|
||||
nb::dict type_ids;
|
||||
type_ids["test_ffi_state"] = BindTypeId(&TestFfiState::id);
|
||||
return type_ids;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace xla::ffi
|
47
jaxlib/xla/jax_jit_test.py
Normal file
47
jaxlib/xla/jax_jit_test.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2024 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for jax_jit helper functions."""
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax.jaxlib.xla import xla_client
|
||||
|
||||
jax_jit = xla_client._xla.jax_jit
|
||||
pytree = xla_client._xla.pytree
|
||||
|
||||
pytree_registry = pytree.default_registry()
|
||||
|
||||
|
||||
class JaxJitTest(absltest.TestCase):
|
||||
|
||||
def testParseArguments(self):
|
||||
sig, args = jax_jit.parse_arguments(
|
||||
positional_args=[1, 2, 3],
|
||||
keyword_args=[4, 5],
|
||||
kwnames=("a", "b"),
|
||||
static_argnums=[0, 2],
|
||||
static_argnames=["a"],
|
||||
pytree_registry=pytree_registry,
|
||||
)
|
||||
self.assertEqual(args, [2, 5])
|
||||
self.assertEqual(sig.static_args, [1, 3, 4])
|
||||
self.assertEqual(sig.static_arg_names, ["a"])
|
||||
_, leaf = pytree_registry.flatten(0)
|
||||
self.assertEqual(sig.dynamic_arg_names, ["b"])
|
||||
self.assertEqual(sig.dynamic_arg_treedefs, [leaf, leaf])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
144
jaxlib/xla/pytree_test.py
Normal file
144
jaxlib/xla/pytree_test.py
Normal file
@ -0,0 +1,144 @@
|
||||
# Copyright 2023 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import collections
|
||||
import dataclasses
|
||||
import gc
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax.jaxlib.xla import xla_client
|
||||
|
||||
pytree = xla_client._xla.pytree
|
||||
|
||||
|
||||
ExampleType = collections.namedtuple("ExampleType", "field0 field1")
|
||||
|
||||
registry = pytree.PyTreeRegistry()
|
||||
|
||||
|
||||
class ExampleType2:
|
||||
|
||||
def __init__(self, field0, field1):
|
||||
self.field0 = field0
|
||||
self.field1 = field1
|
||||
|
||||
def to_iterable(self):
|
||||
return [self.field0, self.field1], (None,)
|
||||
|
||||
|
||||
def from_iterable(state, values):
|
||||
del state
|
||||
return ExampleType2(field0=values[0], field1=values[1])
|
||||
|
||||
|
||||
registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Custom:
|
||||
a: int
|
||||
b: str
|
||||
|
||||
|
||||
registry.register_dataclass_node(Custom, ["a"], ["b"])
|
||||
|
||||
|
||||
class PyTreeTest(absltest.TestCase):
|
||||
|
||||
def roundtrip(self, example):
|
||||
original = registry.flatten(example)[1]
|
||||
self.assertEqual(
|
||||
pytree.PyTreeDef.deserialize_using_proto(
|
||||
registry, original.serialize_using_proto()
|
||||
),
|
||||
original,
|
||||
)
|
||||
|
||||
def testSerializeDeserializeNoPickle(self):
|
||||
o = object()
|
||||
self.roundtrip(({"a": o, "b": o}, [o, (o, o), None]))
|
||||
|
||||
def testSerializeWithFallback(self):
|
||||
o = object()
|
||||
with self.assertRaises(ValueError):
|
||||
self.roundtrip({"a": ExampleType(field0=o, field1=o)})
|
||||
|
||||
def testRegisteredType(self):
|
||||
o = object()
|
||||
with self.assertRaises(ValueError):
|
||||
self.roundtrip({"a": ExampleType2(field0=o, field1=o)})
|
||||
|
||||
def roundtrip_node_data(self, example):
|
||||
original = registry.flatten(example)[1]
|
||||
restored = pytree.PyTreeDef.make_from_node_data_and_children(
|
||||
registry, original.node_data(), original.children()
|
||||
)
|
||||
self.assertEqual(restored, original)
|
||||
|
||||
def testRoundtripNodeData(self):
|
||||
o = object()
|
||||
self.roundtrip_node_data([o, o, o])
|
||||
self.roundtrip_node_data((o, o, o))
|
||||
self.roundtrip_node_data({"a": o, "b": o})
|
||||
self.roundtrip_node_data({22: o, 88: o})
|
||||
self.roundtrip_node_data(None)
|
||||
self.roundtrip_node_data(o)
|
||||
self.roundtrip_node_data(ExampleType(field0=o, field1=o))
|
||||
self.roundtrip_node_data(ExampleType2(field0=o, field1=o))
|
||||
|
||||
def testCompose(self):
|
||||
x = registry.flatten(0)[1]
|
||||
y = registry.flatten((0, 0))[1]
|
||||
self.assertEqual((x.compose(y)).num_leaves, 2)
|
||||
|
||||
def testDataclassMakeFromNodeData(self):
|
||||
c = Custom(1, "a")
|
||||
c_leafs, c_tree = registry.flatten(c)
|
||||
c_tree2 = c_tree.make_from_node_data_and_children(
|
||||
registry, c_tree.node_data(), c_tree.children()
|
||||
)
|
||||
self.assertEqual(c_tree2.unflatten(c_leafs), c)
|
||||
self.assertEqual(str(c_tree2), str(c_tree))
|
||||
|
||||
def testTpTraverse(self):
|
||||
self.assertContainsSubset(
|
||||
[
|
||||
pytree.PyTreeRegistry,
|
||||
ExampleType2,
|
||||
ExampleType2.to_iterable,
|
||||
from_iterable,
|
||||
],
|
||||
gc.get_referents(registry),
|
||||
)
|
||||
k1 = "k1"
|
||||
k2 = "k2"
|
||||
|
||||
t = ExampleType("a", "b")
|
||||
_, treedef = registry.flatten([1, {k1: 2, k2: t}, 5, t])
|
||||
|
||||
self.assertContainsSubset(
|
||||
[
|
||||
pytree.PyTreeDef,
|
||||
registry,
|
||||
k1,
|
||||
k2,
|
||||
ExampleType,
|
||||
],
|
||||
gc.get_referents(treedef),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
257
jaxlib/xla/weakref_lru_cache_test.py
Normal file
257
jaxlib/xla/weakref_lru_cache_test.py
Normal file
@ -0,0 +1,257 @@
|
||||
# Copyright 2023 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import gc
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax.jaxlib.xla import xla_client
|
||||
|
||||
|
||||
class WeakrefLRUCacheTest(absltest.TestCase):
|
||||
|
||||
def testMultiThreaded(self):
|
||||
insert_evs = [threading.Event() for _ in range(2)]
|
||||
insert_evs_i = 0
|
||||
|
||||
class WRKey:
|
||||
pass
|
||||
|
||||
class ClashingKey:
|
||||
|
||||
def __eq__(self, other):
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return 333 # induce maximal caching problems.
|
||||
|
||||
class GilReleasingCacheKey:
|
||||
|
||||
def __eq__(self, other):
|
||||
nonlocal insert_evs_i
|
||||
if isinstance(other, GilReleasingCacheKey) and insert_evs_i < len(
|
||||
insert_evs
|
||||
):
|
||||
insert_evs[insert_evs_i].set()
|
||||
insert_evs_i += 1
|
||||
time.sleep(0.01)
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return 333 # induce maximal caching problems.
|
||||
|
||||
def CacheFn(obj, gil_releasing_cache_key):
|
||||
del obj
|
||||
del gil_releasing_cache_key
|
||||
return None
|
||||
|
||||
cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 2048)
|
||||
|
||||
wrkey = WRKey()
|
||||
|
||||
def Body():
|
||||
for insert_ev in insert_evs:
|
||||
insert_ev.wait()
|
||||
for _ in range(20):
|
||||
cache(wrkey, ClashingKey())
|
||||
|
||||
t = threading.Thread(target=Body)
|
||||
t.start()
|
||||
for _ in range(3):
|
||||
cache(wrkey, GilReleasingCacheKey())
|
||||
t.join()
|
||||
|
||||
def testAnotherMultiThreaded(self):
|
||||
num_workers = 5
|
||||
barrier = threading.Barrier(num_workers)
|
||||
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
|
||||
|
||||
class WRKey:
|
||||
pass
|
||||
|
||||
def WorkerAddToCache():
|
||||
barrier.wait()
|
||||
wrkey = WRKey()
|
||||
for i in range(10):
|
||||
cache(wrkey, i)
|
||||
|
||||
def WorkerCleanCache():
|
||||
barrier.wait()
|
||||
for _ in range(10):
|
||||
cache.cache_clear()
|
||||
|
||||
workers = [
|
||||
threading.Thread(target=WorkerAddToCache)
|
||||
for _ in range(num_workers - 1)
|
||||
] + [threading.Thread(target=WorkerCleanCache)]
|
||||
|
||||
for t in workers:
|
||||
t.start()
|
||||
|
||||
for t in workers:
|
||||
t.join()
|
||||
|
||||
def testKwargsDictOrder(self):
|
||||
miss_id = 0
|
||||
|
||||
class WRKey:
|
||||
pass
|
||||
|
||||
def CacheFn(obj, kwkey1, kwkey2):
|
||||
del obj, kwkey1, kwkey2
|
||||
nonlocal miss_id
|
||||
miss_id += 1
|
||||
return miss_id
|
||||
|
||||
cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4)
|
||||
|
||||
wrkey = WRKey()
|
||||
|
||||
self.assertEqual(cache(wrkey, kwkey1="a", kwkey2="b"), 1)
|
||||
self.assertEqual(cache(wrkey, kwkey1="b", kwkey2="a"), 2)
|
||||
self.assertEqual(cache(wrkey, kwkey2="b", kwkey1="a"), 1)
|
||||
|
||||
def testGetKeys(self):
|
||||
def CacheFn(obj, arg):
|
||||
del obj
|
||||
return arg + "extra"
|
||||
|
||||
cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4)
|
||||
|
||||
class WRKey:
|
||||
pass
|
||||
|
||||
wrkey = WRKey()
|
||||
|
||||
self.assertEmpty(cache.cache_keys())
|
||||
cache(wrkey, "arg1")
|
||||
cache(wrkey, "arg2")
|
||||
self.assertLen(cache.cache_keys(), 2)
|
||||
|
||||
def testNonWeakreferenceableKey(self):
|
||||
class NonWRKey:
|
||||
__slots__ = ()
|
||||
|
||||
non_wr_key = NonWRKey()
|
||||
with self.assertRaises(TypeError):
|
||||
weakref.ref(non_wr_key)
|
||||
|
||||
cache = xla_client.weakref_lru_cache(lambda: None, lambda x: 2048)
|
||||
for _ in range(100):
|
||||
with self.assertRaises(TypeError):
|
||||
cache(non_wr_key)
|
||||
|
||||
def testCrashingKey(self):
|
||||
class WRKey:
|
||||
pass
|
||||
|
||||
class CrashingKey:
|
||||
# A key that raises exceptions if eq or hash is called.
|
||||
|
||||
def __eq__(self, other):
|
||||
raise ValueError("eq")
|
||||
|
||||
def __hash__(self):
|
||||
raise ValueError("hash")
|
||||
|
||||
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
|
||||
wrkey = WRKey()
|
||||
with self.assertRaises(ValueError):
|
||||
for _ in range(100):
|
||||
cache(wrkey, CrashingKey())
|
||||
|
||||
def testPrintingStats(self):
|
||||
class WRKey:
|
||||
pass
|
||||
|
||||
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
|
||||
wrkey = WRKey()
|
||||
for i in range(10):
|
||||
cache(wrkey, i)
|
||||
for i in range(5):
|
||||
cache(wrkey, i)
|
||||
|
||||
self.assertEqual(
|
||||
repr(cache.cache_info()),
|
||||
"WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)",
|
||||
)
|
||||
|
||||
def testGCKeys(self):
|
||||
class WRKey:
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.x == other.x
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.x)
|
||||
|
||||
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
|
||||
keys = [WRKey(i) for i in range(10)]
|
||||
for i in range(10):
|
||||
cache(keys[i], i)
|
||||
|
||||
# Delete some keys, to exercise the weakref callback behavior.
|
||||
del keys[::2]
|
||||
|
||||
for key in keys:
|
||||
cache(key, 7)
|
||||
|
||||
def testTpTraverse(self):
|
||||
class WRKey:
|
||||
pass
|
||||
|
||||
def CacheContextFn():
|
||||
return None
|
||||
|
||||
def CallFn(x, y, *args, **kwargs):
|
||||
del x, args, kwargs
|
||||
return y
|
||||
|
||||
cache = xla_client.weakref_lru_cache(CacheContextFn, CallFn, 2048)
|
||||
|
||||
keys = [WRKey() for _ in range(10)]
|
||||
values = [str(i) for i in range(10)]
|
||||
args = [str(i) for i in range(10)]
|
||||
kwargs = {"a": "b"}
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
cache(key, value, *args, **kwargs)
|
||||
|
||||
expected_refs = (
|
||||
[
|
||||
CacheContextFn,
|
||||
CallFn,
|
||||
xla_client._xla.WeakrefLRUCache,
|
||||
kwargs,
|
||||
]
|
||||
+ [weakref.getweakrefs(key)[0] for key in keys]
|
||||
+ values
|
||||
+ args
|
||||
)
|
||||
|
||||
# Can't use assertContainsSubset because it doesn't support kwargs since
|
||||
# dicts aren't hashable.
|
||||
for ref in expected_refs:
|
||||
self.assertIn(ref, gc.get_referents(cache))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
1044
jaxlib/xla/xla_client.py
Normal file
1044
jaxlib/xla/xla_client.py
Normal file
File diff suppressed because it is too large
Load Diff
322
jaxlib/xla/xla_client.pyi
Normal file
322
jaxlib/xla/xla_client.pyi
Normal file
@ -0,0 +1,322 @@
|
||||
# Copyright 2021 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
import enum
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy
|
||||
|
||||
from jaxlib import xla_extension as _xla
|
||||
from jaxlib.xla_extension import ArrayImpl as ArrayImpl
|
||||
from jaxlib.xla_extension import AutotuneCacheMode as AutotuneCacheMode
|
||||
from jaxlib.xla_extension import Client as Client
|
||||
from jaxlib.xla_extension import CompileOptions as CompileOptions
|
||||
from jaxlib.xla_extension import Device as Device
|
||||
from jaxlib.xla_extension import DeviceAssignment as DeviceAssignment
|
||||
from jaxlib.xla_extension import DeviceList as DeviceList
|
||||
from jaxlib.xla_extension import DeviceTopology as DeviceTopology
|
||||
from jaxlib.xla_extension import DistributedRuntimeClient as DistributedRuntimeClient
|
||||
from jaxlib.xla_extension import FftType as FftType
|
||||
from jaxlib.xla_extension import Frame as Frame
|
||||
from jaxlib.xla_extension import GSPMDSharding as GSPMDSharding
|
||||
from jaxlib.xla_extension import HloSharding as HloSharding
|
||||
from jaxlib.xla_extension import HostBufferSemantics as HostBufferSemantics
|
||||
from jaxlib.xla_extension import ifrt_programs as ifrt_programs
|
||||
from jaxlib.xla_extension import Layout as Layout
|
||||
from jaxlib.xla_extension import LoadedExecutable as LoadedExecutable
|
||||
from jaxlib.xla_extension import Memory as Memory
|
||||
from jaxlib.xla_extension import NamedSharding as NamedSharding
|
||||
from jaxlib.xla_extension import ops as ops
|
||||
from jaxlib.xla_extension import OpSharding as OpSharding
|
||||
from jaxlib.xla_extension import PjRtLayout as PjRtLayout
|
||||
from jaxlib.xla_extension import PmapSharding as PmapSharding
|
||||
from jaxlib.xla_extension import PrimitiveType as PrimitiveType
|
||||
from jaxlib.xla_extension import ArrayCopySemantics as ArrayCopySemantics
|
||||
from jaxlib.xla_extension import profiler as profiler
|
||||
from jaxlib.xla_extension import Shape as Shape
|
||||
from jaxlib.xla_extension import Sharding as Sharding
|
||||
from jaxlib.xla_extension import SingleDeviceSharding as SingleDeviceSharding
|
||||
from jaxlib.xla_extension import Traceback as Traceback
|
||||
from jaxlib.xla_extension import XlaBuilder as XlaBuilder
|
||||
from jaxlib.xla_extension import XlaComputation as XlaComputation
|
||||
from jaxlib.xla_extension import XlaOp as XlaOp
|
||||
|
||||
_version: int
|
||||
|
||||
mlir_api_version: int
|
||||
|
||||
bfloat16: type[numpy.generic]
|
||||
# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
|
||||
# float4_e2m1fn: type[numpy.generic]
|
||||
# float8_e3m4: type[numpy.generic]
|
||||
# float8_e4m3: type[numpy.generic]
|
||||
# float8_e8m0fnu: type[numpy.generic]
|
||||
float8_e4m3fn: type[numpy.generic]
|
||||
float8_e4m3b11fnuz: type[numpy.generic]
|
||||
float8_e4m3fnuz: type[numpy.generic]
|
||||
float8_e5m2: type[numpy.generic]
|
||||
float8_e5m2fnuz: type[numpy.generic]
|
||||
XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype]
|
||||
|
||||
_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]
|
||||
|
||||
def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType:
|
||||
...
|
||||
|
||||
def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ...
|
||||
|
||||
def heap_profile(client: Client) -> bytes:
|
||||
...
|
||||
|
||||
XlaRuntimeError = _xla.XlaRuntimeError
|
||||
|
||||
def make_cpu_client(
|
||||
asynchronous: bool = ...,
|
||||
distributed_client: DistributedRuntimeClient | None = ...,
|
||||
node_id: int = ...,
|
||||
num_nodes: int = ...,
|
||||
collectives: _xla.CpuCollectives | None = ...,
|
||||
num_devices: int | None = ...,
|
||||
) -> Client:
|
||||
...
|
||||
|
||||
def make_gpu_client(
|
||||
distributed_client: DistributedRuntimeClient | None = ...,
|
||||
node_id: int = ...,
|
||||
num_nodes: int = ...,
|
||||
platform_name: str | None = ...,
|
||||
allowed_devices: set[int] | None = ...,
|
||||
mock: bool | None = ...,
|
||||
mock_gpu_topology: str | None = ...,
|
||||
) -> Client:
|
||||
...
|
||||
|
||||
def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None) -> Client:
|
||||
...
|
||||
|
||||
def make_tfrt_tpu_c_api_device_topology(
|
||||
topology_name: str | None = None, **kwargs
|
||||
) -> DeviceTopology:
|
||||
...
|
||||
|
||||
def make_c_api_device_topology(c_api: Any, topology_name: str = '', **kwargs) -> DeviceTopology:
|
||||
...
|
||||
|
||||
def get_topology_for_devices(devices: list[Device]) -> DeviceTopology:
|
||||
...
|
||||
|
||||
def make_tpu_client(
|
||||
library_path: str | None, options: _NameValueMapping | None = None
|
||||
) -> Client:
|
||||
...
|
||||
|
||||
def make_c_api_client(
|
||||
plugin_name: str,
|
||||
options: _NameValueMapping | None = None,
|
||||
distributed_client: DistributedRuntimeClient | None = None,
|
||||
) -> Client:
|
||||
...
|
||||
|
||||
def pjrt_plugin_loaded(plugin_name: str) -> bool:
|
||||
...
|
||||
|
||||
def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any:
|
||||
...
|
||||
|
||||
def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None:
|
||||
...
|
||||
|
||||
def pjrt_plugin_initialized(plugin_name: str) -> bool:
|
||||
...
|
||||
|
||||
def initialize_pjrt_plugin(plugin_name: str) -> None:
|
||||
...
|
||||
|
||||
def generate_pjrt_gpu_plugin_options() -> _NameValueMapping:
|
||||
...
|
||||
|
||||
class OpMetadata:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
op_type: str | None = ...,
|
||||
op_name: str | None = ...,
|
||||
source_file: str | None = ...,
|
||||
source_line: int | None = ...,
|
||||
):
|
||||
...
|
||||
op_type: str | None
|
||||
op_name: str | None
|
||||
source_file: str | None
|
||||
source_line: int | None
|
||||
|
||||
class PaddingConfigDimension:
|
||||
edge_padding_low: int
|
||||
edge_padding_high: int
|
||||
interior_padding: int
|
||||
|
||||
class PaddingConfig:
|
||||
dimensions: list[PaddingConfigDimension]
|
||||
|
||||
def make_padding_config(
|
||||
padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]],
|
||||
) -> PaddingConfig:
|
||||
...
|
||||
|
||||
class PaddingType(enum.Enum):
|
||||
VALID = 1
|
||||
SAME = 2
|
||||
|
||||
class DotDimensionNumbers:
|
||||
lhs_contracting_dimensions: list[int]
|
||||
rhs_contracting_dimensions: list[int]
|
||||
lhs_batch_dimensions: list[int]
|
||||
rhs_batch_dimensions: list[int]
|
||||
|
||||
def make_dot_dimension_numbers(
|
||||
dimension_numbers: Union[
|
||||
DotDimensionNumbers,
|
||||
tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]],
|
||||
],
|
||||
) -> DotDimensionNumbers:
|
||||
...
|
||||
|
||||
class ConvolutionDimensionNumbers:
|
||||
input_batch_dimension: int
|
||||
input_feature_dimension: int
|
||||
input_spatial_dimensions: list[int]
|
||||
kernel_input_feature_dimension: int
|
||||
kernel_output_feature_dimension: int
|
||||
kernel_spatial_dimensions: list[int]
|
||||
output_batch_dimension: int
|
||||
output_feature_dimension: int
|
||||
output_spatial_dimensions: list[int]
|
||||
|
||||
def make_convolution_dimension_numbers(
|
||||
dimension_numbers: Union[
|
||||
None, ConvolutionDimensionNumbers, tuple[str, str, str]
|
||||
],
|
||||
num_spatial_dimensions: int,
|
||||
) -> ConvolutionDimensionNumbers:
|
||||
...
|
||||
|
||||
class PrecisionConfig:
|
||||
Precision = _xla.PrecisionConfig_Precision
|
||||
operand_precision: list[_xla.PrecisionConfig_Precision]
|
||||
|
||||
class ResultAccuracy:
|
||||
mode: _xla.ResultAccuracy_Mode
|
||||
atol: float
|
||||
rtol: float
|
||||
ulps: int
|
||||
|
||||
class GatherDimensionNumbers:
|
||||
offset_dims: list[int]
|
||||
collapsed_slice_dims: list[int]
|
||||
start_index_map: list[int]
|
||||
index_vector_dim: int
|
||||
operand_batching_dims: list[int]
|
||||
start_indices_batching_dims: list[int]
|
||||
|
||||
class ScatterDimensionNumbers:
|
||||
update_window_dims: list[int]
|
||||
inserted_window_dims: list[int]
|
||||
scatter_dims_to_operand_dims: list[int]
|
||||
index_vector_dim: int
|
||||
input_batching_dims: list[int]
|
||||
scatter_indices_batching_dims: list[int]
|
||||
|
||||
class ReplicaGroup:
|
||||
replica_ids: list[int]
|
||||
|
||||
def make_replica_groups(
|
||||
replica_groups: Sequence[Sequence[int]] | None,
|
||||
) -> list[ReplicaGroup]:
|
||||
...
|
||||
|
||||
def weakref_lru_cache(cache_context_fn: Callable, call: Callable, maxsize=...) -> _xla.WeakrefLRUCache:
|
||||
...
|
||||
|
||||
def batched_copy_array_to_devices_with_sharding(
|
||||
arrays: Sequence[ArrayImpl],
|
||||
devices: Sequence[list[Device]],
|
||||
sharding: Sequence[Any],
|
||||
array_copy_semantics: Sequence[ArrayCopySemantics],
|
||||
) -> Sequence[ArrayImpl]: ...
|
||||
|
||||
def batched_device_put(
|
||||
aval: Any,
|
||||
sharding: Any,
|
||||
shards: Sequence[Any],
|
||||
devices: list[Device],
|
||||
committed: bool = ...,
|
||||
force_copy: bool = ...,
|
||||
host_buffer_semantics: Any = ...,
|
||||
) -> ArrayImpl: ...
|
||||
|
||||
def reorder_shards(
|
||||
x: ArrayImpl,
|
||||
dst_sharding: Any,
|
||||
array_copy_semantics: ArrayCopySemantics,
|
||||
) -> ArrayImpl: ...
|
||||
|
||||
def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ...
|
||||
|
||||
def check_and_canonicalize_memory_kind(
|
||||
memory_kind: str | None, device_list: DeviceList
|
||||
) -> str | None: ...
|
||||
|
||||
def array_result_handler(
|
||||
aval: Any,
|
||||
sharding: Any,
|
||||
committed: bool,
|
||||
_skip_checks: bool = ...) -> Callable:
|
||||
...
|
||||
|
||||
class CustomCallTargetTraits(enum.IntFlag):
|
||||
DEFAULT = 0
|
||||
COMMAND_BUFFER_COMPATIBLE = 1
|
||||
|
||||
def register_custom_call_target(
|
||||
name: str,
|
||||
fn: Any,
|
||||
platform: str = ...,
|
||||
api_version: int = ...,
|
||||
traits: CustomCallTargetTraits = ...,
|
||||
) -> None: ...
|
||||
|
||||
def register_custom_call_handler(
|
||||
xla_platform_name: str, handler: Any
|
||||
) -> None: ...
|
||||
|
||||
def custom_call_targets(platform: str) -> dict[str, Any]: ...
|
||||
|
||||
def register_custom_type_id(
|
||||
type_name: str,
|
||||
type_id: Any,
|
||||
platform: str = ...,
|
||||
) -> None: ...
|
||||
|
||||
def register_custom_type_id_handler(platform: str, handler: Any) -> None: ...
|
||||
|
||||
def encode_inspect_sharding_callback(handler: Any) -> bytes: ...
|
||||
|
||||
register_custom_call_partitioner = _xla.register_custom_call_partitioner
|
||||
register_custom_call_as_batch_partitionable = (
|
||||
_xla.register_custom_call_as_batch_partitionable
|
||||
)
|
195
jaxlib/xla/xla_client_backend_independent_test.py
Normal file
195
jaxlib/xla/xla_client_backend_independent_test.py
Normal file
@ -0,0 +1,195 @@
|
||||
# Copyright 2017 The JAX Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Backend-independent tests for the Python XLA client."""
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
from jax.jaxlib.xla import xla_client
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
import portpicker
|
||||
except ImportError:
|
||||
portpicker = None
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
ops = xla_client.ops
|
||||
|
||||
|
||||
class ShapeTest(absltest.TestCase):
|
||||
|
||||
def testInvalidShapes(self):
|
||||
with self.assertRaisesRegex(xla_client.XlaRuntimeError, "invalid shape"):
|
||||
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "layout minor_to_major field contains 1 element.*"):
|
||||
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "layout minor_to_major field has out-of-bounds value.*"):
|
||||
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4],
|
||||
[1, -1])
|
||||
|
||||
|
||||
class ComputationPrinting(absltest.TestCase):
|
||||
|
||||
def ExampleComputation(self):
|
||||
builder = xla_client.XlaBuilder("acomputation")
|
||||
p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0)))
|
||||
p1 = ops.Parameter(builder, 1,
|
||||
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
||||
x = ops.Mul(p0, p1)
|
||||
ops.Add(x, x)
|
||||
return builder.build()
|
||||
|
||||
def testComputationToHloText(self):
|
||||
computation = self.ExampleComputation()
|
||||
hlo_text = computation.as_hlo_text()
|
||||
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
|
||||
|
||||
def testComputationToHloGraph(self):
|
||||
computation = self.ExampleComputation()
|
||||
hlo_dot_graph = computation.as_hlo_dot_graph()
|
||||
self.assertTrue(hlo_dot_graph.startswith("digraph "))
|
||||
|
||||
def testHloModuleToHloText(self):
|
||||
computation = self.ExampleComputation()
|
||||
hlo_text = computation.as_hlo_module().to_string()
|
||||
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
|
||||
|
||||
def testHloModuleFromText(self):
|
||||
hlo_module_text = """HloModule test
|
||||
add {
|
||||
x = f32[] parameter(0)
|
||||
y = f32[] parameter(1)
|
||||
ROOT add = f32[] add(x, y)
|
||||
}
|
||||
ENTRY entry {
|
||||
p0 = f32[2,3] parameter(0)
|
||||
start = f32[2,3] all-reduce-start(p0), to_apply=add
|
||||
ROOT done = f32[2,3] all-reduce-done(start)
|
||||
}"""
|
||||
hlo_module = xla_client._xla.hlo_module_from_text(hlo_module_text)
|
||||
hlo_text = hlo_module.to_string()
|
||||
self.assertTrue(hlo_text.startswith("HloModule test"))
|
||||
|
||||
def testHloModuleToHloGraph(self):
|
||||
computation = self.ExampleComputation()
|
||||
hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph(
|
||||
computation.as_hlo_module())
|
||||
self.assertTrue(hlo_dot_graph.startswith("digraph "))
|
||||
|
||||
|
||||
class ComputationHashTest(absltest.TestCase):
|
||||
|
||||
def testHash(self):
|
||||
builder0 = xla_client.XlaBuilder("computation0")
|
||||
p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0)))
|
||||
p1 = ops.Parameter(builder0, 1,
|
||||
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
||||
ops.Mul(p0, p1)
|
||||
computation0 = builder0.build()
|
||||
|
||||
builder1 = xla_client.XlaBuilder("computation1")
|
||||
p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0)))
|
||||
p1 = ops.Parameter(builder1, 1,
|
||||
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
||||
ops.Mul(p0, p1)
|
||||
computation1 = builder1.build()
|
||||
|
||||
self.assertEqual(computation0.hash(), computation1.hash())
|
||||
|
||||
|
||||
class AliasTest(absltest.TestCase):
|
||||
|
||||
def testSetUpAlias(self):
|
||||
c = xla_client.XlaBuilder(self.id())
|
||||
p1 = ops.Parameter(
|
||||
c, 0,
|
||||
xla_client.shape_from_pyval(np.array(
|
||||
1.0, np.float32)).with_major_to_minor_layout_if_absent())
|
||||
p2 = ops.Parameter(
|
||||
c, 1,
|
||||
xla_client.shape_from_pyval(np.array(
|
||||
1.0, np.float32)).with_major_to_minor_layout_if_absent())
|
||||
out = ops.Add(p1, p2)
|
||||
c.setup_alias([], 0, [])
|
||||
c.build(out)
|
||||
|
||||
|
||||
class ProfilerTest(absltest.TestCase):
|
||||
|
||||
def testTraceMe(self):
|
||||
# TODO(phawkins): These tests just check that the TraceMe context manager
|
||||
# acts like a context manager and doesn't explode. Ideally we'd check that
|
||||
# the profiler saw the traceme too.
|
||||
with xla_client.profiler.TraceMe("test1"):
|
||||
pass
|
||||
with xla_client.profiler.TraceMe("test2", foo=123):
|
||||
pass
|
||||
with self.assertRaises(ValueError):
|
||||
with xla_client.profiler.TraceMe("test3"):
|
||||
raise ValueError("test")
|
||||
|
||||
@unittest.skipIf(portpicker is None, "Test requires portpicker")
|
||||
def testStartServer(self):
|
||||
port = portpicker.pick_unused_port()
|
||||
server = xla_client.profiler.start_server(port)
|
||||
del server
|
||||
|
||||
|
||||
class HloModuleGroupTest(absltest.TestCase):
|
||||
|
||||
def testHloModuleGroup(self):
|
||||
builder0 = xla_client.XlaBuilder("computation0")
|
||||
p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0)))
|
||||
p1 = ops.Parameter(builder0, 1,
|
||||
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
||||
root = ops.Mul(p0, p1)
|
||||
computation0 = builder0.build(root)
|
||||
|
||||
m = computation0.get_hlo_module()
|
||||
mg_name = "test_module_group"
|
||||
mg = xla_client._xla.HloModuleGroup(mg_name, [m])
|
||||
self.assertEqual(mg.name, mg_name)
|
||||
|
||||
modules = mg.to_modules()
|
||||
self.assertLen(modules, 1)
|
||||
self.assertEqual(m.to_string(), modules[0].to_string())
|
||||
|
||||
|
||||
class RunHloPassTest(absltest.TestCase):
|
||||
|
||||
def testHloDCE(self):
|
||||
b = xla_client.XlaBuilder("acomputation")
|
||||
p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(np.float32(0)))
|
||||
p1 = ops.Parameter(b, 1,
|
||||
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
||||
root = ops.Mul(p0, p1)
|
||||
|
||||
# Dead instructions
|
||||
p2 = ops.Parameter(b, 2, xla_client.shape_from_pyval(np.float32(0)))
|
||||
ops.Add(p2, p2)
|
||||
|
||||
hlo_module = b.build(root).get_hlo_module()
|
||||
self.assertTrue(xla_client._xla.HloDCE().run(hlo_module))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
3714
jaxlib/xla/xla_client_test.py
Normal file
3714
jaxlib/xla/xla_client_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -23,8 +23,11 @@ module = [
|
||||
"jax.experimental.jax2tf.tests.back_compat_testdata",
|
||||
"jax.experimental.jax2tf.tests.flax_models",
|
||||
"jax_cuda12_plugin.*",
|
||||
"jaxlib.*",
|
||||
"jaxlib.cpu_feature_guard",
|
||||
"jaxlib.cuda.*",
|
||||
"jaxlib.mlir.*",
|
||||
"jaxlib.utils",
|
||||
"jaxlib.xla_extension.utils",
|
||||
"jraph.*",
|
||||
"libtpu.*",
|
||||
"matplotlib.*",
|
||||
|
Loading…
x
Reference in New Issue
Block a user