fast dispatch for functions over typed PRNG key arrays

Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.

We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.

To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
This commit is contained in:
Roy Frostig 2023-09-13 09:43:14 -07:00 committed by jax authors
parent 84951288bd
commit 6abefa1977
7 changed files with 179 additions and 10 deletions

View File

@ -48,8 +48,9 @@ def trivial_dispatch_typed_key(state):
_bench_trivial_dispatch(state, key)
def _bench_nontrivial_dispatch(state, key):
f = jax.jit(lambda key: jax.random.split(key))
def _bench_nontrivial_dispatch(state, key, do_split=False):
key_op = jax.random.split if do_split else jax.random.normal
f = jax.jit(lambda key: key_op(key))
_ = f(key)
while state:
f(key)
@ -60,14 +61,66 @@ def _bench_nontrivial_dispatch(state, key):
def nontrivial_dispatch_raw_key(state):
key = jax.random.PRNGKey(0)
_assert_raw_key(key)
_bench_nontrivial_dispatch(state, key)
_bench_nontrivial_dispatch(state, key, do_split=False)
@google_benchmark.register
def nontrivial_dispatch_typed_key(state):
key = jax.random.key(0)
_assert_typed_key(key)
_bench_nontrivial_dispatch(state, key)
_bench_nontrivial_dispatch(state, key, do_split=False)
@google_benchmark.register
def nontrivial_dispatch_raw_key_split(state):
key = jax.random.PRNGKey(0)
_assert_raw_key(key)
_bench_nontrivial_dispatch(state, key, do_split=True)
@google_benchmark.register
def nontrivial_dispatch_typed_key_split(state):
key = jax.random.key(0)
_assert_typed_key(key)
_bench_nontrivial_dispatch(state, key, do_split=True)
def _bench_custom_container(state, key):
@jax.tree_util.register_pytree_node_class
class A:
def __init__(self, x):
self.x = x
def tree_flatten(self):
return (self.x,), None
@classmethod
def tree_unflatten(cls, aux, children):
x, = children
return cls(x)
f = jax.jit(
lambda key, a: jax.random.normal(key) + a.x)
a = A(5.)
_ = f(key, a)
while state:
f(key, a)
f(key, a).block_until_ready()
@google_benchmark.register
def custom_container_raw_key(state):
key = jax.random.PRNGKey(0)
_assert_raw_key(key)
_bench_custom_container(state, key)
@google_benchmark.register
def custom_container_typed_key(state):
key = jax.random.key(0)
_assert_typed_key(key)
_bench_custom_container(state, key)
if __name__ == "__main__":

View File

@ -2707,6 +2707,14 @@ class MeshExecutableFastpathData(NamedTuple):
kept_var_bitvec: Iterable[bool]
def reflatten_outputs_for_dispatch(out_tree, out_flat):
# We arrive at dispatch having flattened according to the default
# pytree registry, but we want to re-flatten according to our
# dispatch-specific registry.
out_unflat = tree_util.tree_unflatten(out_tree, out_flat)
return tree_util.dispatch_registry.flatten(out_unflat, None)
class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
@ -2766,6 +2774,8 @@ class MeshExecutable(stages.XlaExecutable):
def aot_cache_miss(*args, **kwargs):
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch(
out_tree, out_flat)
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
if use_fastpath:
@ -2774,14 +2784,14 @@ class MeshExecutable(stages.XlaExecutable):
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(args_flat))]
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree, self._in_shardings,
self.xla_executable, out_tree_dispatch, self._in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
else:
fastpath_data = None
return outs, fastpath_data
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.default_registry)
tree_util.dispatch_registry)
def create_cpp_call_for_apply_primitive(self, out_tree):
# unsafe_call can be different than ExecuteReplicated for pathways.
@ -2793,6 +2803,8 @@ class MeshExecutable(stages.XlaExecutable):
def apply_primitive_cache_miss(*args):
out_flat = self.unsafe_call(*args)
outs = tree_util.tree_unflatten(out_tree, out_flat)
out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch(
out_tree, out_flat)
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
if use_fastpath:
@ -2801,14 +2813,14 @@ class MeshExecutable(stages.XlaExecutable):
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(args))]
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree, self._in_shardings,
self.xla_executable, out_tree_dispatch, self._in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
else:
fastpath_data = None
return outs, fastpath_data
return xc._xla.pjit(self.unsafe_call.name, None, apply_primitive_cache_miss,
[], [], [], tree_util.default_registry)
[], [], [], tree_util.dispatch_registry)
def check_arg_avals_for_call(ref_avals, arg_avals,

View File

@ -192,6 +192,8 @@ def _python_pjit(fun: Callable, infer_params_fn):
def _get_fastpath_data(executable, out_tree, args_flat, out_flat):
out_flat, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
use_fastpath = (
executable is not None and
isinstance(executable, pxla.MeshExecutable) and
@ -259,7 +261,7 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
cpp_pjit_f = xc._xla.pjit(
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.default_registry,
donate_argnums, tree_util.dispatch_registry,
_get_cpp_global_cache(pjit_has_explicit_sharding))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
@ -1210,7 +1212,7 @@ def _pjit_call_impl(*args, jaxpr,
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, None, None)
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.default_registry,
tree_util.dispatch_registry,
_get_cpp_global_cache(has_explicit_sharding))(*args)
pjit_p.def_impl(_pjit_call_impl)

View File

@ -36,6 +36,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import sharding_specs
from jax._src import tree_util as tree_util_internal
from jax._src import typing
from jax._src.api import jit, vmap
from jax._src.config import config
@ -400,6 +401,16 @@ basearray.Array.register(PRNGKeyArrayImpl)
ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type]
def prngkeyarrayimpl_flatten(x):
return (x._base_array,), x.impl
def prngkeyarrayimpl_unflatten(impl, children):
base_array, = children
return PRNGKeyArrayImpl(impl, base_array)
tree_util_internal.dispatch_registry.register_node(
PRNGKeyArrayImpl, prngkeyarrayimpl_flatten, prngkeyarrayimpl_unflatten)
# TODO(frostig): remove, rerouting callers directly to random_seed
def seed_with_impl(impl: PRNGImpl, seed: int | Array) -> PRNGKeyArrayImpl:

