From 0b542ff585b30b24aceadd12a2335b0fa26b8209 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 15 Feb 2024 17:46:33 -0800 Subject: [PATCH] Add a benchmark measuring device_put's speed for a 4GB input array ``` --------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------- device_put_big 419 ms 0.363 ms 10 ``` PiperOrigin-RevId: 607512568 --- benchmarks/api_benchmark.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 2bceaf723..e6539b797 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -686,6 +686,16 @@ def device_put(state): _ = jax.device_put(x).block_until_ready() +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def device_put_big(state): + x = np.arange(4000 * 10**6 // np.dtype('float32').itemsize, dtype=np.float32) + jax.device_put(x).block_until_ready() + + while state: + _ = jax.device_put(x).block_until_ready() + + @google_benchmark.register def device_put_sharded(state): arr_inp = [np.array(i) for i in range(jax.device_count())]