diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 4dbc8109d..229941c40 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -696,6 +696,17 @@ def pjit_simple_4000_device(state): pjit_simple_benchmark( state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1)) +@google_benchmark.register +@jax_config.jax_array(True) +def simple_dispatch_array_pjit(state): + a = jax.device_put(1) + b = jax.device_put(2) + f = pjit_lib.pjit(operator.add) + f(a, b) + + while state: + f(a, b) + if __name__ == "__main__": google_benchmark.main()