add simple single-primitive eager benchmarks

This commit is contained in:
Matthew Johnson 2021-03-18 21:46:46 -07:00
parent 75f191a2b8
commit 9802f3378e

View File

@ -19,6 +19,7 @@ import google_benchmark
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
partial = functools.partial
@ -36,6 +37,40 @@ def required_devices(num_devices_required):
return helper1
@google_benchmark.register
def eager_unary_dispatch(state):
a = jax.device_put(1)
lax.neg(a)
while state:
lax.neg(a)
@google_benchmark.register
def eager_unary(state):
a = jax.device_put(1)
lax.neg(a).block_until_ready()
while state:
lax.neg(a).block_until_ready()
@google_benchmark.register
def eager_binary_dispatch(state):
a = jax.device_put(1)
b = jax.device_put(2)
lax.add(a, b)
while state:
lax.add(a, b)
@google_benchmark.register
def eager_binary(state):
a = jax.device_put(1)
b = jax.device_put(2)
lax.add(a, b).block_until_ready()
while state:
lax.add(a, b).block_until_ready()
@google_benchmark.register
def jit_trivial_dispatch(state):
"""Benchmarks only the duration for jitted_f to return the future."""