Add ShardedDeviceArray indexing benchmark. (#2549)

Example output:
```
---------Benchmark summary for ShardedDeviceArray_indexing---------
indices_fn              mean     %std    relative
------------------  --------  -------  ----------
integer_indices      0.16901  8.52522       1
integer_2D_indices  18.4918   0           109.412
```
This commit is contained in:
Skye Wanderman-Milne 2020-03-31 15:52:41 -07:00 committed by GitHub
parent bfbd0b800f
commit c28c46e191
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 2 deletions

View File

@ -92,18 +92,25 @@ def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict],
times = []
for params in params_list:
f = prepare(**params)
subname = name + "".join("_%s=%s" % (n, p) for n, p in params.items())
subname = name + "".join("_%s=%s" % (n, _param_str(p))
for n, p in params.items())
times.append(benchmark(f, name=subname,
target_total_secs=target_total_secs))
print("---------Benchmark summary for %s---------" % name)
param_names = list(params_list[0].keys())
print(tabulate([tuple(params.values()) +
print(tabulate([tuple(map(_param_str, params.values())) +
(t.mean(), _pstd(t), t.mean() / times[0].mean())
for params, t in safe_zip(params_list, times)],
param_names + ["mean", "%std", "relative"]))
print()
def _param_str(param):
if callable(param):
return param.__name__
return str(param)
def _pstd(x):
return x.std() / x.mean() * 100

View File

@ -84,9 +84,37 @@ def pmap_shard_outputs_benchmark():
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
def sharded_device_array_indexing_benchmark():
"""Benchmark focusing on ShardedDeviceArray indexing."""
def get_benchmark_fn(indices_fn):
nshards = min(8, jax.local_device_count())
shape = (nshards, 8, 8)
def benchmark_fn():
arr = pmap(lambda x: x)(np.arange(np.prod(shape)).reshape(shape))
indices = indices_fn()
for idx in indices:
arr[idx]
return benchmark_fn
num_internal_iters = 1000
def integer_indices():
return (i for _ in range(num_internal_iters) for i in range(8))
def integer_2D_indices():
return ((i,i) for _ in range(num_internal_iters) for i in range(8))
params = []
params.append({"indices_fn": integer_indices})
params.append({"indices_fn": integer_2D_indices})
benchmark.benchmark_suite(get_benchmark_fn, params,
"ShardedDeviceArray_indexing")
def run_all_benchmarks():
pmap_shard_args_benchmark()
pmap_shard_outputs_benchmark()
sharded_device_array_indexing_benchmark()
def main(unused_argv):