mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
fast dispatch for functions over typed PRNG key arrays
Before this change, JAX could dispatch compiled functions over new-style (typed) RNG key arrays, but it would always do so off of the fast (C++-based) dispatch path. In other words, switching from old-style `uint32` RNG keys to new-style keys would regress dispatch times. With this change, dispatch happens on the fast path again and performance regressions ought to be minimal. We currently maintain only one pytree registry, for all registered pytree node types. We want RNG key arrays to also be treated as pytree leaves everywhere *except* during dispatch. In other words: we want operations on (typed) RNG key arrays to appear in Jaxpr, but we want to unravel those arrays into their underlying `uint32` arrays only during dispatch. To do this, we add a new internal pytree registry that dispatch respects uniquely. This registry includes all items in the default registry, but also the RNG key array type. Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 565077758
This commit is contained in:
parent
84951288bd
commit
6abefa1977
@ -48,8 +48,9 @@ def trivial_dispatch_typed_key(state):
|
||||
_bench_trivial_dispatch(state, key)
|
||||
|
||||
|
||||
def _bench_nontrivial_dispatch(state, key):
|
||||
f = jax.jit(lambda key: jax.random.split(key))
|
||||
def _bench_nontrivial_dispatch(state, key, do_split=False):
|
||||
key_op = jax.random.split if do_split else jax.random.normal
|
||||
f = jax.jit(lambda key: key_op(key))
|
||||
_ = f(key)
|
||||
while state:
|
||||
f(key)
|
||||
@ -60,14 +61,66 @@ def _bench_nontrivial_dispatch(state, key):
|
||||
def nontrivial_dispatch_raw_key(state):
|
||||
key = jax.random.PRNGKey(0)
|
||||
_assert_raw_key(key)
|
||||
_bench_nontrivial_dispatch(state, key)
|
||||
_bench_nontrivial_dispatch(state, key, do_split=False)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def nontrivial_dispatch_typed_key(state):
|
||||
key = jax.random.key(0)
|
||||
_assert_typed_key(key)
|
||||
_bench_nontrivial_dispatch(state, key)
|
||||
_bench_nontrivial_dispatch(state, key, do_split=False)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def nontrivial_dispatch_raw_key_split(state):
|
||||
key = jax.random.PRNGKey(0)
|
||||
_assert_raw_key(key)
|
||||
_bench_nontrivial_dispatch(state, key, do_split=True)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def nontrivial_dispatch_typed_key_split(state):
|
||||
key = jax.random.key(0)
|
||||
_assert_typed_key(key)
|
||||
_bench_nontrivial_dispatch(state, key, do_split=True)
|
||||
|
||||
|
||||
|
||||
def _bench_custom_container(state, key):
|
||||
@jax.tree_util.register_pytree_node_class
|
||||
class A:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def tree_flatten(self):
|
||||
return (self.x,), None
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux, children):
|
||||
x, = children
|
||||
return cls(x)
|
||||
|
||||
f = jax.jit(
|
||||
lambda key, a: jax.random.normal(key) + a.x)
|
||||
a = A(5.)
|
||||
_ = f(key, a)
|
||||
while state:
|
||||
f(key, a)
|
||||
f(key, a).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def custom_container_raw_key(state):
|
||||
key = jax.random.PRNGKey(0)
|
||||
_assert_raw_key(key)
|
||||
_bench_custom_container(state, key)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def custom_container_typed_key(state):
|
||||
key = jax.random.key(0)
|
||||
_assert_typed_key(key)
|
||||
_bench_custom_container(state, key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2707,6 +2707,14 @@ class MeshExecutableFastpathData(NamedTuple):
|
||||
kept_var_bitvec: Iterable[bool]
|
||||
|
||||
|
||||
def reflatten_outputs_for_dispatch(out_tree, out_flat):
|
||||
# We arrive at dispatch having flattened according to the default
|
||||
# pytree registry, but we want to re-flatten according to our
|
||||
# dispatch-specific registry.
|
||||
out_unflat = tree_util.tree_unflatten(out_tree, out_flat)
|
||||
return tree_util.dispatch_registry.flatten(out_unflat, None)
|
||||
|
||||
|
||||
class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = [
|
||||
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
|
||||
@ -2766,6 +2774,8 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
def aot_cache_miss(*args, **kwargs):
|
||||
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
|
||||
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
|
||||
out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch(
|
||||
out_tree, out_flat)
|
||||
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
|
||||
|
||||
if use_fastpath:
|
||||
@ -2774,14 +2784,14 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
kept_var_bitvec = [i in self._kept_var_idx
|
||||
for i in range(len(args_flat))]
|
||||
fastpath_data = MeshExecutableFastpathData(
|
||||
self.xla_executable, out_tree, self._in_shardings,
|
||||
self.xla_executable, out_tree_dispatch, self._in_shardings,
|
||||
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
|
||||
else:
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data
|
||||
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.default_registry)
|
||||
tree_util.dispatch_registry)
|
||||
|
||||
def create_cpp_call_for_apply_primitive(self, out_tree):
|
||||
# unsafe_call can be different than ExecuteReplicated for pathways.
|
||||
@ -2793,6 +2803,8 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
def apply_primitive_cache_miss(*args):
|
||||
out_flat = self.unsafe_call(*args)
|
||||
outs = tree_util.tree_unflatten(out_tree, out_flat)
|
||||
out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch(
|
||||
out_tree, out_flat)
|
||||
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
|
||||
|
||||
if use_fastpath:
|
||||
@ -2801,14 +2813,14 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
kept_var_bitvec = [i in self._kept_var_idx
|
||||
for i in range(len(args))]
|
||||
fastpath_data = MeshExecutableFastpathData(
|
||||
self.xla_executable, out_tree, self._in_shardings,
|
||||
self.xla_executable, out_tree_dispatch, self._in_shardings,
|
||||
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
|
||||
else:
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data
|
||||
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, apply_primitive_cache_miss,
|
||||
[], [], [], tree_util.default_registry)
|
||||
[], [], [], tree_util.dispatch_registry)
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
|
@ -192,6 +192,8 @@ def _python_pjit(fun: Callable, infer_params_fn):
|
||||
|
||||
|
||||
def _get_fastpath_data(executable, out_tree, args_flat, out_flat):
|
||||
out_flat, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
|
||||
|
||||
use_fastpath = (
|
||||
executable is not None and
|
||||
isinstance(executable, pxla.MeshExecutable) and
|
||||
@ -259,7 +261,7 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, static_argnums, static_argnames,
|
||||
donate_argnums, tree_util.default_registry,
|
||||
donate_argnums, tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding))
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
@ -1210,7 +1212,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
in_shardings, out_shardings, None, None)
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.default_registry,
|
||||
tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
|
||||
pjit_p.def_impl(_pjit_call_impl)
|
||||
|
@ -36,6 +36,7 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import tree_util as tree_util_internal
|
||||
from jax._src import typing
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.config import config
|
||||
@ -400,6 +401,16 @@ basearray.Array.register(PRNGKeyArrayImpl)
|
||||
|
||||
ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type]
|
||||
|
||||
def prngkeyarrayimpl_flatten(x):
|
||||
return (x._base_array,), x.impl
|
||||
|
||||
def prngkeyarrayimpl_unflatten(impl, children):
|
||||
base_array, = children
|
||||
return PRNGKeyArrayImpl(impl, base_array)
|
||||
|
||||
tree_util_internal.dispatch_registry.register_node(
|
||||
PRNGKeyArrayImpl, prngkeyarrayimpl_flatten, prngkeyarrayimpl_unflatten)
|
||||
|
||||
|
||||
# TODO(frostig): remove, rerouting callers directly to random_seed
|
||||
def seed_with_impl(impl: PRNGImpl, seed: int | Array) -> PRNGKeyArrayImpl:
|
||||
|
@ -47,6 +47,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import monitoring
|
||||
from jax._src import stages
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.config import (bool_env, config,
|
||||
raise_persistent_cache_errors,
|
||||
@ -240,6 +241,22 @@ def count_pjit_cpp_cache_miss():
|
||||
pjit_lib._pjit_lower = original_pjit_lower
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_aot_jit_cpp_cache_miss():
|
||||
original_call = stages.Compiled.call
|
||||
count = [0]
|
||||
|
||||
def compiled_call_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_call(*args, **kwargs)
|
||||
|
||||
stages.Compiled.call = compiled_call_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
stages.Compiled.call = original_call
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_jit_and_pmap_compiles():
|
||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||
|
@ -44,6 +44,22 @@ default_registry = pytree.default_registry()
|
||||
default_registry.__module__ = __name__
|
||||
default_registry.__name__ = "default_registry"
|
||||
|
||||
# A special, internal pytree registry that includes everything in
|
||||
# `default_registry`, plus internal Python-defined types that we want
|
||||
# to teach the fast dispatch path ("C++ dispatch") how to flatten and
|
||||
# unflatten. A key example is PRNG key arrays, which are currently a
|
||||
# Python-defined class (in `jax._src.prng`). These ought to be a leaf
|
||||
# node everywhere in the system (e.g. in Jaxpr), but we want to unpack
|
||||
# and repack them across the fast dispatch boundary. If we were to
|
||||
# skip registering such types here, the fast dispatch path would not
|
||||
# know how to handle them as arguments. It would instead always
|
||||
# indicate a "cache miss" and dispatch on the slow path.
|
||||
dispatch_registry = pytree.PyTreeRegistry(
|
||||
enable_none=True, enable_tuple=True, enable_namedtuple=True,
|
||||
enable_list=True, enable_dict=True)
|
||||
dispatch_registry.__module__ = __name__
|
||||
dispatch_registry.__name__ = "dispatch_registry"
|
||||
|
||||
def tree_flatten(tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> tuple[list[Leaf], PyTreeDef]:
|
||||
@ -162,6 +178,7 @@ def register_pytree_node(nodetype: type[T],
|
||||
``nodetype``.
|
||||
"""
|
||||
default_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
||||
|
||||
|
||||
|
@ -1805,6 +1805,63 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(k1, random.KeyArray)
|
||||
self.assertIsInstance(k2, random.KeyArray)
|
||||
|
||||
def test_cpp_dispatch_normal(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling a jitted
|
||||
# function with a key array as an argument.
|
||||
|
||||
@jax.jit
|
||||
def f(key):
|
||||
return jax.random.normal(key)
|
||||
|
||||
key = self.make_keys()
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
f(key).block_until_ready()
|
||||
f(key).block_until_ready()
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_cpp_dispatch_split(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling a jitted
|
||||
# function with a key arrays as inputs and as outputs.
|
||||
|
||||
@jax.jit
|
||||
def f(key):
|
||||
return jax.random.split(key)
|
||||
|
||||
key = self.make_keys()
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
f(key).block_until_ready()
|
||||
f(key).block_until_ready()
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_cpp_dispatch_aot_normal(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling an
|
||||
# AOT-compiled function with a key array as an argument.
|
||||
|
||||
key = self.make_keys()
|
||||
f = jax.jit(lambda key: jax.random.normal(key)).lower(key).compile()
|
||||
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as count:
|
||||
f(key).block_until_ready()
|
||||
f(key).block_until_ready()
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_cpp_dispatch_aot_split(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling an
|
||||
# AOT-compiled function with a key arrays as inputs and as
|
||||
# outputs.
|
||||
|
||||
key = self.make_keys()
|
||||
f = jax.jit(lambda key: jax.random.split(key)).lower(key).compile()
|
||||
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as count:
|
||||
f(key).block_until_ready()
|
||||
f(key).block_until_ready()
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
# -- prng primitives
|
||||
|
||||
def test_random_wrap_vmap(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user