mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
add simple single-primitive eager benchmarks
This commit is contained in:
parent
75f191a2b8
commit
9802f3378e
@ -19,6 +19,7 @@ import google_benchmark
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import lax
|
||||
|
||||
|
||||
partial = functools.partial
|
||||
@ -36,6 +37,40 @@ def required_devices(num_devices_required):
|
||||
return helper1
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def eager_unary_dispatch(state):
|
||||
a = jax.device_put(1)
|
||||
lax.neg(a)
|
||||
while state:
|
||||
lax.neg(a)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def eager_unary(state):
|
||||
a = jax.device_put(1)
|
||||
lax.neg(a).block_until_ready()
|
||||
while state:
|
||||
lax.neg(a).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def eager_binary_dispatch(state):
|
||||
a = jax.device_put(1)
|
||||
b = jax.device_put(2)
|
||||
lax.add(a, b)
|
||||
while state:
|
||||
lax.add(a, b)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def eager_binary(state):
|
||||
a = jax.device_put(1)
|
||||
b = jax.device_put(2)
|
||||
lax.add(a, b).block_until_ready()
|
||||
while state:
|
||||
lax.add(a, b).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_trivial_dispatch(state):
|
||||
"""Benchmarks only the duration for jitted_f to return the future."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user