Expand device_put benchmarks to run with different numbers of arrays and input types

For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays.

PiperOrigin-RevId: 641716268
This commit is contained in:
Junwhan Ahn 2024-06-09 13:01:10 -07:00 committed by jax authors
parent a8246ea67f
commit 6617a0d1ed

View File

@ -678,10 +678,29 @@ def host_local_array_to_global_array(state):
(input_data, input_data), global_mesh, (in_pspec, in_pspec))
@google_benchmark.register
def device_put(state):
x = np.array(1, np.int32)
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
@google_benchmark.option.args([1000])
def device_put_from_numpy_array(state):
x = [np.array(1, np.int32)] * state.range(0)
while state:
_ = jax.device_put(x).block_until_ready()
_ = jax.block_until_ready(jax.device_put(x))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
@google_benchmark.option.args([1000])
def device_put_from_jax_array(state):
x = [np.array(1, np.int32)] * state.range(0)
x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0]))
d = jax.devices()[1]
while state:
_ = jax.block_until_ready(jax.device_put(x, device=d))
@google_benchmark.register