Update api_benchmark to not use any deprecated APIs.

PiperOrigin-RevId: 512941633
This commit is contained in:
Lena Martens 2023-02-28 08:32:46 -08:00 committed by jax authors
parent f66f6ec98a
commit 4f48f94649

View File

@ -29,6 +29,7 @@ from jax.interpreters import xla
from jax.interpreters import pxla
from jax._src import array
from jax._src import sharding
from jax._src.pjit import pjit_check_aval_sharding
from jax.experimental import pjit as pjit_lib
from jax.experimental import multihost_utils
import jax.numpy as jnp
@ -153,7 +154,6 @@ def jit_simple(state):
@google_benchmark.register
def jit_simple_dispatch_array(state):
with jax_config.jax_array(True):
a = jax.device_put(1)
b = jax.device_put(2)
f = jax.jit(operator.add)
@ -165,7 +165,6 @@ def jit_simple_dispatch_array(state):
@google_benchmark.register
def jit_simple_array(state):
with jax_config.jax_array(True):
a = jax.device_put(1)
b = jax.device_put(2)
f = jax.jit(operator.add)
@ -200,17 +199,16 @@ def jit_big_matmul(state):
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'jax_array'])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.args([1000, False])
@google_benchmark.option.args([1000, True])
@google_benchmark.option.args([2000, False])
@google_benchmark.option.args([2000, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([10])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
@google_benchmark.option.args([100])
@google_benchmark.option.args([1000])
@google_benchmark.option.args([1000])
@google_benchmark.option.args([2000])
@google_benchmark.option.args([2000])
def jit_simple_many_args_dispatch(state):
with jax_config.jax_array(state.range(1)):
args = [jax.device_put(i) for i in range(state.range(0))]
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
x = f(args)
@ -221,17 +219,16 @@ def jit_simple_many_args_dispatch(state):
x.block_until_ready()
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'jax_array'])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.args([1000, False])
@google_benchmark.option.args([1000, True])
@google_benchmark.option.args([2000, False])
@google_benchmark.option.args([2000, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([10])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
@google_benchmark.option.args([100])
@google_benchmark.option.args([1000])
@google_benchmark.option.args([1000])
@google_benchmark.option.args([2000])
@google_benchmark.option.args([2000])
def jit_simple_many_args(state):
with jax_config.jax_array(state.range(1)):
args = [jax.device_put(i) for i in range(state.range(0))]
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
f(args).block_until_ready()
@ -296,12 +293,8 @@ def jit_dispatch_with_transfer(state):
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(2)
def pmap_trivial_2_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(swap)
a, b = f(jnp.array([1, 2]), jnp.array([3, 4]))
@ -312,12 +305,8 @@ def pmap_trivial_2_devices(state):
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_trivial_dispatch_8_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(swap)
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
@ -327,12 +316,8 @@ def pmap_trivial_dispatch_8_devices(state):
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_trivial_8_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(swap)
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
@ -344,12 +329,8 @@ def pmap_trivial_8_devices(state):
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(2)
def pmap_simple_2_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda a, b: (a + b, a - b))
a, b = f(jnp.array([1, 2]), jnp.array([3, 4]))
@ -360,12 +341,8 @@ def pmap_simple_2_devices(state):
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_simple_dispatch_8_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda a, b: (a + b, a - b))
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
@ -375,12 +352,8 @@ def pmap_simple_dispatch_8_devices(state):
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_simple_8_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda a, b: (a + b, a - b))
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
@ -392,12 +365,8 @@ def pmap_simple_8_devices(state):
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_simple_dispatch_8_devices_100_args(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
args = []
for i in range(100):
@ -410,12 +379,8 @@ def pmap_simple_dispatch_8_devices_100_args(state):
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_simple_8_devices_100_args(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
args = []
for i in range(100):
@ -624,11 +589,11 @@ def bench_pjit_check_aval_sharding(state):
mesh = create_mesh((4, 2), ('x', 'y'), state)
if mesh is None:
return
s = sharding.NamedSharding(mesh, pxla.PartitionSpec('x', 'y'))
s = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
aval = jax.core.ShapedArray((8, 2), np.int32)
while state:
pjit_lib.pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
@google_benchmark.register
@ -732,7 +697,6 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_1_device(state):
pjit_simple_benchmark(
state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1))
@ -745,7 +709,6 @@ def pjit_simple_1_device(state):
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_4_device(state):
pjit_simple_benchmark(
state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1))
@ -758,7 +721,6 @@ def pjit_simple_4_device(state):
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_4000_device(state):
pjit_simple_benchmark(
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
@ -772,7 +734,6 @@ def pjit_simple_4000_device(state):
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_aot_1_device(state):
pjit_simple_benchmark(
state,
@ -790,7 +751,6 @@ def pjit_aot_1_device(state):
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_aot_4_device(state):
pjit_simple_benchmark(
state,
@ -808,7 +768,6 @@ def pjit_aot_4_device(state):
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_aot_4000_device(state):
pjit_simple_benchmark(
state,
@ -824,7 +783,7 @@ def host_local_array_to_global_array(state):
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
input_shape = (8, 2)
input_data = np.arange(np.prod(input_shape)).reshape(input_shape)
in_pspec = pxla.PartitionSpec('x', 'y')
in_pspec = jax.sharding.PartitionSpec('x', 'y')
while state:
multihost_utils.host_local_array_to_global_array(