Add benchmarks for jax.Array

PiperOrigin-RevId: 471889808
This commit is contained in:
Kuangyuan Chen 2022-09-02 14:44:32 -07:00 committed by jax authors
parent bd425e5dc5
commit d17e516ea7

View File

@ -20,6 +20,7 @@ import google_benchmark
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src import config as jax_config
from jax.experimental import sparse
from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
@ -127,6 +128,29 @@ def jit_simple(state):
while state:
f(a, b).block_until_ready()
@google_benchmark.register
def jit_simple_dispatch_array(state):
with jax_config.jax_array(True):
a = jax.device_put(1)
b = jax.device_put(2)
f = jax.jit(operator.add)
f(a, b)
while state:
f(a, b)
@google_benchmark.register
def jit_simple_array(state):
with jax_config.jax_array(True):
a = jax.device_put(1)
b = jax.device_put(2)
f = jax.jit(operator.add)
f(a, b)
while state:
f(a, b).block_until_ready()
@google_benchmark.register
def jit_small_matmul(state):