1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Migrate xla_client and its Python tests out of XLA into JAX.

This change copies targets into jaxlib, and a subsequent change will delete them from XLA. We separate these into two phases because we cannot atomically change both JAX and XLA.

Future changes will migrate more of the C++ pieces of XLA:Python.

PiperOrigin-RevId: 739158120
This commit is contained in:
Peter Hawkins 2025-03-21 06:25:37 -07:00 committed by Charles Hofer
parent f6cbab70fd
commit 0d5458316b
15 changed files with 6106 additions and 15 deletions

@ -44,6 +44,7 @@ py_library_providing_imports_info(
"//jaxlib/mosaic/python:tpu_dialect",
"//jaxlib:cpu_feature_guard",
"//jaxlib:utils",
"//jaxlib/xla:xla_client",
"//jaxlib/triton",
"//jaxlib/mlir/_mlir_libs:register_jax_dialects",
"//jaxlib/mlir:arithmetic_dialect",
@ -60,6 +61,6 @@ py_library_providing_imports_info(
"//jaxlib/mlir:sparse_tensor_dialect",
"//jaxlib/mlir:stablehlo_dialect",
"//jaxlib/mlir:vector_dialect",
# xla_client
# xla_extension
]),
)

@ -40,7 +40,7 @@ except Exception as err:
raise ImportError(msg) from err
# Checks the jaxlib version before importing anything else from jaxlib.
# Checks the jaxlib version before importing anything else.
# Returns the jaxlib version string.
def check_jaxlib_version(jax_version: str, jaxlib_version: str,
minimum_jaxlib_version: str) -> tuple[int, ...]:
@ -77,20 +77,23 @@ version = check_jaxlib_version(
jaxlib_version=jaxlib.version.__version__,
minimum_jaxlib_version=jax.version._minimum_jaxlib_version)
# Before importing any C compiled modules from jaxlib, first import the CPU
# Before importing any C compiled modules, first import the CPU
# feature guard module to verify that jaxlib was compiled in a way that only
# uses instructions that are present on this machine.
import jaxlib.cpu_feature_guard as cpu_feature_guard
cpu_feature_guard.check_cpu_features()
import jaxlib.utils as utils # noqa: F401
import jaxlib.xla_client as xla_client
import jaxlib.lapack as lapack # noqa: F401
import jaxlib.utils as utils # noqa: F401
import jaxlib.xla_extension as xla_extension # noqa: F401
from jaxlib.xla_extension import guard_lib as guard_lib # noqa: F401
from jaxlib.xla_extension import jax_jit as jax_jit # noqa: F401
from jaxlib.xla_extension import pmap_lib as pmap_lib # noqa: F401
from jaxlib.xla_extension import pytree as pytree # noqa: F401
import jaxlib.xla_client as xla_client # noqa: F401
from jaxlib.xla_extension import Device as Device # noqa: F401
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
# XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882
def _xla_gc_callback(*args):
@ -167,6 +170,3 @@ def _cuda_path() -> str | None:
return None
cuda_path = _cuda_path()
guard_lib = xla_client._xla.guard_lib
Device = xla_client._xla.Device

@ -81,7 +81,7 @@ py_library_providing_imports_info(
"//jaxlib/mlir:vector_dialect",
"//jaxlib/mosaic",
"//jaxlib/triton",
"@xla//xla/python:xla_extension",
"//jaxlib/xla:xla_client",
],
)
@ -94,7 +94,7 @@ symlink_files(
symlink_files(
name = "xla_client",
srcs = ["@xla//xla/python:xla_client"],
srcs = ["//jaxlib/xla:xla_client"],
dst = ".",
flatten = True,
)

@ -132,6 +132,9 @@ def pytype_strict_library(name, pytype_srcs = [], **kwargs):
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
native.py_library(name = name, data = data, **new_kwargs)
py_strict_library = native.py_library
py_strict_test = native.py_test
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}

162
jaxlib/xla/BUILD Normal file

