mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Update api_benchmark to not use any deprecated APIs.
PiperOrigin-RevId: 512941633
This commit is contained in:
parent
f66f6ec98a
commit
4f48f94649
@ -29,6 +29,7 @@ from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax._src.pjit import pjit_check_aval_sharding
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
from jax.experimental import multihost_utils
|
||||
import jax.numpy as jnp
|
||||
@ -153,7 +154,6 @@ def jit_simple(state):
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_simple_dispatch_array(state):
|
||||
with jax_config.jax_array(True):
|
||||
a = jax.device_put(1)
|
||||
b = jax.device_put(2)
|
||||
f = jax.jit(operator.add)
|
||||
@ -165,7 +165,6 @@ def jit_simple_dispatch_array(state):
|
||||
|
||||
@google_benchmark.register
|
||||
def jit_simple_array(state):
|
||||
with jax_config.jax_array(True):
|
||||
a = jax.device_put(1)
|
||||
b = jax.device_put(2)
|
||||
f = jax.jit(operator.add)
|
||||
@ -200,17 +199,16 @@ def jit_big_matmul(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'jax_array'])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@google_benchmark.option.args([1000, False])
|
||||
@google_benchmark.option.args([1000, True])
|
||||
@google_benchmark.option.args([2000, False])
|
||||
@google_benchmark.option.args([2000, True])
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
@google_benchmark.option.args([100])
|
||||
@google_benchmark.option.args([1000])
|
||||
@google_benchmark.option.args([1000])
|
||||
@google_benchmark.option.args([2000])
|
||||
@google_benchmark.option.args([2000])
|
||||
def jit_simple_many_args_dispatch(state):
|
||||
with jax_config.jax_array(state.range(1)):
|
||||
args = [jax.device_put(i) for i in range(state.range(0))]
|
||||
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
|
||||
x = f(args)
|
||||
@ -221,17 +219,16 @@ def jit_simple_many_args_dispatch(state):
|
||||
x.block_until_ready()
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'jax_array'])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@google_benchmark.option.args([1000, False])
|
||||
@google_benchmark.option.args([1000, True])
|
||||
@google_benchmark.option.args([2000, False])
|
||||
@google_benchmark.option.args([2000, True])
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
@google_benchmark.option.args([100])
|
||||
@google_benchmark.option.args([1000])
|
||||
@google_benchmark.option.args([1000])
|
||||
@google_benchmark.option.args([2000])
|
||||
@google_benchmark.option.args([2000])
|
||||
def jit_simple_many_args(state):
|
||||
with jax_config.jax_array(state.range(1)):
|
||||
args = [jax.device_put(i) for i in range(state.range(0))]
|
||||
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
|
||||
f(args).block_until_ready()
|
||||
@ -296,12 +293,8 @@ def jit_dispatch_with_transfer(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_name('jax_array')
|
||||
@google_benchmark.option.arg(True)
|
||||
@google_benchmark.option.arg(False)
|
||||
@required_devices(2)
|
||||
def pmap_trivial_2_devices(state):
|
||||
with jax_config.jax_array(state.range(0)):
|
||||
f = jax.pmap(swap)
|
||||
a, b = f(jnp.array([1, 2]), jnp.array([3, 4]))
|
||||
|
||||
@ -312,12 +305,8 @@ def pmap_trivial_2_devices(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_name('jax_array')
|
||||
@google_benchmark.option.arg(True)
|
||||
@google_benchmark.option.arg(False)
|
||||
@required_devices(8)
|
||||
def pmap_trivial_dispatch_8_devices(state):
|
||||
with jax_config.jax_array(state.range(0)):
|
||||
f = jax.pmap(swap)
|
||||
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
|
||||
@ -327,12 +316,8 @@ def pmap_trivial_dispatch_8_devices(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_name('jax_array')
|
||||
@google_benchmark.option.arg(True)
|
||||
@google_benchmark.option.arg(False)
|
||||
@required_devices(8)
|
||||
def pmap_trivial_8_devices(state):
|
||||
with jax_config.jax_array(state.range(0)):
|
||||
f = jax.pmap(swap)
|
||||
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
|
||||
@ -344,12 +329,8 @@ def pmap_trivial_8_devices(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_name('jax_array')
|
||||
@google_benchmark.option.arg(True)
|
||||
@google_benchmark.option.arg(False)
|
||||
@required_devices(2)
|
||||
def pmap_simple_2_devices(state):
|
||||
with jax_config.jax_array(state.range(0)):
|
||||
f = jax.pmap(lambda a, b: (a + b, a - b))
|
||||
a, b = f(jnp.array([1, 2]), jnp.array([3, 4]))
|
||||
|
||||
@ -360,12 +341,8 @@ def pmap_simple_2_devices(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_name('jax_array')
|
||||
@google_benchmark.option.arg(True)
|
||||
@google_benchmark.option.arg(False)
|
||||
@required_devices(8)
|
||||
def pmap_simple_dispatch_8_devices(state):
|
||||
with jax_config.jax_array(state.range(0)):
|
||||
f = jax.pmap(lambda a, b: (a + b, a - b))
|
||||
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
|
||||
@ -375,12 +352,8 @@ def pmap_simple_dispatch_8_devices(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_name('jax_array')
|
||||
@google_benchmark.option.arg(True)
|
||||
@google_benchmark.option.arg(False)
|
||||
@required_devices(8)
|
||||
def pmap_simple_8_devices(state):
|
||||
with jax_config.jax_array(state.range(0)):
|
||||
f = jax.pmap(lambda a, b: (a + b, a - b))
|
||||
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
|
||||
@ -392,12 +365,8 @@ def pmap_simple_8_devices(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_name('jax_array')
|
||||
@google_benchmark.option.arg(True)
|
||||
@google_benchmark.option.arg(False)
|
||||
@required_devices(8)
|
||||
def pmap_simple_dispatch_8_devices_100_args(state):
|
||||
with jax_config.jax_array(state.range(0)):
|
||||
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
|
||||
args = []
|
||||
for i in range(100):
|
||||
@ -410,12 +379,8 @@ def pmap_simple_dispatch_8_devices_100_args(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_name('jax_array')
|
||||
@google_benchmark.option.arg(True)
|
||||
@google_benchmark.option.arg(False)
|
||||
@required_devices(8)
|
||||
def pmap_simple_8_devices_100_args(state):
|
||||
with jax_config.jax_array(state.range(0)):
|
||||
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
|
||||
args = []
|
||||
for i in range(100):
|
||||
@ -624,11 +589,11 @@ def bench_pjit_check_aval_sharding(state):
|
||||
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||
if mesh is None:
|
||||
return
|
||||
s = sharding.NamedSharding(mesh, pxla.PartitionSpec('x', 'y'))
|
||||
s = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
||||
aval = jax.core.ShapedArray((8, 2), np.int32)
|
||||
|
||||
while state:
|
||||
pjit_lib.pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
|
||||
pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@ -732,7 +697,6 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_simple_1_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
@ -745,7 +709,6 @@ def pjit_simple_1_device(state):
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_simple_4_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
@ -758,7 +721,6 @@ def pjit_simple_4_device(state):
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_simple_4000_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
@ -772,7 +734,6 @@ def pjit_simple_4000_device(state):
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_aot_1_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state,
|
||||
@ -790,7 +751,6 @@ def pjit_aot_1_device(state):
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_aot_4_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state,
|
||||
@ -808,7 +768,6 @@ def pjit_aot_4_device(state):
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_aot_4000_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state,
|
||||
@ -824,7 +783,7 @@ def host_local_array_to_global_array(state):
|
||||
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||
input_shape = (8, 2)
|
||||
input_data = np.arange(np.prod(input_shape)).reshape(input_shape)
|
||||
in_pspec = pxla.PartitionSpec('x', 'y')
|
||||
in_pspec = jax.sharding.PartitionSpec('x', 'y')
|
||||
|
||||
while state:
|
||||
multihost_utils.host_local_array_to_global_array(
|
||||
|
Loading…
x
Reference in New Issue
Block a user