Add pmap_shard_device_array_benchmark. (#2864)

Also renames pmap_shard_args_benchmark to pmap_shard_sharded_device_array_benchmark.
This commit is contained in:
Skye Wanderman-Milne 2020-04-27 17:21:05 -07:00 committed by GitHub
parent b277b55d1c
commit 4b0334338e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -30,7 +30,7 @@ from benchmarks import benchmark
import numpy as onp
def pmap_shard_args_benchmark():
def pmap_shard_sharded_device_array_benchmark():
"""Pmap benchmark focusing on shard_args fast path.
This is intended to measure how long it takes to dispatch a correctly-sharded
@ -56,7 +56,35 @@ def pmap_shard_args_benchmark():
for nshards in (2, 4, 8, 100, 500):
if nshards > jax.local_device_count(): continue
params.append({"nargs": 100, "nshards": nshards})
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")
benchmark.benchmark_suite(get_benchmark_fn, params,
"pmap_shard_sharded_device_array")
def pmap_shard_device_array_benchmark():
"""Pmap benchmark focusing on shard_args DeviceArray path.
This is intended to measure how long it takes to dispatch a DeviceArray to
pmap.
"""
def get_benchmark_fn(nargs, nshards):
pmap_fn = pmap(lambda *args: np.sum(args))
shape = (nshards, 4)
args = [np.array(onp.random.random(shape)) for _ in range(nargs)]
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
def benchmark_fn():
for _ in range(10):
pmap_fn(*args)
return benchmark_fn
params = []
for nargs in (10, 100, 500):
nshards = min(8, jax.local_device_count())
params.append({"nargs": nargs, "nshards": nshards})
for nshards in (2, 4, 8):
if nshards > jax.local_device_count(): continue
params.append({"nargs": 100, "nshards": nshards})
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_device_array")
def pmap_shard_outputs_benchmark():
@ -112,7 +140,8 @@ def sharded_device_array_indexing_benchmark():
def run_all_benchmarks():
pmap_shard_args_benchmark()
pmap_shard_sharded_device_array_benchmark()
pmap_shard_device_array_benchmark()
pmap_shard_outputs_benchmark()
sharded_device_array_indexing_benchmark()