mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Adjust pmap_bechmark.py values to be more realistic. (#2622)
This commit is contained in:
parent
e8f989e38f
commit
3fe8bd027c
@ -50,12 +50,12 @@ def pmap_shard_args_benchmark():
|
||||
return benchmark_fn
|
||||
|
||||
params = []
|
||||
for nargs in (10, 100, 101, 500):
|
||||
nshards = min(4, jax.local_device_count())
|
||||
for nargs in (10, 100, 101, 500, 1000, 5000):
|
||||
nshards = min(8, jax.local_device_count())
|
||||
params.append({"nargs": nargs, "nshards": nshards})
|
||||
for nshards in (2, 4, 8, 100, 500):
|
||||
if nshards > jax.local_device_count(): continue
|
||||
params.append({"nargs": 10, "nshards": nshards})
|
||||
params.append({"nargs": 100, "nshards": nshards})
|
||||
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")
|
||||
|
||||
|
||||
@ -75,12 +75,12 @@ def pmap_shard_outputs_benchmark():
|
||||
return benchmark_fn
|
||||
|
||||
params = []
|
||||
for nouts in (10, 100, 500):
|
||||
nshards = min(4, jax.local_device_count())
|
||||
for nouts in (10, 100, 500, 1000, 5000):
|
||||
nshards = min(8, jax.local_device_count())
|
||||
params.append({"nouts": nouts, "nshards": nshards})
|
||||
for nshards in (2, 4, 8, 100, 500):
|
||||
if nshards > jax.local_device_count(): continue
|
||||
params.append({"nouts": 10, "nshards": nshards})
|
||||
params.append({"nouts": 100, "nshards": nshards})
|
||||
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user