mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add benchmarks for jax.Array
PiperOrigin-RevId: 471889808
This commit is contained in:
parent
bd425e5dc5
commit
d17e516ea7
@ -20,6 +20,7 @@ import google_benchmark
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import config as jax_config
|
||||
from jax.experimental import sparse
|
||||
from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
@ -127,6 +128,29 @@ def jit_simple(state):
|
||||
while state:
|
||||
f(a, b).block_until_ready()
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_simple_dispatch_array(state):
|
||||
with jax_config.jax_array(True):
|
||||
a = jax.device_put(1)
|
||||
b = jax.device_put(2)
|
||||
f = jax.jit(operator.add)
|
||||
f(a, b)
|
||||
|
||||
while state:
|
||||
f(a, b)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_simple_array(state):
|
||||
with jax_config.jax_array(True):
|
||||
a = jax.device_put(1)
|
||||
b = jax.device_put(2)
|
||||
f = jax.jit(operator.add)
|
||||
f(a, b)
|
||||
|
||||
while state:
|
||||
f(a, b).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_small_matmul(state):
|
||||
|
Loading…
x
Reference in New Issue
Block a user