mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a benchmark measuring device_put's speed for a 4GB input array
``` --------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------- device_put_big 419 ms 0.363 ms 10 ``` PiperOrigin-RevId: 607512568
This commit is contained in:
parent
7e7094c82d
commit
0b542ff585
@ -686,6 +686,16 @@ def device_put(state):
|
||||
_ = jax.device_put(x).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def device_put_big(state):
|
||||
x = np.arange(4000 * 10**6 // np.dtype('float32').itemsize, dtype=np.float32)
|
||||
jax.device_put(x).block_until_ready()
|
||||
|
||||
while state:
|
||||
_ = jax.device_put(x).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def device_put_sharded(state):
|
||||
arr_inp = [np.array(i) for i in range(jax.device_count())]
|
||||
|
Loading…
x
Reference in New Issue
Block a user