mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Expand device_put
benchmarks to run with different numbers of arrays and input types
For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays. PiperOrigin-RevId: 641716268
This commit is contained in:
parent
a8246ea67f
commit
6617a0d1ed
@ -678,10 +678,29 @@ def host_local_array_to_global_array(state):
|
||||
(input_data, input_data), global_mesh, (in_pspec, in_pspec))
|
||||
|
||||
@google_benchmark.register
|
||||
def device_put(state):
|
||||
x = np.array(1, np.int32)
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([1])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
@google_benchmark.option.args([1000])
|
||||
def device_put_from_numpy_array(state):
|
||||
x = [np.array(1, np.int32)] * state.range(0)
|
||||
while state:
|
||||
_ = jax.device_put(x).block_until_ready()
|
||||
_ = jax.block_until_ready(jax.device_put(x))
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([1])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
@google_benchmark.option.args([1000])
|
||||
def device_put_from_jax_array(state):
|
||||
x = [np.array(1, np.int32)] * state.range(0)
|
||||
x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0]))
|
||||
d = jax.devices()[1]
|
||||
while state:
|
||||
_ = jax.block_until_ready(jax.device_put(x, device=d))
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
|
Loading…
x
Reference in New Issue
Block a user