diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 2bceaf723..e6539b797 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -686,6 +686,16 @@ def device_put(state): _ = jax.device_put(x).block_until_ready() +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def device_put_big(state): + x = np.arange(4000 * 10**6 // np.dtype('float32').itemsize, dtype=np.float32) + jax.device_put(x).block_until_ready() + + while state: + _ = jax.device_put(x).block_until_ready() + + @google_benchmark.register def device_put_sharded(state): arr_inp = [np.array(i) for i in range(jax.device_count())]