Add a benchmark measuring device_put's speed for a 4GB input array

```
---------------------------------------------------------
Benchmark               Time             CPU   Iterations
---------------------------------------------------------
device_put_big        419 ms        0.363 ms           10
```

PiperOrigin-RevId: 607512568
This commit is contained in:
Yash Katariya 2024-02-15 17:46:33 -08:00 committed by jax authors
parent 7e7094c82d
commit 0b542ff585

View File

@ -686,6 +686,16 @@ def device_put(state):
_ = jax.device_put(x).block_until_ready()
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def device_put_big(state):
x = np.arange(4000 * 10**6 // np.dtype('float32').itemsize, dtype=np.float32)
jax.device_put(x).block_until_ready()
while state:
_ = jax.device_put(x).block_until_ready()
@google_benchmark.register
def device_put_sharded(state):
arr_inp = [np.array(i) for i in range(jax.device_count())]