From 6617a0d1edb1ec0a144347c40239e5ece75fc354 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Sun, 9 Jun 2024 13:01:10 -0700 Subject: [PATCH] Expand `device_put` benchmarks to run with different numbers of arrays and input types For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays. PiperOrigin-RevId: 641716268 --- benchmarks/api_benchmark.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 75cd38d10..2b5e61ffb 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -678,10 +678,29 @@ def host_local_array_to_global_array(state): (input_data, input_data), global_mesh, (in_pspec, in_pspec)) @google_benchmark.register -def device_put(state): - x = np.array(1, np.int32) +@google_benchmark.option.arg_names(['num_args']) +@google_benchmark.option.args([1]) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([1000]) +def device_put_from_numpy_array(state): + x = [np.array(1, np.int32)] * state.range(0) while state: - _ = jax.device_put(x).block_until_ready() + _ = jax.block_until_ready(jax.device_put(x)) + + +@google_benchmark.register +@google_benchmark.option.arg_names(['num_args']) +@google_benchmark.option.args([1]) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([1000]) +def device_put_from_jax_array(state): + x = [np.array(1, np.int32)] * state.range(0) + x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0])) + d = jax.devices()[1] + while state: + _ = jax.block_until_ready(jax.device_put(x, device=d)) @google_benchmark.register