diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 056509996..b9a8bb43b 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -29,6 +29,7 @@ from jax.interpreters import xla from jax.interpreters import pxla from jax._src import array from jax._src import sharding +from jax._src.pjit import pjit_check_aval_sharding from jax.experimental import pjit as pjit_lib from jax.experimental import multihost_utils import jax.numpy as jnp @@ -153,26 +154,24 @@ def jit_simple(state): @google_benchmark.register def jit_simple_dispatch_array(state): - with jax_config.jax_array(True): - a = jax.device_put(1) - b = jax.device_put(2) - f = jax.jit(operator.add) - f(a, b) + a = jax.device_put(1) + b = jax.device_put(2) + f = jax.jit(operator.add) + f(a, b) - while state: - f(a, b) + while state: + f(a, b) @google_benchmark.register def jit_simple_array(state): - with jax_config.jax_array(True): - a = jax.device_put(1) - b = jax.device_put(2) - f = jax.jit(operator.add) - f(a, b) + a = jax.device_put(1) + b = jax.device_put(2) + f = jax.jit(operator.add) + f(a, b) - while state: - f(a, b).block_until_ready() + while state: + f(a, b).block_until_ready() @google_benchmark.register @@ -200,44 +199,42 @@ def jit_big_matmul(state): @google_benchmark.register -@google_benchmark.option.arg_names(['num_args', 'jax_array']) -@google_benchmark.option.args([10, False]) -@google_benchmark.option.args([10, True]) -@google_benchmark.option.args([100, False]) -@google_benchmark.option.args([100, True]) -@google_benchmark.option.args([1000, False]) -@google_benchmark.option.args([1000, True]) -@google_benchmark.option.args([2000, False]) -@google_benchmark.option.args([2000, True]) +@google_benchmark.option.arg_names(['num_args']) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([1000]) +@google_benchmark.option.args([1000]) +@google_benchmark.option.args([2000]) +@google_benchmark.option.args([2000]) def jit_simple_many_args_dispatch(state): - with jax_config.jax_array(state.range(1)): - args = [jax.device_put(i) for i in range(state.range(0))] - f = jax.jit(lambda xs: functools.reduce(operator.add, xs)) - x = f(args) - x.block_until_ready() + args = [jax.device_put(i) for i in range(state.range(0))] + f = jax.jit(lambda xs: functools.reduce(operator.add, xs)) + x = f(args) + x.block_until_ready() - while state: - x = f(args) - x.block_until_ready() + while state: + x = f(args) + x.block_until_ready() @google_benchmark.register -@google_benchmark.option.arg_names(['num_args', 'jax_array']) -@google_benchmark.option.args([10, False]) -@google_benchmark.option.args([10, True]) -@google_benchmark.option.args([100, False]) -@google_benchmark.option.args([100, True]) -@google_benchmark.option.args([1000, False]) -@google_benchmark.option.args([1000, True]) -@google_benchmark.option.args([2000, False]) -@google_benchmark.option.args([2000, True]) +@google_benchmark.option.arg_names(['num_args']) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([1000]) +@google_benchmark.option.args([1000]) +@google_benchmark.option.args([2000]) +@google_benchmark.option.args([2000]) def jit_simple_many_args(state): - with jax_config.jax_array(state.range(1)): - args = [jax.device_put(i) for i in range(state.range(0))] - f = jax.jit(lambda xs: functools.reduce(operator.add, xs)) - f(args).block_until_ready() + args = [jax.device_put(i) for i in range(state.range(0))] + f = jax.jit(lambda xs: functools.reduce(operator.add, xs)) + f(args).block_until_ready() - while state: - f(args).block_until_ready() + while state: + f(args).block_until_ready() def jit_simple_pruned_args_dispatch(n, state): args = [jax.device_put(i) for i in range(n)] @@ -296,137 +293,105 @@ def jit_dispatch_with_transfer(state): @google_benchmark.register -@google_benchmark.option.arg_name('jax_array') -@google_benchmark.option.arg(True) -@google_benchmark.option.arg(False) @required_devices(2) def pmap_trivial_2_devices(state): - with jax_config.jax_array(state.range(0)): - f = jax.pmap(swap) - a, b = f(jnp.array([1, 2]), jnp.array([3, 4])) + f = jax.pmap(swap) + a, b = f(jnp.array([1, 2]), jnp.array([3, 4])) - while state: - c, d = f(a, b) - c.block_until_ready() - d.block_until_ready() + while state: + c, d = f(a, b) + c.block_until_ready() + d.block_until_ready() @google_benchmark.register -@google_benchmark.option.arg_name('jax_array') -@google_benchmark.option.arg(True) -@google_benchmark.option.arg(False) @required_devices(8) def pmap_trivial_dispatch_8_devices(state): - with jax_config.jax_array(state.range(0)): - f = jax.pmap(swap) - a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]), - jnp.array([2, 3, 4, 5, 6, 7, 8, 9])) + f = jax.pmap(swap) + a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]), + jnp.array([2, 3, 4, 5, 6, 7, 8, 9])) - while state: - a, b = f(a, b) + while state: + a, b = f(a, b) @google_benchmark.register -@google_benchmark.option.arg_name('jax_array') -@google_benchmark.option.arg(True) -@google_benchmark.option.arg(False) @required_devices(8) def pmap_trivial_8_devices(state): - with jax_config.jax_array(state.range(0)): - f = jax.pmap(swap) - a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]), - jnp.array([2, 3, 4, 5, 6, 7, 8, 9])) + f = jax.pmap(swap) + a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]), + jnp.array([2, 3, 4, 5, 6, 7, 8, 9])) - while state: - c, d = f(a, b) - c.block_until_ready() - d.block_until_ready() + while state: + c, d = f(a, b) + c.block_until_ready() + d.block_until_ready() @google_benchmark.register -@google_benchmark.option.arg_name('jax_array') -@google_benchmark.option.arg(True) -@google_benchmark.option.arg(False) @required_devices(2) def pmap_simple_2_devices(state): - with jax_config.jax_array(state.range(0)): - f = jax.pmap(lambda a, b: (a + b, a - b)) - a, b = f(jnp.array([1, 2]), jnp.array([3, 4])) + f = jax.pmap(lambda a, b: (a + b, a - b)) + a, b = f(jnp.array([1, 2]), jnp.array([3, 4])) - while state: - c, d = f(a, b) - c.block_until_ready() - d.block_until_ready() + while state: + c, d = f(a, b) + c.block_until_ready() + d.block_until_ready() @google_benchmark.register -@google_benchmark.option.arg_name('jax_array') -@google_benchmark.option.arg(True) -@google_benchmark.option.arg(False) @required_devices(8) def pmap_simple_dispatch_8_devices(state): - with jax_config.jax_array(state.range(0)): - f = jax.pmap(lambda a, b: (a + b, a - b)) - a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]), - jnp.array([2, 3, 4, 5, 6, 7, 8, 9])) + f = jax.pmap(lambda a, b: (a + b, a - b)) + a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]), + jnp.array([2, 3, 4, 5, 6, 7, 8, 9])) - while state: - a, b = f(a, b) + while state: + a, b = f(a, b) @google_benchmark.register -@google_benchmark.option.arg_name('jax_array') -@google_benchmark.option.arg(True) -@google_benchmark.option.arg(False) @required_devices(8) def pmap_simple_8_devices(state): - with jax_config.jax_array(state.range(0)): - f = jax.pmap(lambda a, b: (a + b, a - b)) - a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]), - jnp.array([2, 3, 4, 5, 6, 7, 8, 9])) + f = jax.pmap(lambda a, b: (a + b, a - b)) + a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]), + jnp.array([2, 3, 4, 5, 6, 7, 8, 9])) - while state: - c, d = f(a, b) - c.block_until_ready() - d.block_until_ready() + while state: + c, d = f(a, b) + c.block_until_ready() + d.block_until_ready() @google_benchmark.register -@google_benchmark.option.arg_name('jax_array') -@google_benchmark.option.arg(True) -@google_benchmark.option.arg(False) @required_devices(8) def pmap_simple_dispatch_8_devices_100_args(state): - with jax_config.jax_array(state.range(0)): - f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,)) - args = [] - for i in range(100): - args.append(jnp.array(list(range(i, i+8)))) + f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,)) + args = [] + for i in range(100): + args.append(jnp.array(list(range(i, i+8)))) + args = f(*args) + + while state: args = f(*args) - while state: - args = f(*args) - @google_benchmark.register -@google_benchmark.option.arg_name('jax_array') -@google_benchmark.option.arg(True) -@google_benchmark.option.arg(False) @required_devices(8) def pmap_simple_8_devices_100_args(state): - with jax_config.jax_array(state.range(0)): - f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,)) - args = [] - for i in range(100): - args.append(jnp.array(list(range(i, i+8)))) + f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,)) + args = [] + for i in range(100): + args.append(jnp.array(list(range(i, i+8)))) - # Warmup loop. + # Warmup loop. + out = f(*args) + + while state: out = f(*args) - - while state: - out = f(*args) - jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) def _run_sda_index_bench(state, num_devices): @@ -624,11 +589,11 @@ def bench_pjit_check_aval_sharding(state): mesh = create_mesh((4, 2), ('x', 'y'), state) if mesh is None: return - s = sharding.NamedSharding(mesh, pxla.PartitionSpec('x', 'y')) + s = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')) aval = jax.core.ShapedArray((8, 2), np.int32) while state: - pjit_lib.pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False) + pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False) @google_benchmark.register @@ -732,7 +697,6 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False): @google_benchmark.option.args([10, True]) @google_benchmark.option.args([100, False]) @google_benchmark.option.args([100, True]) -@jax_config.jax_array(True) def pjit_simple_1_device(state): pjit_simple_benchmark( state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1)) @@ -745,7 +709,6 @@ def pjit_simple_1_device(state): @google_benchmark.option.args([10, True]) @google_benchmark.option.args([100, False]) @google_benchmark.option.args([100, True]) -@jax_config.jax_array(True) def pjit_simple_4_device(state): pjit_simple_benchmark( state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1)) @@ -758,7 +721,6 @@ def pjit_simple_4_device(state): @google_benchmark.option.args([10, True]) @google_benchmark.option.args([100, False]) @google_benchmark.option.args([100, True]) -@jax_config.jax_array(True) def pjit_simple_4000_device(state): pjit_simple_benchmark( state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1)) @@ -772,7 +734,6 @@ def pjit_simple_4000_device(state): @google_benchmark.option.args([10, True]) @google_benchmark.option.args([100, False]) @google_benchmark.option.args([100, True]) -@jax_config.jax_array(True) def pjit_aot_1_device(state): pjit_simple_benchmark( state, @@ -790,7 +751,6 @@ def pjit_aot_1_device(state): @google_benchmark.option.args([10, True]) @google_benchmark.option.args([100, False]) @google_benchmark.option.args([100, True]) -@jax_config.jax_array(True) def pjit_aot_4_device(state): pjit_simple_benchmark( state, @@ -808,7 +768,6 @@ def pjit_aot_4_device(state): @google_benchmark.option.args([10, True]) @google_benchmark.option.args([100, False]) @google_benchmark.option.args([100, True]) -@jax_config.jax_array(True) def pjit_aot_4000_device(state): pjit_simple_benchmark( state, @@ -824,7 +783,7 @@ def host_local_array_to_global_array(state): global_mesh = create_mesh((4, 2), ('x', 'y'), state) input_shape = (8, 2) input_data = np.arange(np.prod(input_shape)).reshape(input_shape) - in_pspec = pxla.PartitionSpec('x', 'y') + in_pspec = jax.sharding.PartitionSpec('x', 'y') while state: multihost_utils.host_local_array_to_global_array(