diff --git a/benchmarks/random_benchmark.py b/benchmarks/random_benchmark.py index 546550f32..5e730d1b5 100644 --- a/benchmarks/random_benchmark.py +++ b/benchmarks/random_benchmark.py @@ -48,8 +48,9 @@ def trivial_dispatch_typed_key(state): _bench_trivial_dispatch(state, key) -def _bench_nontrivial_dispatch(state, key): - f = jax.jit(lambda key: jax.random.split(key)) +def _bench_nontrivial_dispatch(state, key, do_split=False): + key_op = jax.random.split if do_split else jax.random.normal + f = jax.jit(lambda key: key_op(key)) _ = f(key) while state: f(key) @@ -60,14 +61,66 @@ def _bench_nontrivial_dispatch(state, key): def nontrivial_dispatch_raw_key(state): key = jax.random.PRNGKey(0) _assert_raw_key(key) - _bench_nontrivial_dispatch(state, key) + _bench_nontrivial_dispatch(state, key, do_split=False) @google_benchmark.register def nontrivial_dispatch_typed_key(state): key = jax.random.key(0) _assert_typed_key(key) - _bench_nontrivial_dispatch(state, key) + _bench_nontrivial_dispatch(state, key, do_split=False) + + +@google_benchmark.register +def nontrivial_dispatch_raw_key_split(state): + key = jax.random.PRNGKey(0) + _assert_raw_key(key) + _bench_nontrivial_dispatch(state, key, do_split=True) + + +@google_benchmark.register +def nontrivial_dispatch_typed_key_split(state): + key = jax.random.key(0) + _assert_typed_key(key) + _bench_nontrivial_dispatch(state, key, do_split=True) + + + +def _bench_custom_container(state, key): + @jax.tree_util.register_pytree_node_class + class A: + def __init__(self, x): + self.x = x + + def tree_flatten(self): + return (self.x,), None + + @classmethod + def tree_unflatten(cls, aux, children): + x, = children + return cls(x) + + f = jax.jit( + lambda key, a: jax.random.normal(key) + a.x) + a = A(5.) + _ = f(key, a) + while state: + f(key, a) + f(key, a).block_until_ready() + + +@google_benchmark.register +def custom_container_raw_key(state): + key = jax.random.PRNGKey(0) + _assert_raw_key(key) + _bench_custom_container(state, key) + + +@google_benchmark.register +def custom_container_typed_key(state): + key = jax.random.key(0) + _assert_typed_key(key) + _bench_custom_container(state, key) if __name__ == "__main__": diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 14dfe46a8..64898cb25 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2707,6 +2707,14 @@ class MeshExecutableFastpathData(NamedTuple): kept_var_bitvec: Iterable[bool] +def reflatten_outputs_for_dispatch(out_tree, out_flat): + # We arrive at dispatch having flattened according to the default + # pytree registry, but we want to re-flatten according to our + # dispatch-specific registry. + out_unflat = tree_util.tree_unflatten(out_tree, out_flat) + return tree_util.dispatch_registry.flatten(out_unflat, None) + + class MeshExecutable(stages.XlaExecutable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", @@ -2766,6 +2774,8 @@ class MeshExecutable(stages.XlaExecutable): def aot_cache_miss(*args, **kwargs): params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree) outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs) + out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch( + out_tree, out_flat) use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat)) if use_fastpath: @@ -2774,14 +2784,14 @@ class MeshExecutable(stages.XlaExecutable): kept_var_bitvec = [i in self._kept_var_idx for i in range(len(args_flat))] fastpath_data = MeshExecutableFastpathData( - self.xla_executable, out_tree, self._in_shardings, + self.xla_executable, out_tree_dispatch, self._in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec) else: fastpath_data = None return outs, fastpath_data return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.default_registry) + tree_util.dispatch_registry) def create_cpp_call_for_apply_primitive(self, out_tree): # unsafe_call can be different than ExecuteReplicated for pathways. @@ -2793,6 +2803,8 @@ class MeshExecutable(stages.XlaExecutable): def apply_primitive_cache_miss(*args): out_flat = self.unsafe_call(*args) outs = tree_util.tree_unflatten(out_tree, out_flat) + out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch( + out_tree, out_flat) use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat)) if use_fastpath: @@ -2801,14 +2813,14 @@ class MeshExecutable(stages.XlaExecutable): kept_var_bitvec = [i in self._kept_var_idx for i in range(len(args))] fastpath_data = MeshExecutableFastpathData( - self.xla_executable, out_tree, self._in_shardings, + self.xla_executable, out_tree_dispatch, self._in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec) else: fastpath_data = None return outs, fastpath_data return xc._xla.pjit(self.unsafe_call.name, None, apply_primitive_cache_miss, - [], [], [], tree_util.default_registry) + [], [], [], tree_util.dispatch_registry) def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 379fb0eb8..06aaf05a6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -192,6 +192,8 @@ def _python_pjit(fun: Callable, infer_params_fn): def _get_fastpath_data(executable, out_tree, args_flat, out_flat): + out_flat, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) + use_fastpath = ( executable is not None and isinstance(executable, pxla.MeshExecutable) and @@ -259,7 +261,7 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames, cpp_pjit_f = xc._xla.pjit( getattr(fun, "__name__", ""), fun, cache_miss, static_argnums, static_argnames, - donate_argnums, tree_util.default_registry, + donate_argnums, tree_util.dispatch_registry, _get_cpp_global_cache(pjit_has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) @@ -1210,7 +1212,7 @@ def _pjit_call_impl(*args, jaxpr, has_explicit_sharding = _pjit_explicit_sharding( in_shardings, out_shardings, None, None) return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.default_registry, + tree_util.dispatch_registry, _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 9d8f7bdb1..3b97c7a5d 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -36,6 +36,7 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src import sharding_specs +from jax._src import tree_util as tree_util_internal from jax._src import typing from jax._src.api import jit, vmap from jax._src.config import config @@ -400,6 +401,16 @@ basearray.Array.register(PRNGKeyArrayImpl) ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type] +def prngkeyarrayimpl_flatten(x): + return (x._base_array,), x.impl + +def prngkeyarrayimpl_unflatten(impl, children): + base_array, = children + return PRNGKeyArrayImpl(impl, base_array) + +tree_util_internal.dispatch_registry.register_node( + PRNGKeyArrayImpl, prngkeyarrayimpl_flatten, prngkeyarrayimpl_unflatten) + # TODO(frostig): remove, rerouting callers directly to random_seed def seed_with_impl(impl: PRNGImpl, seed: int | Array) -> PRNGKeyArrayImpl: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 4f370cb04..7314ed554 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -47,6 +47,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes as _dtypes from jax._src import monitoring +from jax._src import stages from jax._src.interpreters import pxla from jax._src.config import (bool_env, config, raise_persistent_cache_errors, @@ -240,6 +241,22 @@ def count_pjit_cpp_cache_miss(): pjit_lib._pjit_lower = original_pjit_lower +@contextmanager +def count_aot_jit_cpp_cache_miss(): + original_call = stages.Compiled.call + count = [0] + + def compiled_call_count(*args, **kwargs): + count[0] += 1 + return original_call(*args, **kwargs) + + stages.Compiled.call = compiled_call_count + try: + yield count + finally: + stages.Compiled.call = original_call + + @contextmanager def count_jit_and_pmap_compiles(): # No need to clear any caches since we generally jit and pmap fresh callables diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 069fb2740..4a7a83299 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -44,6 +44,22 @@ default_registry = pytree.default_registry() default_registry.__module__ = __name__ default_registry.__name__ = "default_registry" +# A special, internal pytree registry that includes everything in +# `default_registry`, plus internal Python-defined types that we want +# to teach the fast dispatch path ("C++ dispatch") how to flatten and +# unflatten. A key example is PRNG key arrays, which are currently a +# Python-defined class (in `jax._src.prng`). These ought to be a leaf +# node everywhere in the system (e.g. in Jaxpr), but we want to unpack +# and repack them across the fast dispatch boundary. If we were to +# skip registering such types here, the fast dispatch path would not +# know how to handle them as arguments. It would instead always +# indicate a "cache miss" and dispatch on the slow path. +dispatch_registry = pytree.PyTreeRegistry( + enable_none=True, enable_tuple=True, enable_namedtuple=True, + enable_list=True, enable_dict=True) +dispatch_registry.__module__ = __name__ +dispatch_registry.__name__ = "dispatch_registry" + def tree_flatten(tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> tuple[list[Leaf], PyTreeDef]: @@ -162,6 +178,7 @@ def register_pytree_node(nodetype: type[T], ``nodetype``. """ default_registry.register_node(nodetype, flatten_func, unflatten_func) + dispatch_registry.register_node(nodetype, flatten_func, unflatten_func) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) diff --git a/tests/random_test.py b/tests/random_test.py index de8daae5f..85267e2ee 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1805,6 +1805,63 @@ class KeyArrayTest(jtu.JaxTestCase): self.assertIsInstance(k1, random.KeyArray) self.assertIsInstance(k2, random.KeyArray) + def test_cpp_dispatch_normal(self): + # Ensure we stay on the C++ dispatch path when calling a jitted + # function with a key array as an argument. + + @jax.jit + def f(key): + return jax.random.normal(key) + + key = self.make_keys() + with jtu.count_pjit_cpp_cache_miss() as count: + f(key).block_until_ready() + f(key).block_until_ready() + + self.assertEqual(count[0], 1) + + def test_cpp_dispatch_split(self): + # Ensure we stay on the C++ dispatch path when calling a jitted + # function with a key arrays as inputs and as outputs. + + @jax.jit + def f(key): + return jax.random.split(key) + + key = self.make_keys() + with jtu.count_pjit_cpp_cache_miss() as count: + f(key).block_until_ready() + f(key).block_until_ready() + + self.assertEqual(count[0], 1) + + def test_cpp_dispatch_aot_normal(self): + # Ensure we stay on the C++ dispatch path when calling an + # AOT-compiled function with a key array as an argument. + + key = self.make_keys() + f = jax.jit(lambda key: jax.random.normal(key)).lower(key).compile() + + with jtu.count_aot_jit_cpp_cache_miss() as count: + f(key).block_until_ready() + f(key).block_until_ready() + + self.assertEqual(count[0], 1) + + def test_cpp_dispatch_aot_split(self): + # Ensure we stay on the C++ dispatch path when calling an + # AOT-compiled function with a key arrays as inputs and as + # outputs. + + key = self.make_keys() + f = jax.jit(lambda key: jax.random.split(key)).lower(key).compile() + + with jtu.count_aot_jit_cpp_cache_miss() as count: + f(key).block_until_ready() + f(key).block_until_ready() + + self.assertEqual(count[0], 1) + # -- prng primitives def test_random_wrap_vmap(self):