rocm_jax/benchmarks/random_benchmark.py
Roy Frostig 6abefa1977 fast dispatch for functions over typed PRNG key arrays
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.

We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.

To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
2023-09-13 09:43:58 -07:00

128 lines
3.0 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Microbenchmarks for JAX random."""
import google_benchmark
import jax
from jax import dtypes
def _assert_raw_key(key):
assert key.dtype == "uint32"
def _assert_typed_key(key):
assert dtypes.issubdtype(key.dtype, dtypes.prng_key)
def _bench_trivial_dispatch(state, key):
f = jax.jit(lambda key: key)
_ = f(key)
while state:
f(key)
f(key).block_until_ready()
@google_benchmark.register
def trivial_dispatch_raw_key(state):
key = jax.random.PRNGKey(0)
_assert_raw_key(key)
_bench_trivial_dispatch(state, key)
@google_benchmark.register
def trivial_dispatch_typed_key(state):
key = jax.random.key(0)
_assert_typed_key(key)
_bench_trivial_dispatch(state, key)
def _bench_nontrivial_dispatch(state, key, do_split=False):
key_op = jax.random.split if do_split else jax.random.normal
f = jax.jit(lambda key: key_op(key))
_ = f(key)
while state:
f(key)
f(key).block_until_ready()
@google_benchmark.register
def nontrivial_dispatch_raw_key(state):
key = jax.random.PRNGKey(0)
_assert_raw_key(key)
_bench_nontrivial_dispatch(state, key, do_split=False)
@google_benchmark.register
def nontrivial_dispatch_typed_key(state):
key = jax.random.key(0)
_assert_typed_key(key)
_bench_nontrivial_dispatch(state, key, do_split=False)
@google_benchmark.register
def nontrivial_dispatch_raw_key_split(state):
key = jax.random.PRNGKey(0)
_assert_raw_key(key)
_bench_nontrivial_dispatch(state, key, do_split=True)
@google_benchmark.register
def nontrivial_dispatch_typed_key_split(state):
key = jax.random.key(0)
_assert_typed_key(key)
_bench_nontrivial_dispatch(state, key, do_split=True)
def _bench_custom_container(state, key):
@jax.tree_util.register_pytree_node_class
class A:
def __init__(self, x):
self.x = x
def tree_flatten(self):
return (self.x,), None
@classmethod
def tree_unflatten(cls, aux, children):
x, = children
return cls(x)
f = jax.jit(
lambda key, a: jax.random.normal(key) + a.x)
a = A(5.)
_ = f(key, a)
while state:
f(key, a)
f(key, a).block_until_ready()
@google_benchmark.register
def custom_container_raw_key(state):
key = jax.random.PRNGKey(0)
_assert_raw_key(key)
_bench_custom_container(state, key)
@google_benchmark.register
def custom_container_typed_key(state):
key = jax.random.key(0)
_assert_typed_key(key)
_bench_custom_container(state, key)
if __name__ == "__main__":
google_benchmark.main()