diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 72095b3bf..22fda5a4e 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -19,6 +19,7 @@ import google_benchmark import jax import jax.numpy as jnp import numpy as np +from jax import lax partial = functools.partial @@ -36,6 +37,40 @@ def required_devices(num_devices_required): return helper1 +@google_benchmark.register +def eager_unary_dispatch(state): + a = jax.device_put(1) + lax.neg(a) + while state: + lax.neg(a) + + +@google_benchmark.register +def eager_unary(state): + a = jax.device_put(1) + lax.neg(a).block_until_ready() + while state: + lax.neg(a).block_until_ready() + + +@google_benchmark.register +def eager_binary_dispatch(state): + a = jax.device_put(1) + b = jax.device_put(2) + lax.add(a, b) + while state: + lax.add(a, b) + + +@google_benchmark.register +def eager_binary(state): + a = jax.device_put(1) + b = jax.device_put(2) + lax.add(a, b).block_until_ready() + while state: + lax.add(a, b).block_until_ready() + + @google_benchmark.register def jit_trivial_dispatch(state): """Benchmarks only the duration for jitted_f to return the future."""