mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add ShardedDeviceArray indexing benchmark. (#2549)
Example output: ``` ---------Benchmark summary for ShardedDeviceArray_indexing--------- indices_fn mean %std relative ------------------ -------- ------- ---------- integer_indices 0.16901 8.52522 1 integer_2D_indices 18.4918 0 109.412 ```
This commit is contained in:
parent
bfbd0b800f
commit
c28c46e191
@ -92,18 +92,25 @@ def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict],
|
||||
times = []
|
||||
for params in params_list:
|
||||
f = prepare(**params)
|
||||
subname = name + "".join("_%s=%s" % (n, p) for n, p in params.items())
|
||||
subname = name + "".join("_%s=%s" % (n, _param_str(p))
|
||||
for n, p in params.items())
|
||||
times.append(benchmark(f, name=subname,
|
||||
target_total_secs=target_total_secs))
|
||||
|
||||
print("---------Benchmark summary for %s---------" % name)
|
||||
param_names = list(params_list[0].keys())
|
||||
print(tabulate([tuple(params.values()) +
|
||||
print(tabulate([tuple(map(_param_str, params.values())) +
|
||||
(t.mean(), _pstd(t), t.mean() / times[0].mean())
|
||||
for params, t in safe_zip(params_list, times)],
|
||||
param_names + ["mean", "%std", "relative"]))
|
||||
print()
|
||||
|
||||
|
||||
def _param_str(param):
|
||||
if callable(param):
|
||||
return param.__name__
|
||||
return str(param)
|
||||
|
||||
|
||||
def _pstd(x):
|
||||
return x.std() / x.mean() * 100
|
||||
|
@ -84,9 +84,37 @@ def pmap_shard_outputs_benchmark():
|
||||
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
|
||||
|
||||
|
||||
def sharded_device_array_indexing_benchmark():
|
||||
"""Benchmark focusing on ShardedDeviceArray indexing."""
|
||||
def get_benchmark_fn(indices_fn):
|
||||
nshards = min(8, jax.local_device_count())
|
||||
shape = (nshards, 8, 8)
|
||||
def benchmark_fn():
|
||||
arr = pmap(lambda x: x)(np.arange(np.prod(shape)).reshape(shape))
|
||||
indices = indices_fn()
|
||||
for idx in indices:
|
||||
arr[idx]
|
||||
return benchmark_fn
|
||||
|
||||
num_internal_iters = 1000
|
||||
|
||||
def integer_indices():
|
||||
return (i for _ in range(num_internal_iters) for i in range(8))
|
||||
|
||||
def integer_2D_indices():
|
||||
return ((i,i) for _ in range(num_internal_iters) for i in range(8))
|
||||
|
||||
params = []
|
||||
params.append({"indices_fn": integer_indices})
|
||||
params.append({"indices_fn": integer_2D_indices})
|
||||
benchmark.benchmark_suite(get_benchmark_fn, params,
|
||||
"ShardedDeviceArray_indexing")
|
||||
|
||||
|
||||
def run_all_benchmarks():
|
||||
pmap_shard_args_benchmark()
|
||||
pmap_shard_outputs_benchmark()
|
||||
sharded_device_array_indexing_benchmark()
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
|
Loading…
x
Reference in New Issue
Block a user