@ -0,0 +1,162 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"nanobind_extension",
"py_deps",
"py_strict_library",
"py_strict_test",
"pytype_strict_library",
)
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//jax:internal"],
)
package_group(
name = "xla_python",
includes = [
"//jax:internal",
],
)
pytype_strict_library(
name = "xla_client",
srcs = ["xla_client.py"],
pytype_srcs = ["xla_client.pyi"],
visibility = [":xla_python"],
deps = py_deps([
"numpy",
"ml_dtypes",
]) + ["@xla//xla/python:xla_extension"],
)
py_strict_test(
name = "xla_client_backend_independent_test",
srcs = ["xla_client_backend_independent_test.py"],
deps = [
":xla_client",
] + py_deps([
"absl/testing",
"numpy",
"portpicker",
]),
)
py_strict_library(
name = "xla_client_test",
testonly = 1,
srcs = ["xla_client_test.py"],
visibility = [":xla_python"],
deps = [
":xla_client",
"//jax",
"//jax:test_util",
"//jaxlib",
] + py_deps([
"absl/flags",
"absl/logging",
"absl/testing",
"ml_dtypes",
"numpy",
]),
)
nanobind_extension(
name = "custom_calls_testlib",
testonly = 1,
srcs = ["custom_calls_testlib.cc"],
deps = [
"@com_google_absl//absl/status",
"@nanobind",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
],
)
py_strict_test(
name = "xla_client_test_cpu",
srcs = ["xla_client_test.py"],
args = ["--backend=cpu"],
env = {
"XLA_FLAGS": "--xla_force_host_platform_device_count=4",
},
main = "xla_client_test.py",
deps = [
":custom_calls_testlib",
":xla_client",
"//jax",
"//jax:test_util",
"//jaxlib",
] + py_deps([
"absl/flags",
"absl/logging",
"absl/testing",
"ml_dtypes",
"numpy",
]),
)
py_strict_test(
name = "weakref_lru_cache_test",
srcs = ["weakref_lru_cache_test.py"],
deps = [
":xla_client",
] + py_deps([
"absl/flags",
"absl/logging",
"absl/testing",
]),
)
py_strict_test(
name = "pytree_test",
srcs = ["pytree_test.py"],
deps = [
":xla_client",
] + py_deps([
"absl/flags",
"absl/logging",
"absl/testing",
]),
)
py_strict_test(
name = "config_test",
srcs = ["config_test.py"],
deps = [
":xla_client",
] + py_deps([
"absl/flags",
"absl/logging",
"absl/testing",
]),
)
py_strict_test(
name = "jax_jit_test",
srcs = ["jax_jit_test.py"],
deps = [
":xla_client",
] + py_deps([
"absl/flags",
"absl/logging",
"absl/testing",
"numpy",
]),
)

71
jaxlib/xla/config_test.py Normal file

@ -0,0 +1,71 @@
# Copyright 2024 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import threading
from absl.testing import absltest
from jax.jaxlib.xla import xla_client
config = xla_client._xla.config
class ConfigTest(absltest.TestCase):
def testBasic(self):
c = config.Config(1)
self.assertEqual(c.value, 1)
self.assertEqual(c.get_global(), 1)
self.assertEqual(c.get_local(), config.unset)
c.set_global(2)
self.assertEqual(c.value, 2)
self.assertEqual(c.get_global(), 2)
self.assertEqual(c.get_local(), config.unset)
c.set_local(3)
self.assertEqual(c.value, 3)
self.assertEqual(c.get_global(), 2)
self.assertEqual(c.get_local(), 3)
c.set_global(4)
self.assertEqual(c.value, 3)
self.assertEqual(c.get_global(), 4)
self.assertEqual(c.get_local(), 3)
c.set_local(config.unset)
self.assertEqual(c.value, 4)
self.assertEqual(c.get_global(), 4)
self.assertEqual(c.get_local(), config.unset)
def testThreading(self):
c = config.Config(1)
def Body():
for i in range(100):
c.set_local(i)
self.assertEqual(c.get_local(), i)
self.assertEqual(c.get_global(), 1)
self.assertEqual(c.value, i)
threads = [threading.Thread(target=Body) for _ in range(4)]
for t in threads:
t.start()
for t in threads:
t.join()
if __name__ == "__main__":
absltest.main()

