Improve API benchmarks.

Add benchmarks for different dispatch arg arities.

Add more blocking before and after benchmark loops that don't otherwise block.
This commit is contained in:
Peter Hawkins 2021-03-03 20:50:45 -05:00
parent 6c102d9fbf
commit cdd36b1113

View File

@ -21,6 +21,8 @@ import jax.numpy as jnp
import numpy as np
partial = functools.partial
def required_devices(num_devices_required):
"""Helper to skip benchmarks that require more devices."""
def helper1(f):
@ -39,10 +41,10 @@ def jit_trivial_dispatch(state):
"""Benchmarks only the duration for jitted_f to return the future."""
f = jax.jit(swap)
a, b = f(1, 2)
f(a, b)
x = f(a, b)
while state:
f(a, b)
x = f(a, b)
x[0].block_until_ready()
@google_benchmark.register
@ -79,25 +81,33 @@ def jit_simple(state):
f(a, b).block_until_ready()
@google_benchmark.register
def jit_simple_many_args_dispatch(state):
args = [jax.device_put(i) for i in range(50)]
def jit_simple_many_args_dispatch(n, state):
args = [jax.device_put(i) for i in range(n)]
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
f(args)
x = f(args)
x.block_until_ready()
while state:
f(args)
x = f(args)
x.block_until_ready()
@google_benchmark.register
def jit_simple_many_args(state):
args = [jax.device_put(i) for i in range(50)]
def jit_simple_many_args(n, state):
args = [jax.device_put(i) for i in range(n)]
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
f(args)
f(args).block_until_ready()
while state:
f(args).block_until_ready()
benchmarks = []
for n in [10, 100, 1000, 2000]:
benchmarks += [
google_benchmark.register(partial(jit_simple_many_args_dispatch, n),
name=f"jit_simple_many_args_dispatch_{n}"),
google_benchmark.register(partial(jit_simple_many_args, n),
name=f"jit_simple_many_args_{n}")
]
@google_benchmark.register
def jit_dispatch_without_transfer(state):
@ -118,10 +128,11 @@ def jit_dispatch_with_transfer(state):
imgs = np.ones((128, 224, 224), np.float32)
f = jax.api.jit(lambda x: x+1)
f(imgs)
f(imgs).block_until_ready()
while state:
f(imgs)
x = f(imgs)
x.block_until_ready()
@google_benchmark.register