Adjust pmap_bechmark.py values to be more realistic. (#2622)

This commit is contained in:
Skye Wanderman-Milne 2020-04-06 16:38:34 -07:00 committed by GitHub
parent e8f989e38f
commit 3fe8bd027c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -50,12 +50,12 @@ def pmap_shard_args_benchmark():
return benchmark_fn
params = []
for nargs in (10, 100, 101, 500):
nshards = min(4, jax.local_device_count())
for nargs in (10, 100, 101, 500, 1000, 5000):
nshards = min(8, jax.local_device_count())
params.append({"nargs": nargs, "nshards": nshards})
for nshards in (2, 4, 8, 100, 500):
if nshards > jax.local_device_count(): continue
params.append({"nargs": 10, "nshards": nshards})
params.append({"nargs": 100, "nshards": nshards})
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")
@ -75,12 +75,12 @@ def pmap_shard_outputs_benchmark():
return benchmark_fn
params = []
for nouts in (10, 100, 500):
nshards = min(4, jax.local_device_count())
for nouts in (10, 100, 500, 1000, 5000):
nshards = min(8, jax.local_device_count())
params.append({"nouts": nouts, "nshards": nshards})
for nshards in (2, 4, 8, 100, 500):
if nshards > jax.local_device_count(): continue
params.append({"nouts": 10, "nshards": nshards})
params.append({"nouts": 100, "nshards": nshards})
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")