Changed to pmap_benchmark to make it runnable in Google (#2448)

This commit is contained in:
George Necula 2020-03-19 06:56:59 +01:00 committed by GitHub
parent 2998a21505
commit cd7ab0a9e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,12 +18,15 @@ python3 pmap_benchmark.py
To make it run faster, set env var TARGET_TOTAL_SECS to a low number (e.g. 2).
"""
import numpy as onp
from absl import app
from benchmark import benchmark_suite
import jax
import jax.numpy as np
from jax import numpy as np
from jax import pmap
from jax.benchmarks import benchmark
from jax.config import config
import numpy as onp
def pmap_shard_args_benchmark():
@ -38,7 +41,8 @@ def pmap_shard_args_benchmark():
shape = (nshards, 4)
args = [onp.random.random(shape) for _ in range(nargs)]
sharded_args = pmap(lambda x: x)(args)
assert all(type(arg) == jax.pxla.ShardedDeviceArray for arg in sharded_args)
assert all(isinstance(arg, jax.pxla.ShardedDeviceArray)
for arg in sharded_args)
def benchmark_fn():
for _ in range(100):
pmap_fn(*sharded_args)
@ -51,7 +55,7 @@ def pmap_shard_args_benchmark():
for nshards in (2, 4, 8, 100, 500):
if nshards > jax.local_device_count(): continue
params.append({"nargs": 10, "nshards": nshards})
benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_args")
def pmap_shard_outputs_benchmark():
@ -76,7 +80,7 @@ def pmap_shard_outputs_benchmark():
for nshards in (2, 4, 8, 100, 500):
if nshards > jax.local_device_count(): continue
params.append({"nouts": 10, "nshards": nshards})
benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
benchmark.benchmark_suite(get_benchmark_fn, params, "pmap_shard_outputs")
def run_all_benchmarks():
@ -84,5 +88,10 @@ def run_all_benchmarks():
pmap_shard_outputs_benchmark()
if __name__ == "__main__":
def main(unused_argv):
run_all_benchmarks()
if __name__ == "__main__":
config.config_with_absl()
app.run(main)