From 4b0334338e94515eb8a1b6fba3093317cfc6ff98 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 27 Apr 2020 17:21:05 -0700 Subject: [PATCH] Add pmap_shard_device_array_benchmark. (#2864) Also renames pmap_shard_args_benchmark to pmap_shard_sharded_device_array_benchmark. --- benchmarks/pmap_benchmark.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/benchmarks/pmap_benchmark.py b/benchmarks/pmap_benchmark.py index fb2dca27c..34efa2885 100644 --- a/benchmarks/pmap_benchmark.py +++ b/benchmarks/pmap_benchmark.py @@ -30,7 +30,7 @@ from benchmarks import benchmark import numpy as onp -def pmap_shard_args_benchmark(): +def pmap_shard_sharded_device_array_benchmark(): """Pmap benchmark focusing on shard_args fast path. This is intended to measure how long it takes to dispatch a correctly-sharded @@ -56,7 +56,35 @@ def pmap_shard_args_benchmark(): for nshards in (2, 4, 8, 100, 500): if nshards > jax.local_device_count(): continue params.append({"nargs": 100, "nshards": nshards}) - benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_args") + benchmark.benchmark_suite(get_benchmark_fn, params, + "pmap_shard_sharded_device_array") + + +def pmap_shard_device_array_benchmark(): + """Pmap benchmark focusing on shard_args DeviceArray path. + + This is intended to measure how long it takes to dispatch a DeviceArray to + pmap. + """ + + def get_benchmark_fn(nargs, nshards): + pmap_fn = pmap(lambda *args: np.sum(args)) + shape = (nshards, 4) + args = [np.array(onp.random.random(shape)) for _ in range(nargs)] + assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args) + def benchmark_fn(): + for _ in range(10): + pmap_fn(*args) + return benchmark_fn + + params = [] + for nargs in (10, 100, 500): + nshards = min(8, jax.local_device_count()) + params.append({"nargs": nargs, "nshards": nshards}) + for nshards in (2, 4, 8): + if nshards > jax.local_device_count(): continue + params.append({"nargs": 100, "nshards": nshards}) + benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_device_array") def pmap_shard_outputs_benchmark(): @@ -112,7 +140,8 @@ def sharded_device_array_indexing_benchmark(): def run_all_benchmarks(): - pmap_shard_args_benchmark() + pmap_shard_sharded_device_array_benchmark() + pmap_shard_device_array_benchmark() pmap_shard_outputs_benchmark() sharded_device_array_indexing_benchmark()