mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add pmap_shard_device_array_benchmark. (#2864)
Also renames pmap_shard_args_benchmark to pmap_shard_sharded_device_array_benchmark.
This commit is contained in:
parent
b277b55d1c
commit
4b0334338e
@ -30,7 +30,7 @@ from benchmarks import benchmark
|
||||
import numpy as onp
|
||||
|
||||
|
||||
def pmap_shard_args_benchmark():
|
||||
def pmap_shard_sharded_device_array_benchmark():
|
||||
"""Pmap benchmark focusing on shard_args fast path.
|
||||
|
||||
This is intended to measure how long it takes to dispatch a correctly-sharded
|
||||
@ -56,7 +56,35 @@ def pmap_shard_args_benchmark():
|
||||
for nshards in (2, 4, 8, 100, 500):
|
||||
if nshards > jax.local_device_count(): continue
|
||||
params.append({"nargs": 100, "nshards": nshards})
|
||||
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")
|
||||
benchmark.benchmark_suite(get_benchmark_fn, params,
|
||||
"pmap_shard_sharded_device_array")
|
||||
|
||||
|
||||
def pmap_shard_device_array_benchmark():
|
||||
"""Pmap benchmark focusing on shard_args DeviceArray path.
|
||||
|
||||
This is intended to measure how long it takes to dispatch a DeviceArray to
|
||||
pmap.
|
||||
"""
|
||||
|
||||
def get_benchmark_fn(nargs, nshards):
|
||||
pmap_fn = pmap(lambda *args: np.sum(args))
|
||||
shape = (nshards, 4)
|
||||
args = [np.array(onp.random.random(shape)) for _ in range(nargs)]
|
||||
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
|
||||
def benchmark_fn():
|
||||
for _ in range(10):
|
||||
pmap_fn(*args)
|
||||
return benchmark_fn
|
||||
|
||||
params = []
|
||||
for nargs in (10, 100, 500):
|
||||
nshards = min(8, jax.local_device_count())
|
||||
params.append({"nargs": nargs, "nshards": nshards})
|
||||
for nshards in (2, 4, 8):
|
||||
if nshards > jax.local_device_count(): continue
|
||||
params.append({"nargs": 100, "nshards": nshards})
|
||||
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_device_array")
|
||||
|
||||
|
||||
def pmap_shard_outputs_benchmark():
|
||||
@ -112,7 +140,8 @@ def sharded_device_array_indexing_benchmark():
|
||||
|
||||
|
||||
def run_all_benchmarks():
|
||||
pmap_shard_args_benchmark()
|
||||
pmap_shard_sharded_device_array_benchmark()
|
||||
pmap_shard_device_array_benchmark()
|
||||
pmap_shard_outputs_benchmark()
|
||||
sharded_device_array_indexing_benchmark()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user