mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Improve API benchmarks.
Add benchmarks for different dispatch arg arities. Add more blocking before and after benchmark loops that don't otherwise block.
This commit is contained in:
parent
6c102d9fbf
commit
cdd36b1113
@ -21,6 +21,8 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
partial = functools.partial
|
||||
|
||||
def required_devices(num_devices_required):
|
||||
"""Helper to skip benchmarks that require more devices."""
|
||||
def helper1(f):
|
||||
@ -39,10 +41,10 @@ def jit_trivial_dispatch(state):
|
||||
"""Benchmarks only the duration for jitted_f to return the future."""
|
||||
f = jax.jit(swap)
|
||||
a, b = f(1, 2)
|
||||
f(a, b)
|
||||
|
||||
x = f(a, b)
|
||||
while state:
|
||||
f(a, b)
|
||||
x = f(a, b)
|
||||
x[0].block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@ -79,25 +81,33 @@ def jit_simple(state):
|
||||
f(a, b).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_simple_many_args_dispatch(state):
|
||||
args = [jax.device_put(i) for i in range(50)]
|
||||
def jit_simple_many_args_dispatch(n, state):
|
||||
args = [jax.device_put(i) for i in range(n)]
|
||||
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
|
||||
f(args)
|
||||
x = f(args)
|
||||
x.block_until_ready()
|
||||
|
||||
while state:
|
||||
f(args)
|
||||
x = f(args)
|
||||
x.block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_simple_many_args(state):
|
||||
args = [jax.device_put(i) for i in range(50)]
|
||||
def jit_simple_many_args(n, state):
|
||||
args = [jax.device_put(i) for i in range(n)]
|
||||
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
|
||||
f(args)
|
||||
f(args).block_until_ready()
|
||||
|
||||
while state:
|
||||
f(args).block_until_ready()
|
||||
|
||||
benchmarks = []
|
||||
for n in [10, 100, 1000, 2000]:
|
||||
benchmarks += [
|
||||
google_benchmark.register(partial(jit_simple_many_args_dispatch, n),
|
||||
name=f"jit_simple_many_args_dispatch_{n}"),
|
||||
google_benchmark.register(partial(jit_simple_many_args, n),
|
||||
name=f"jit_simple_many_args_{n}")
|
||||
]
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_dispatch_without_transfer(state):
|
||||
@ -118,10 +128,11 @@ def jit_dispatch_with_transfer(state):
|
||||
imgs = np.ones((128, 224, 224), np.float32)
|
||||
|
||||
f = jax.api.jit(lambda x: x+1)
|
||||
f(imgs)
|
||||
f(imgs).block_until_ready()
|
||||
|
||||
while state:
|
||||
f(imgs)
|
||||
x = f(imgs)
|
||||
x.block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
|
Loading…
x
Reference in New Issue
Block a user