mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Changed to pmap_benchmark to make it runnable in Google (#2448)
This commit is contained in:
parent
2998a21505
commit
cd7ab0a9e0
@ -18,12 +18,15 @@ python3 pmap_benchmark.py
|
||||
|
||||
To make it run faster, set env var TARGET_TOTAL_SECS to a low number (e.g. 2).
|
||||
"""
|
||||
import numpy as onp
|
||||
from absl import app
|
||||
|
||||
from benchmark import benchmark_suite
|
||||
import jax
|
||||
import jax.numpy as np
|
||||
from jax import numpy as np
|
||||
from jax import pmap
|
||||
from jax.benchmarks import benchmark
|
||||
from jax.config import config
|
||||
|
||||
import numpy as onp
|
||||
|
||||
|
||||
def pmap_shard_args_benchmark():
|
||||
@ -38,7 +41,8 @@ def pmap_shard_args_benchmark():
|
||||
shape = (nshards, 4)
|
||||
args = [onp.random.random(shape) for _ in range(nargs)]
|
||||
sharded_args = pmap(lambda x: x)(args)
|
||||
assert all(type(arg) == jax.pxla.ShardedDeviceArray for arg in sharded_args)
|
||||
assert all(isinstance(arg, jax.pxla.ShardedDeviceArray)
|
||||
for arg in sharded_args)
|
||||
def benchmark_fn():
|
||||
for _ in range(100):
|
||||
pmap_fn(*sharded_args)
|
||||
@ -51,7 +55,7 @@ def pmap_shard_args_benchmark():
|
||||
for nshards in (2, 4, 8, 100, 500):
|
||||
if nshards > jax.local_device_count(): continue
|
||||
params.append({"nargs": 10, "nshards": nshards})
|
||||
benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")
|
||||
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")
|
||||
|
||||
|
||||
def pmap_shard_outputs_benchmark():
|
||||
@ -76,7 +80,7 @@ def pmap_shard_outputs_benchmark():
|
||||
for nshards in (2, 4, 8, 100, 500):
|
||||
if nshards > jax.local_device_count(): continue
|
||||
params.append({"nouts": 10, "nshards": nshards})
|
||||
benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
|
||||
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
|
||||
|
||||
|
||||
def run_all_benchmarks():
|
||||
@ -84,5 +88,10 @@ def run_all_benchmarks():
|
||||
pmap_shard_outputs_benchmark()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main(unused_argv):
|
||||
run_all_benchmarks()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.config_with_absl()
|
||||
app.run(main)
|
||||
|
Loading…
x
Reference in New Issue
Block a user