@ -0,0 +1,128 @@
/* Copyright 2024 The JAX Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include "third_party/nanobind/include/nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace xla::ffi {
namespace nb = ::nanobind;
// Implement custom calls as static functions with XLA FFI types in the function
// signature that gives access to the arguments and results buffers together
// with their types and dimensions. See `ffi/api/ffi_test.cc` for more XLA FFI
// examples and features (e.g. binding attributes, custom user-defined structs
// and arbitrary execution context).
static Error AlwaysFail(Result<AnyBuffer>) {
return Error(XLA_FFI_Error_Code_INTERNAL, "Failed intentionally");
}
static Error AlwaysSucceed(Result<AnyBuffer>) { return Error::Success(); }
static Error Subtract(BufferR0<DataType::F32> a, BufferR0<DataType::F32> b,
Result<BufferR0<DataType::F32>> out) {
*out->typed_data() = *a.typed_data() - *b.typed_data();
return Error::Success();
}
static Error SubtractCst(BufferR0<DataType::F32> a,
Result<BufferR0<DataType::F32>> out, float cst) {
*out->typed_data() = *a.typed_data() - cst;
return Error::Success();
}
// Define XLA FFI handlers from the implementations defined above using explicit
// XLA FFI binding API to describe type signatures of custom calls.
XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, Ffi::Bind().Ret<AnyBuffer>());
XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed,
Ffi::Bind().Ret<AnyBuffer>());
XLA_FFI_DEFINE_HANDLER(kSubtract, Subtract,
Ffi::Bind()
.Arg<BufferR0<DataType::F32>>()
.Arg<BufferR0<DataType::F32>>()
.Ret<BufferR0<DataType::F32>>());
XLA_FFI_DEFINE_HANDLER(kSubtractCst, SubtractCst,
Ffi::Bind()
.Arg<BufferR0<DataType::F32>>()
.Ret<BufferR0<DataType::F32>>()
.Attr<float>("cst"));
// XLA FFI calls can also be stateful.
struct TestFfiState {
static TypeId id;
explicit TestFfiState(int32_t value) : value(value) {}
int32_t value;
};
TypeId TestFfiState::id = {};
static ErrorOr<std::unique_ptr<TestFfiState>> StateInstantiate() {
return std::make_unique<TestFfiState>(42);
}
static Error StateExecute(TestFfiState* state,
Result<BufferR0<DataType::S32>> out) {
*out->typed_data() = state->value;
return Error::Success();
}
XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate,
Ffi::BindInstantiate());
XLA_FFI_DEFINE_HANDLER(
kStateExecute, StateExecute,
Ffi::Bind().Ctx<State<TestFfiState>>().Ret<BufferR0<DataType::S32>>());
template <typename T>
static auto BindFunction(T* fn) {
return nb::capsule(reinterpret_cast<void*>(fn));
}
template <typename T>
static auto BindTypeId(T* typeId) {
return nb::capsule(reinterpret_cast<void*>(typeId));
}
// Custom calls registration library that exports function pointers to XLA FFI
// handlers to the python users.
NB_MODULE(custom_calls_testlib, m) {
m.def("registrations", []() {
nb::dict dict;
dict["always_fail"] = BindFunction(kAlwaysFail);
dict["always_succeed"] = BindFunction(kAlwaysSucceed);
dict["subtract_f32"] = BindFunction(kSubtract);
dict["subtract_f32_cst"] = BindFunction(kSubtractCst);
nb::dict bundle;
bundle["instantiate"] = BindFunction(kStateInstantiate);
bundle["execute"] = BindFunction(kStateExecute);
dict["stateful"] = bundle;
return dict;
});
m.def("type_ids", []() {
nb::dict type_ids;
type_ids["test_ffi_state"] = BindTypeId(&TestFfiState::id);
return type_ids;
});
}
} // namespace xla::ffi

@ -0,0 +1,47 @@
# Copyright 2024 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for jax_jit helper functions."""
from absl.testing import absltest
from jax.jaxlib.xla import xla_client
jax_jit = xla_client._xla.jax_jit
pytree = xla_client._xla.pytree
pytree_registry = pytree.default_registry()
class JaxJitTest(absltest.TestCase):
def testParseArguments(self):
sig, args = jax_jit.parse_arguments(
positional_args=[1, 2, 3],
keyword_args=[4, 5],
kwnames=("a", "b"),
static_argnums=[0, 2],
static_argnames=["a"],
pytree_registry=pytree_registry,
)
self.assertEqual(args, [2, 5])
self.assertEqual(sig.static_args, [1, 3, 4])
self.assertEqual(sig.static_arg_names, ["a"])
_, leaf = pytree_registry.flatten(0)
self.assertEqual(sig.dynamic_arg_names, ["b"])
self.assertEqual(sig.dynamic_arg_treedefs, [leaf, leaf])
if __name__ == "__main__":
absltest.main()

144
jaxlib/xla/pytree_test.py Normal file

@ -0,0 +1,144 @@
# Copyright 2023 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
import dataclasses
import gc
from absl.testing import absltest
from jax.jaxlib.xla import xla_client
pytree = xla_client._xla.pytree
ExampleType = collections.namedtuple("ExampleType", "field0 field1")
registry = pytree.PyTreeRegistry()
class ExampleType2:
def __init__(self, field0, field1):
self.field0 = field0
self.field1 = field1
def to_iterable(self):
return [self.field0, self.field1], (None,)
def from_iterable(state, values):
del state
return ExampleType2(field0=values[0], field1=values[1])
registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable)
@dataclasses.dataclass
class Custom:
a: int
b: str
registry.register_dataclass_node(Custom, ["a"], ["b"])
class PyTreeTest(absltest.TestCase):
def roundtrip(self, example):
original = registry.flatten(example)[1]
self.assertEqual(
pytree.PyTreeDef.deserialize_using_proto(
registry, original.serialize_using_proto()
),
original,
)
def testSerializeDeserializeNoPickle(self):
o = object()
self.roundtrip(({"a": o, "b": o}, [o, (o, o), None]))
def testSerializeWithFallback(self):
o = object()
with self.assertRaises(ValueError):
self.roundtrip({"a": ExampleType(field0=o, field1=o)})
def testRegisteredType(self):
o = object()
with self.assertRaises(ValueError):
self.roundtrip({"a": ExampleType2(field0=o, field1=o)})
def roundtrip_node_data(self, example):
original = registry.flatten(example)[1]
restored = pytree.PyTreeDef.make_from_node_data_and_children(
registry, original.node_data(), original.children()
)
self.assertEqual(restored, original)
def testRoundtripNodeData(self):
o = object()
self.roundtrip_node_data([o, o, o])
self.roundtrip_node_data((o, o, o))
self.roundtrip_node_data({"a": o, "b": o})
self.roundtrip_node_data({22: o, 88: o})
self.roundtrip_node_data(None)
self.roundtrip_node_data(o)
self.roundtrip_node_data(ExampleType(field0=o, field1=o))
self.roundtrip_node_data(ExampleType2(field0=o, field1=o))
def testCompose(self):
x = registry.flatten(0)[1]
y = registry.flatten((0, 0))[1]
self.assertEqual((x.compose(y)).num_leaves, 2)
def testDataclassMakeFromNodeData(self):
c = Custom(1, "a")
c_leafs, c_tree = registry.flatten(c)
c_tree2 = c_tree.make_from_node_data_and_children(
registry, c_tree.node_data(), c_tree.children()
)
self.assertEqual(c_tree2.unflatten(c_leafs), c)
self.assertEqual(str(c_tree2), str(c_tree))
def testTpTraverse(self):
self.assertContainsSubset(
[
pytree.PyTreeRegistry,
ExampleType2,
ExampleType2.to_iterable,
from_iterable,
],
gc.get_referents(registry),
)
k1 = "k1"
k2 = "k2"
t = ExampleType("a", "b")
_, treedef = registry.flatten([1, {k1: 2, k2: t}, 5, t])
self.assertContainsSubset(
[
pytree.PyTreeDef,
registry,
k1,
k2,
ExampleType,
],
gc.get_referents(treedef),
)
if __name__ == "__main__":
absltest.main()

@ -0,0 +1,257 @@
# Copyright 2023 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import gc
import threading
import time
import weakref
from absl.testing import absltest
from jax.jaxlib.xla import xla_client
class WeakrefLRUCacheTest(absltest.TestCase):
def testMultiThreaded(self):
insert_evs = [threading.Event() for _ in range(2)]
insert_evs_i = 0
class WRKey:
pass
class ClashingKey:
def __eq__(self, other):
return False
def __hash__(self):
return 333 # induce maximal caching problems.
class GilReleasingCacheKey:
def __eq__(self, other):
nonlocal insert_evs_i
if isinstance(other, GilReleasingCacheKey) and insert_evs_i < len(
insert_evs
):
insert_evs[insert_evs_i].set()
insert_evs_i += 1
time.sleep(0.01)
return False
def __hash__(self):
return 333 # induce maximal caching problems.
def CacheFn(obj, gil_releasing_cache_key):
del obj
del gil_releasing_cache_key
return None
cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 2048)
wrkey = WRKey()
def Body():
for insert_ev in insert_evs:
insert_ev.wait()
for _ in range(20):
cache(wrkey, ClashingKey())
t = threading.Thread(target=Body)
t.start()
for _ in range(3):
cache(wrkey, GilReleasingCacheKey())
t.join()
def testAnotherMultiThreaded(self):
num_workers = 5
barrier = threading.Barrier(num_workers)
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
class WRKey:
pass
def WorkerAddToCache():
barrier.wait()
wrkey = WRKey()
for i in range(10):
cache(wrkey, i)
def WorkerCleanCache():
barrier.wait()
for _ in range(10):
cache.cache_clear()
workers = [
threading.Thread(target=WorkerAddToCache)
for _ in range(num_workers - 1)
] + [threading.Thread(target=WorkerCleanCache)]
for t in workers:
t.start()
for t in workers:
t.join()
def testKwargsDictOrder(self):
miss_id = 0
class WRKey:
pass
def CacheFn(obj, kwkey1, kwkey2):
del obj, kwkey1, kwkey2
nonlocal miss_id
miss_id += 1
return miss_id
cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4)
wrkey = WRKey()
self.assertEqual(cache(wrkey, kwkey1="a", kwkey2="b"), 1)
self.assertEqual(cache(wrkey, kwkey1="b", kwkey2="a"), 2)
self.assertEqual(cache(wrkey, kwkey2="b", kwkey1="a"), 1)
def testGetKeys(self):
def CacheFn(obj, arg):
del obj
return arg + "extra"
cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4)
class WRKey:
pass
wrkey = WRKey()
self.assertEmpty(cache.cache_keys())
cache(wrkey, "arg1")
cache(wrkey, "arg2")
self.assertLen(cache.cache_keys(), 2)
def testNonWeakreferenceableKey(self):
class NonWRKey:
__slots__ = ()
non_wr_key = NonWRKey()
with self.assertRaises(TypeError):
weakref.ref(non_wr_key)
cache = xla_client.weakref_lru_cache(lambda: None, lambda x: 2048)
for _ in range(100):
with self.assertRaises(TypeError):
cache(non_wr_key)
def testCrashingKey(self):
class WRKey:
pass
class CrashingKey:
# A key that raises exceptions if eq or hash is called.
def __eq__(self, other):
raise ValueError("eq")
def __hash__(self):
raise ValueError("hash")
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
wrkey = WRKey()
with self.assertRaises(ValueError):
for _ in range(100):
cache(wrkey, CrashingKey())
def testPrintingStats(self):
class WRKey:
pass
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
wrkey = WRKey()
for i in range(10):
cache(wrkey, i)
for i in range(5):
cache(wrkey, i)
self.assertEqual(
repr(cache.cache_info()),
"WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)",
)
def testGCKeys(self):
class WRKey:
def __init__(self, x):
self.x = x
def __eq__(self, other):
return self.x == other.x
def __hash__(self):
return hash(self.x)
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
keys = [WRKey(i) for i in range(10)]
for i in range(10):
cache(keys[i], i)
# Delete some keys, to exercise the weakref callback behavior.
del keys[::2]
for key in keys:
cache(key, 7)
def testTpTraverse(self):
class WRKey:
pass
def CacheContextFn():
return None
def CallFn(x, y, *args, **kwargs):
del x, args, kwargs
return y
cache = xla_client.weakref_lru_cache(CacheContextFn, CallFn, 2048)
keys = [WRKey() for _ in range(10)]
values = [str(i) for i in range(10)]
args = [str(i) for i in range(10)]
kwargs = {"a": "b"}
for key, value in zip(keys, values):
cache(key, value, *args, **kwargs)
expected_refs = (
[
CacheContextFn,
CallFn,
xla_client._xla.WeakrefLRUCache,
kwargs,
]
+ [weakref.getweakrefs(key)[0] for key in keys]
+ values
+ args
)
# Can't use assertContainsSubset because it doesn't support kwargs since
# dicts aren't hashable.
for ref in expected_refs:
self.assertIn(ref, gc.get_referents(cache))
if __name__ == "__main__":
absltest.main()

1044
jaxlib/xla/xla_client.py Normal file

File diff suppressed because it is too large Load Diff

322
jaxlib/xla/xla_client.pyi Normal file

@ -0,0 +1,322 @@
# Copyright 2021 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
import enum
from typing import Any, Union
import numpy
from jaxlib import xla_extension as _xla
from jaxlib.xla_extension import ArrayImpl as ArrayImpl
from jaxlib.xla_extension import AutotuneCacheMode as AutotuneCacheMode
from jaxlib.xla_extension import Client as Client
from jaxlib.xla_extension import CompileOptions as CompileOptions
from jaxlib.xla_extension import Device as Device
from jaxlib.xla_extension import DeviceAssignment as DeviceAssignment
from jaxlib.xla_extension import DeviceList as DeviceList
from jaxlib.xla_extension import DeviceTopology as DeviceTopology
from jaxlib.xla_extension import DistributedRuntimeClient as DistributedRuntimeClient
from jaxlib.xla_extension import FftType as FftType
from jaxlib.xla_extension import Frame as Frame
from jaxlib.xla_extension import GSPMDSharding as GSPMDSharding
from jaxlib.xla_extension import HloSharding as HloSharding
from jaxlib.xla_extension import HostBufferSemantics as HostBufferSemantics
from jaxlib.xla_extension import ifrt_programs as ifrt_programs
from jaxlib.xla_extension import Layout as Layout
from jaxlib.xla_extension import LoadedExecutable as LoadedExecutable
from jaxlib.xla_extension import Memory as Memory
from jaxlib.xla_extension import NamedSharding as NamedSharding
from jaxlib.xla_extension import ops as ops
from jaxlib.xla_extension import OpSharding as OpSharding
from jaxlib.xla_extension import PjRtLayout as PjRtLayout
from jaxlib.xla_extension import PmapSharding as PmapSharding
from jaxlib.xla_extension import PrimitiveType as PrimitiveType
from jaxlib.xla_extension import ArrayCopySemantics as ArrayCopySemantics
from jaxlib.xla_extension import profiler as profiler
from jaxlib.xla_extension import Shape as Shape
from jaxlib.xla_extension import Sharding as Sharding
from jaxlib.xla_extension import SingleDeviceSharding as SingleDeviceSharding
from jaxlib.xla_extension import Traceback as Traceback
from jaxlib.xla_extension import XlaBuilder as XlaBuilder
from jaxlib.xla_extension import XlaComputation as XlaComputation
from jaxlib.xla_extension import XlaOp as XlaOp
_version: int
mlir_api_version: int
bfloat16: type[numpy.generic]
# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
# float4_e2m1fn: type[numpy.generic]
# float8_e3m4: type[numpy.generic]
# float8_e4m3: type[numpy.generic]
# float8_e8m0fnu: type[numpy.generic]
float8_e4m3fn: type[numpy.generic]
float8_e4m3b11fnuz: type[numpy.generic]
float8_e4m3fnuz: type[numpy.generic]
float8_e5m2: type[numpy.generic]
float8_e5m2fnuz: type[numpy.generic]
XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype]
_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]
def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType:
...
def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ...
def heap_profile(client: Client) -> bytes:
...
XlaRuntimeError = _xla.XlaRuntimeError
def make_cpu_client(
asynchronous: bool = ...,
distributed_client: DistributedRuntimeClient | None = ...,
node_id: int = ...,
num_nodes: int = ...,
collectives: _xla.CpuCollectives | None = ...,
num_devices: int | None = ...,
) -> Client:
...
def make_gpu_client(
distributed_client: DistributedRuntimeClient | None = ...,
node_id: int = ...,
num_nodes: int = ...,
platform_name: str | None = ...,
allowed_devices: set[int] | None = ...,
mock: bool | None = ...,
mock_gpu_topology: str | None = ...,
) -> Client:
...
def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None) -> Client:
...
def make_tfrt_tpu_c_api_device_topology(
topology_name: str | None = None, **kwargs
) -> DeviceTopology:
...
def make_c_api_device_topology(c_api: Any, topology_name: str = '', **kwargs) -> DeviceTopology:
...
def get_topology_for_devices(devices: list[Device]) -> DeviceTopology:
...
def make_tpu_client(
library_path: str | None, options: _NameValueMapping | None = None
) -> Client:
...
def make_c_api_client(
plugin_name: str,
options: _NameValueMapping | None = None,
distributed_client: DistributedRuntimeClient | None = None,
) -> Client:
...
def pjrt_plugin_loaded(plugin_name: str) -> bool:
...
def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any:
...
def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None:
...
def pjrt_plugin_initialized(plugin_name: str) -> bool:
...
def initialize_pjrt_plugin(plugin_name: str) -> None:
...
def generate_pjrt_gpu_plugin_options() -> _NameValueMapping:
...
class OpMetadata:
def __init__(
self,
op_type: str | None = ...,
op_name: str | None = ...,
source_file: str | None = ...,
source_line: int | None = ...,
):
...
op_type: str | None
op_name: str | None
source_file: str | None
source_line: int | None
class PaddingConfigDimension:
edge_padding_low: int
edge_padding_high: int
interior_padding: int
class PaddingConfig:
dimensions: list[PaddingConfigDimension]
def make_padding_config(
padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]],
) -> PaddingConfig:
...
class PaddingType(enum.Enum):
VALID = 1
SAME = 2
class DotDimensionNumbers:
lhs_contracting_dimensions: list[int]
rhs_contracting_dimensions: list[int]
lhs_batch_dimensions: list[int]
rhs_batch_dimensions: list[int]
def make_dot_dimension_numbers(
dimension_numbers: Union[
DotDimensionNumbers,
tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]],
],
) -> DotDimensionNumbers:
...
class ConvolutionDimensionNumbers:
input_batch_dimension: int
input_feature_dimension: int
input_spatial_dimensions: list[int]
kernel_input_feature_dimension: int
kernel_output_feature_dimension: int
kernel_spatial_dimensions: list[int]
output_batch_dimension: int
output_feature_dimension: int
output_spatial_dimensions: list[int]
def make_convolution_dimension_numbers(
dimension_numbers: Union[
None, ConvolutionDimensionNumbers, tuple[str, str, str]
],
num_spatial_dimensions: int,
) -> ConvolutionDimensionNumbers:
...
class PrecisionConfig:
Precision = _xla.PrecisionConfig_Precision
operand_precision: list[_xla.PrecisionConfig_Precision]
class ResultAccuracy:
mode: _xla.ResultAccuracy_Mode
atol: float
rtol: float
ulps: int
class GatherDimensionNumbers:
offset_dims: list[int]
collapsed_slice_dims: list[int]
start_index_map: list[int]
index_vector_dim: int
operand_batching_dims: list[int]
start_indices_batching_dims: list[int]
class ScatterDimensionNumbers:
update_window_dims: list[int]
inserted_window_dims: list[int]
scatter_dims_to_operand_dims: list[int]
index_vector_dim: int
input_batching_dims: list[int]
scatter_indices_batching_dims: list[int]
class ReplicaGroup:
replica_ids: list[int]
def make_replica_groups(
replica_groups: Sequence[Sequence[int]] | None,
) -> list[ReplicaGroup]:
...
def weakref_lru_cache(cache_context_fn: Callable, call: Callable, maxsize=...) -> _xla.WeakrefLRUCache:
...
def batched_copy_array_to_devices_with_sharding(
arrays: Sequence[ArrayImpl],
devices: Sequence[list[Device]],
sharding: Sequence[Any],
array_copy_semantics: Sequence[ArrayCopySemantics],
) -> Sequence[ArrayImpl]: ...
def batched_device_put(
aval: Any,
sharding: Any,
shards: Sequence[Any],
devices: list[Device],
committed: bool = ...,
force_copy: bool = ...,
host_buffer_semantics: Any = ...,
) -> ArrayImpl: ...
def reorder_shards(
x: ArrayImpl,
dst_sharding: Any,
array_copy_semantics: ArrayCopySemantics,
) -> ArrayImpl: ...
def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ...
def check_and_canonicalize_memory_kind(
memory_kind: str | None, device_list: DeviceList
) -> str | None: ...
def array_result_handler(
aval: Any,
sharding: Any,
committed: bool,
_skip_checks: bool = ...) -> Callable:
...
class CustomCallTargetTraits(enum.IntFlag):
DEFAULT = 0
COMMAND_BUFFER_COMPATIBLE = 1
def register_custom_call_target(
name: str,
fn: Any,
platform: str = ...,
api_version: int = ...,
traits: CustomCallTargetTraits = ...,
) -> None: ...
def register_custom_call_handler(
xla_platform_name: str, handler: Any
) -> None: ...
def custom_call_targets(platform: str) -> dict[str, Any]: ...
def register_custom_type_id(
type_name: str,
type_id: Any,
platform: str = ...,
) -> None: ...
def register_custom_type_id_handler(platform: str, handler: Any) -> None: ...
def encode_inspect_sharding_callback(handler: Any) -> bytes: ...
register_custom_call_partitioner = _xla.register_custom_call_partitioner
register_custom_call_as_batch_partitionable = (
_xla.register_custom_call_as_batch_partitionable
)

@ -0,0 +1,195 @@
# Copyright 2017 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Backend-independent tests for the Python XLA client."""
import unittest
from absl.testing import absltest
import numpy as np
from jax.jaxlib.xla import xla_client
# pylint: disable=g-import-not-at-top
try:
import portpicker
except ImportError:
portpicker = None
# pylint: enable=g-import-not-at-top
ops = xla_client.ops
class ShapeTest(absltest.TestCase):
def testInvalidShapes(self):
with self.assertRaisesRegex(xla_client.XlaRuntimeError, "invalid shape"):
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4])
with self.assertRaisesRegex(
RuntimeError, "layout minor_to_major field contains 1 element.*"):
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3])
with self.assertRaisesRegex(
RuntimeError, "layout minor_to_major field has out-of-bounds value.*"):
xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4],
[1, -1])
class ComputationPrinting(absltest.TestCase):
def ExampleComputation(self):
builder = xla_client.XlaBuilder("acomputation")
p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0)))
p1 = ops.Parameter(builder, 1,
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
x = ops.Mul(p0, p1)
ops.Add(x, x)
return builder.build()
def testComputationToHloText(self):
computation = self.ExampleComputation()
hlo_text = computation.as_hlo_text()
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
def testComputationToHloGraph(self):
computation = self.ExampleComputation()
hlo_dot_graph = computation.as_hlo_dot_graph()
self.assertTrue(hlo_dot_graph.startswith("digraph "))
def testHloModuleToHloText(self):
computation = self.ExampleComputation()
hlo_text = computation.as_hlo_module().to_string()
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
def testHloModuleFromText(self):
hlo_module_text = """HloModule test
add {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry {
p0 = f32[2,3] parameter(0)
start = f32[2,3] all-reduce-start(p0), to_apply=add
ROOT done = f32[2,3] all-reduce-done(start)
}"""
hlo_module = xla_client._xla.hlo_module_from_text(hlo_module_text)
hlo_text = hlo_module.to_string()
self.assertTrue(hlo_text.startswith("HloModule test"))
def testHloModuleToHloGraph(self):
computation = self.ExampleComputation()
hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph(
computation.as_hlo_module())
self.assertTrue(hlo_dot_graph.startswith("digraph "))
class ComputationHashTest(absltest.TestCase):
def testHash(self):
builder0 = xla_client.XlaBuilder("computation0")
p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0)))
p1 = ops.Parameter(builder0, 1,
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
ops.Mul(p0, p1)
computation0 = builder0.build()
builder1 = xla_client.XlaBuilder("computation1")
p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0)))
p1 = ops.Parameter(builder1, 1,
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
ops.Mul(p0, p1)
computation1 = builder1.build()
self.assertEqual(computation0.hash(), computation1.hash())
class AliasTest(absltest.TestCase):
def testSetUpAlias(self):
c = xla_client.XlaBuilder(self.id())
p1 = ops.Parameter(
c, 0,
xla_client.shape_from_pyval(np.array(
1.0, np.float32)).with_major_to_minor_layout_if_absent())
p2 = ops.Parameter(
c, 1,
xla_client.shape_from_pyval(np.array(
1.0, np.float32)).with_major_to_minor_layout_if_absent())
out = ops.Add(p1, p2)
c.setup_alias([], 0, [])
c.build(out)
class ProfilerTest(absltest.TestCase):
def testTraceMe(self):
# TODO(phawkins): These tests just check that the TraceMe context manager
# acts like a context manager and doesn't explode. Ideally we'd check that
# the profiler saw the traceme too.
with xla_client.profiler.TraceMe("test1"):
pass
with xla_client.profiler.TraceMe("test2", foo=123):
pass
with self.assertRaises(ValueError):
with xla_client.profiler.TraceMe("test3"):
raise ValueError("test")
@unittest.skipIf(portpicker is None, "Test requires portpicker")
def testStartServer(self):
port = portpicker.pick_unused_port()
server = xla_client.profiler.start_server(port)
del server
class HloModuleGroupTest(absltest.TestCase):
def testHloModuleGroup(self):
builder0 = xla_client.XlaBuilder("computation0")
p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0)))
p1 = ops.Parameter(builder0, 1,
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
root = ops.Mul(p0, p1)
computation0 = builder0.build(root)
m = computation0.get_hlo_module()
mg_name = "test_module_group"
mg = xla_client._xla.HloModuleGroup(mg_name, [m])
self.assertEqual(mg.name, mg_name)
modules = mg.to_modules()
self.assertLen(modules, 1)
self.assertEqual(m.to_string(), modules[0].to_string())
class RunHloPassTest(absltest.TestCase):
def testHloDCE(self):
b = xla_client.XlaBuilder("acomputation")
p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(np.float32(0)))
p1 = ops.Parameter(b, 1,
xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
root = ops.Mul(p0, p1)
# Dead instructions
p2 = ops.Parameter(b, 2, xla_client.shape_from_pyval(np.float32(0)))
ops.Add(p2, p2)
hlo_module = b.build(root).get_hlo_module()
self.assertTrue(xla_client._xla.HloDCE().run(hlo_module))
if __name__ == "__main__":
absltest.main()

File diff suppressed because it is too large Load Diff

@ -23,8 +23,11 @@ module = [
"jax.experimental.jax2tf.tests.back_compat_testdata",
"jax.experimental.jax2tf.tests.flax_models",
"jax_cuda12_plugin.*",
"jaxlib.*",
"jaxlib.cpu_feature_guard",
"jaxlib.cuda.*",
"jaxlib.mlir.*",
"jaxlib.utils",
"jaxlib.xla_extension.utils",
"jraph.*",
"libtpu.*",
"matplotlib.*",