View File

@ -47,6 +47,7 @@ from jax._src import core
from jax._src import dispatch
from jax._src import dtypes as _dtypes
from jax._src import monitoring
from jax._src import stages
from jax._src.interpreters import pxla
from jax._src.config import (bool_env, config,
raise_persistent_cache_errors,
@ -240,6 +241,22 @@ def count_pjit_cpp_cache_miss():
pjit_lib._pjit_lower = original_pjit_lower
@contextmanager
def count_aot_jit_cpp_cache_miss():
original_call = stages.Compiled.call
count = [0]
def compiled_call_count(*args, **kwargs):
count[0] += 1
return original_call(*args, **kwargs)
stages.Compiled.call = compiled_call_count
try:
yield count
finally:
stages.Compiled.call = original_call
@contextmanager
def count_jit_and_pmap_compiles():
# No need to clear any caches since we generally jit and pmap fresh callables

View File

@ -44,6 +44,22 @@ default_registry = pytree.default_registry()
default_registry.__module__ = __name__
default_registry.__name__ = "default_registry"
# A special, internal pytree registry that includes everything in
# `default_registry`, plus internal Python-defined types that we want
# to teach the fast dispatch path ("C++ dispatch") how to flatten and
# unflatten. A key example is PRNG key arrays, which are currently a
# Python-defined class (in `jax._src.prng`). These ought to be a leaf
# node everywhere in the system (e.g. in Jaxpr), but we want to unpack
# and repack them across the fast dispatch boundary. If we were to
# skip registering such types here, the fast dispatch path would not
# know how to handle them as arguments. It would instead always
# indicate a "cache miss" and dispatch on the slow path.
dispatch_registry = pytree.PyTreeRegistry(
enable_none=True, enable_tuple=True, enable_namedtuple=True,
enable_list=True, enable_dict=True)
dispatch_registry.__module__ = __name__
dispatch_registry.__name__ = "dispatch_registry"
def tree_flatten(tree: Any,
is_leaf: Callable[[Any], bool] | None = None
) -> tuple[list[Leaf], PyTreeDef]:
@ -162,6 +178,7 @@ def register_pytree_node(nodetype: type[T],
``nodetype``.
"""
default_registry.register_node(nodetype, flatten_func, unflatten_func)
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)

View File

@ -1805,6 +1805,63 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertIsInstance(k1, random.KeyArray)
self.assertIsInstance(k2, random.KeyArray)
def test_cpp_dispatch_normal(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
# function with a key array as an argument.
@jax.jit
def f(key):
return jax.random.normal(key)
key = self.make_keys()
with jtu.count_pjit_cpp_cache_miss() as count:
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
def test_cpp_dispatch_split(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
# function with a key arrays as inputs and as outputs.
@jax.jit
def f(key):
return jax.random.split(key)
key = self.make_keys()
with jtu.count_pjit_cpp_cache_miss() as count:
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
def test_cpp_dispatch_aot_normal(self):
# Ensure we stay on the C++ dispatch path when calling an
# AOT-compiled function with a key array as an argument.
key = self.make_keys()
f = jax.jit(lambda key: jax.random.normal(key)).lower(key).compile()
with jtu.count_aot_jit_cpp_cache_miss() as count:
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
def test_cpp_dispatch_aot_split(self):
# Ensure we stay on the C++ dispatch path when calling an
# AOT-compiled function with a key arrays as inputs and as
# outputs.
key = self.make_keys()
f = jax.jit(lambda key: jax.random.split(key)).lower(key).compile()
with jtu.count_aot_jit_cpp_cache_miss() as count:
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
# -- prng primitives
def test_random_wrap_vmap(self):