mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

To prepare for the upcoming `BatchedDevicePut` implementation changes, this change makes `make_array_from_callback_*` benchmark code to be more homogeneous. Also it adds a variant that uses a partially replicated sharding. PiperOrigin-RevId: 736665856
954 lines
24 KiB
Python
954 lines
24 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Microbenchmarks for JAX `api` functions."""
|
|
|
|
import enum
|
|
import functools
|
|
import math
|
|
import operator
|
|
|
|
import google_benchmark
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import array
|
|
from jax._src import core
|
|
from jax._src import op_shardings
|
|
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.pjit import pjit_check_aval_sharding
|
|
from jax.experimental import multihost_utils
|
|
from jax.experimental import pjit as pjit_lib
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
partial = functools.partial
|
|
|
|
def required_devices(num_devices_required):
|
|
"""Helper to skip benchmarks that require more devices."""
|
|
def helper1(f):
|
|
@functools.wraps(f)
|
|
def helper2(state):
|
|
if jax.device_count() < num_devices_required:
|
|
state.skip_with_error(f"requires {num_devices_required} devices")
|
|
return
|
|
return f(state)
|
|
return helper2
|
|
return helper1
|
|
|
|
|
|
def create_mesh(shape, axis_names, state):
|
|
size = math.prod(shape)
|
|
if len(jax.devices()) < size:
|
|
state.skip_with_error(f"Requires {size} devices")
|
|
return None
|
|
devices = sorted(jax.devices(), key=lambda d: d.id)
|
|
mesh_devices = np.array(devices[:size]).reshape(shape)
|
|
global_mesh = jax.sharding.Mesh(mesh_devices, axis_names)
|
|
return global_mesh
|
|
|
|
|
|
def swap(a, b):
|
|
return b, a
|
|
|
|
|
|
class AnEnum(enum.IntEnum):
|
|
A = 123
|
|
B = 456
|
|
|
|
@google_benchmark.register
|
|
def eager_unary_dispatch(state):
|
|
a = jax.device_put(1)
|
|
lax.neg(a)
|
|
while state:
|
|
lax.neg(a)
|
|
|
|
|
|
@google_benchmark.register
|
|
def eager_unary(state):
|
|
a = jax.device_put(1)
|
|
lax.neg(a).block_until_ready()
|
|
while state:
|
|
lax.neg(a).block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
def eager_binary_dispatch(state):
|
|
a = jax.device_put(1)
|
|
b = jax.device_put(2)
|
|
lax.add(a, b)
|
|
while state:
|
|
lax.add(a, b)
|
|
|
|
|
|
@google_benchmark.register
|
|
def eager_binary(state):
|
|
a = jax.device_put(1)
|
|
b = jax.device_put(2)
|
|
lax.add(a, b).block_until_ready()
|
|
while state:
|
|
lax.add(a, b).block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_trivial_dispatch(state):
|
|
"""Benchmarks only the duration for jitted_f to return the future."""
|
|
f = jax.jit(swap)
|
|
a, b = f(1, 2)
|
|
x = f(a, b)
|
|
while state:
|
|
x = f(a, b)
|
|
x[0].block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_trivial(state):
|
|
f = jax.jit(swap)
|
|
a, b = f(1, 2)
|
|
f(a, b)
|
|
|
|
while state:
|
|
c, d = f(a, b)
|
|
c.block_until_ready()
|
|
d.block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_simple_dispatch(state):
|
|
a = jax.device_put(1)
|
|
b = jax.device_put(2)
|
|
f = jax.jit(operator.add)
|
|
f(a, b)
|
|
|
|
while state:
|
|
f(a, b)
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_simple(state):
|
|
a = jax.device_put(1)
|
|
b = jax.device_put(2)
|
|
f = jax.jit(operator.add)
|
|
f(a, b)
|
|
|
|
while state:
|
|
f(a, b).block_until_ready()
|
|
|
|
@google_benchmark.register
|
|
def jit_simple_dispatch_array(state):
|
|
a = jax.device_put(1)
|
|
b = jax.device_put(2)
|
|
f = jax.jit(operator.add)
|
|
f(a, b)
|
|
|
|
while state:
|
|
f(a, b)
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_simple_array(state):
|
|
a = jax.device_put(1)
|
|
b = jax.device_put(2)
|
|
f = jax.jit(operator.add)
|
|
f(a, b)
|
|
|
|
while state:
|
|
f(a, b).block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_small_matmul(state):
|
|
x = np.random.uniform(size=(2, 2)).astype(np.float32)
|
|
x = jax.device_put(x)
|
|
|
|
f = jax.jit(lambda x: jnp.dot(x, x))
|
|
f(x).block_until_ready()
|
|
|
|
while state:
|
|
f(x).block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_big_matmul(state):
|
|
x = np.random.uniform(size=(100, 100)).astype(np.float32)
|
|
x = jax.device_put(x)
|
|
|
|
f = jax.jit(lambda x: jnp.dot(x, x))
|
|
f(x).block_until_ready()
|
|
|
|
while state:
|
|
f(x).block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
@google_benchmark.option.args([100])
|
|
@google_benchmark.option.args([1000])
|
|
@google_benchmark.option.args([1000])
|
|
@google_benchmark.option.args([2000])
|
|
@google_benchmark.option.args([2000])
|
|
def jit_simple_many_args_dispatch(state):
|
|
args = [jax.device_put(i) for i in range(state.range(0))]
|
|
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
|
|
x = f(args)
|
|
x.block_until_ready()
|
|
|
|
while state:
|
|
x = f(args)
|
|
x.block_until_ready()
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
@google_benchmark.option.args([100])
|
|
@google_benchmark.option.args([1000])
|
|
@google_benchmark.option.args([1000])
|
|
@google_benchmark.option.args([2000])
|
|
@google_benchmark.option.args([2000])
|
|
def jit_simple_many_args(state):
|
|
args = [jax.device_put(i) for i in range(state.range(0))]
|
|
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
|
|
f(args).block_until_ready()
|
|
|
|
while state:
|
|
f(args).block_until_ready()
|
|
|
|
def jit_simple_pruned_args_dispatch(n, state):
|
|
args = [jax.device_put(i) for i in range(n)]
|
|
f = jax.jit(lambda *xs: xs[0] + 1)
|
|
x = f(*args)
|
|
x.block_until_ready()
|
|
|
|
while state:
|
|
x = f(*args)
|
|
x.block_until_ready()
|
|
|
|
|
|
def jit_simple_pruned_args(n, state):
|
|
args = [jax.device_put(i) for i in range(n)]
|
|
f = jax.jit(lambda *xs: xs[0] + 1)
|
|
x = f(*args)
|
|
x.block_until_ready()
|
|
|
|
while state:
|
|
f(*args).block_until_ready()
|
|
|
|
benchmarks = []
|
|
for n in [10, 100, 1000, 2000]:
|
|
benchmarks += [
|
|
google_benchmark.register(partial(jit_simple_pruned_args_dispatch, n),
|
|
name=f"jit_simple_pruned_args_dispatch_{n}"),
|
|
google_benchmark.register(partial(jit_simple_pruned_args, n),
|
|
name=f"jit_simple_pruned_args_{n}")
|
|
]
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_dispatch_without_transfer(state):
|
|
# We pick up a realistic input. 224 is usual for classification and 128 a
|
|
# TPU-friendly batch-size.
|
|
imgs = np.ones((128, 224, 224), np.float32)
|
|
imgs = jax.device_put(imgs)
|
|
|
|
f = jax.jit(lambda x: x+1)
|
|
f(imgs)
|
|
|
|
while state:
|
|
f(imgs)
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_dispatch_with_transfer(state):
|
|
imgs = np.ones((128, 224, 224), np.float32)
|
|
|
|
f = jax.jit(lambda x: x+1)
|
|
f(imgs).block_until_ready()
|
|
|
|
while state:
|
|
x = f(imgs)
|
|
x.block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(2)
|
|
def pmap_trivial_2_devices(state):
|
|
f = jax.pmap(swap)
|
|
a, b = f(jnp.array([1, 2]), jnp.array([3, 4]))
|
|
|
|
while state:
|
|
c, d = f(a, b)
|
|
c.block_until_ready()
|
|
d.block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def pmap_trivial_dispatch_8_devices(state):
|
|
f = jax.pmap(swap)
|
|
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
|
|
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
|
|
|
|
while state:
|
|
a, b = f(a, b)
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def pmap_trivial_8_devices(state):
|
|
f = jax.pmap(swap)
|
|
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
|
|
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
|
|
|
|
while state:
|
|
c, d = f(a, b)
|
|
c.block_until_ready()
|
|
d.block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(2)
|
|
def pmap_simple_2_devices(state):
|
|
f = jax.pmap(lambda a, b: (a + b, a - b))
|
|
a, b = f(jnp.array([1, 2]), jnp.array([3, 4]))
|
|
|
|
while state:
|
|
c, d = f(a, b)
|
|
c.block_until_ready()
|
|
d.block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def pmap_simple_dispatch_8_devices(state):
|
|
f = jax.pmap(lambda a, b: (a + b, a - b))
|
|
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
|
|
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
|
|
|
|
while state:
|
|
a, b = f(a, b)
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def pmap_simple_8_devices(state):
|
|
f = jax.pmap(lambda a, b: (a + b, a - b))
|
|
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
|
|
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
|
|
|
|
while state:
|
|
c, d = f(a, b)
|
|
c.block_until_ready()
|
|
d.block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def pmap_simple_dispatch_8_devices_100_args(state):
|
|
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
|
|
args = []
|
|
for i in range(100):
|
|
args.append(jnp.array(list(range(i, i+8))))
|
|
|
|
args = f(*args)
|
|
|
|
while state:
|
|
args = f(*args)
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def pmap_simple_8_devices_100_args(state):
|
|
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
|
|
args = []
|
|
for i in range(100):
|
|
args.append(jnp.array(list(range(i, i+8))))
|
|
|
|
# Warmup loop.
|
|
out = f(*args)
|
|
|
|
while state:
|
|
out = f(*args)
|
|
jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
|
|
|
|
|
|
def _run_sda_index_bench(state, num_devices):
|
|
x = jax.pmap(jnp.sin)(jnp.arange(num_devices))
|
|
jax.device_get(x)
|
|
while state:
|
|
for i in range(num_devices):
|
|
_ = x[i]
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(1)
|
|
def sda_index_1(state):
|
|
_run_sda_index_bench(state, 1)
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(2)
|
|
def sda_index_2(state):
|
|
_run_sda_index_bench(state, 2)
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def sda_index_8(state):
|
|
_run_sda_index_bench(state, 8)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_shaped_abstractify(state):
|
|
device, *_ = jax.devices()
|
|
args = [jax.device_put_replicated(1, [device])] * 1000
|
|
while state:
|
|
_ = [core.shaped_abstractify(x) for x in args]
|
|
|
|
|
|
def _run_benchmark_for_xla_abstractify(arg, state):
|
|
while state:
|
|
core.abstractify(arg)
|
|
|
|
def bench_xla_abstractify():
|
|
_abstractify_args = [
|
|
(3, 'scalar_int'),
|
|
(3.5, 'scalar_float'),
|
|
(np.int32(3), 'scalar_numpy_int32'),
|
|
(np.uint32(7), 'scalar_numpy_uint32'),
|
|
(np.random.randn(3, 4, 5, 6), 'numpy_random'),
|
|
(np.arange(100, dtype=np.float32), 'numpy_arange_100_float32'),
|
|
(AnEnum.B, 'enum'),
|
|
]
|
|
benchmarks = []
|
|
for a, name in _abstractify_args:
|
|
benchmarks.extend([
|
|
google_benchmark.register(
|
|
partial(_run_benchmark_for_xla_abstractify, a),
|
|
name=f'bench_xla_abstractify_{name}'),
|
|
])
|
|
bench_xla_abstractify()
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMicrosecond)
|
|
def bench_are_op_shardings_equal(state):
|
|
op1 = xc.OpSharding()
|
|
op1.type = xc.OpSharding.Type.OTHER
|
|
op1.tile_assignment_dimensions = [4, 192, 16]
|
|
op1.tile_assignment_devices = list(range(12288))
|
|
|
|
op2 = xc.OpSharding()
|
|
op2.type = xc.OpSharding.Type.OTHER
|
|
op2.tile_assignment_dimensions = [4, 192, 16]
|
|
op2.tile_assignment_devices = list(range(12288))
|
|
|
|
while state:
|
|
op_shardings.are_op_shardings_equal(op1, op2)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_pjit_check_aval_sharding(state):
|
|
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
|
if mesh is None:
|
|
return
|
|
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
|
aval = jax.core.ShapedArray((8, 2), np.int32)
|
|
|
|
while state:
|
|
pjit_check_aval_sharding([s] * 100, [aval] * 100, [''] * 100, 'benchmark', False)
|
|
|
|
|
|
@google_benchmark.register
|
|
def bench_addressable_shards_index(state):
|
|
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
|
if mesh is None:
|
|
return
|
|
shape = (8, 2)
|
|
inp = np.arange(math.prod(shape)).reshape(shape)
|
|
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
|
arr = jax.device_put(inp, s)
|
|
|
|
while state:
|
|
[s.index for s in arr.addressable_shards]
|
|
|
|
|
|
@google_benchmark.register
|
|
def bench_addressable_shards_replica_id(state):
|
|
mesh = create_mesh((32, 16), ('x', 'y'), state)
|
|
if mesh is None:
|
|
return
|
|
shape = (64, 32)
|
|
inp = np.arange(math.prod(shape)).reshape(shape)
|
|
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
|
arr = jax.device_put(inp, s)
|
|
|
|
while state:
|
|
[s.replica_id for s in arr.addressable_shards]
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_remat_eager_retracing_overheads(state):
|
|
def double_compose(f):
|
|
return lambda x: f(f(x))
|
|
|
|
f = jnp.sin
|
|
for _ in range(6):
|
|
f = double_compose(f)
|
|
f = double_compose(checkpoint(f))
|
|
|
|
while state:
|
|
y, _ = jax.vjp(f, 3.)
|
|
y.block_until_ready()
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_remat_eager_retracing_overheads_static_argnums(state):
|
|
def double_compose(f):
|
|
return lambda x, y: f(f(x, y), y)
|
|
|
|
f = lambda x, _: jnp.sin(x)
|
|
for _ in range(6):
|
|
f = double_compose(f)
|
|
f = double_compose(checkpoint(f, static_argnums=(1,)))
|
|
|
|
while state:
|
|
y, _ = jax.vjp(f, 3., True)
|
|
y.block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_slicing_compilation(state):
|
|
x = jnp.arange(3)
|
|
while state:
|
|
jax.jit(lambda x: (x[0], x[1], x[2])).lower(x).compile()
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_slicing_compilation2(state):
|
|
x = jnp.arange(3)
|
|
while state:
|
|
jax.jit(lambda x: (x[:1], x[1:2], x[2:3])).lower(x).compile()
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_repeated_static_indexing(state):
|
|
x = jnp.arange(500)
|
|
while state:
|
|
jax.block_until_ready([x[i] for i in range(500)])
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_repeated_static_slicing(state):
|
|
x = jnp.arange(1000)
|
|
while state:
|
|
jax.block_until_ready([x[i:i + 2] for i in range(0, 1000, 2)])
|
|
|
|
def pjit_simple_benchmark(state, num_devices, num_args, use_aot=False):
|
|
spec = jax.sharding.PartitionSpec('x')
|
|
mesh = create_mesh((num_devices,), ('x',), state)
|
|
if mesh is None:
|
|
return
|
|
s = jax.sharding.NamedSharding(mesh, spec)
|
|
inp_data = np.arange(num_devices).astype(np.float32)
|
|
x = array.make_array_from_callback(inp_data.shape, s, lambda idx: inp_data[idx])
|
|
|
|
x = [x for _ in range(num_args)]
|
|
|
|
in_axis_resources = jax.sharding.NamedSharding(mesh, spec)
|
|
out_axis_resources = jax.sharding.NamedSharding(mesh, spec)
|
|
|
|
f = pjit_lib.pjit(
|
|
lambda x: jax.tree.map(lambda x: x + 1, x),
|
|
in_shardings=in_axis_resources,
|
|
out_shardings=out_axis_resources,
|
|
)
|
|
|
|
if use_aot:
|
|
f = f.lower(x).compile()
|
|
|
|
x = f(x)
|
|
|
|
while state:
|
|
x = f(x)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([1])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
def pjit_simple_1_device(state):
|
|
pjit_simple_benchmark(state, num_devices=1, num_args=state.range(0))
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([1])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
def pjit_simple_4_device(state):
|
|
pjit_simple_benchmark(state, num_devices=4, num_args=state.range(0))
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([1])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
def pjit_simple_4000_device(state):
|
|
pjit_simple_benchmark(state, num_devices=4000, num_args=state.range(0))
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([1])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
def pjit_aot_1_device(state):
|
|
pjit_simple_benchmark(
|
|
state,
|
|
num_devices=1,
|
|
num_args=state.range(0),
|
|
use_aot=True)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([1])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
def pjit_aot_4_device(state):
|
|
pjit_simple_benchmark(
|
|
state,
|
|
num_devices=4,
|
|
num_args=state.range(0),
|
|
use_aot=True)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([1])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
def pjit_aot_4000_device(state):
|
|
pjit_simple_benchmark(
|
|
state,
|
|
num_devices=4000,
|
|
num_args=state.range(0),
|
|
use_aot=True)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def host_local_array_to_global_array(state):
|
|
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
|
|
input_shape = (8, 2)
|
|
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
|
in_pspec = jax.sharding.PartitionSpec('x', 'y')
|
|
|
|
while state:
|
|
multihost_utils.host_local_array_to_global_array(
|
|
(input_data, input_data), global_mesh, (in_pspec, in_pspec))
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([1])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
@google_benchmark.option.args([1000])
|
|
def device_put_from_numpy_array(state):
|
|
x = [np.array(1, np.int32)] * state.range(0)
|
|
while state:
|
|
_ = jax.block_until_ready(jax.device_put(x))
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['num_args'])
|
|
@google_benchmark.option.args([1])
|
|
@google_benchmark.option.args([10])
|
|
@google_benchmark.option.args([100])
|
|
@google_benchmark.option.args([1000])
|
|
def device_put_from_jax_array(state):
|
|
if len(jax.devices()) < 2:
|
|
state.skip_with_error('requires 2 devices')
|
|
x = [np.array(1, np.int32)] * state.range(0)
|
|
x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0]))
|
|
d = jax.devices()[1]
|
|
while state:
|
|
_ = jax.block_until_ready(jax.device_put(x, device=d))
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def device_put_big(state):
|
|
x = np.arange(4000 * 10**6 // np.dtype('float32').itemsize, dtype=np.float32)
|
|
jax.device_put(x).block_until_ready()
|
|
|
|
while state:
|
|
_ = jax.device_put(x).block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
def device_put_sharded(state):
|
|
arr_inp = [np.array(i) for i in range(jax.device_count())]
|
|
dev = jax.devices()
|
|
|
|
while state:
|
|
_ = jax.device_put_sharded(arr_inp, dev).block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def device_get_8_devices(state):
|
|
mesh = jax.sharding.Mesh(
|
|
np.array(jax.devices()[:8]).reshape((4, 2)), ('x', 'y')
|
|
)
|
|
sharding = jax.sharding.NamedSharding(
|
|
mesh, jax.sharding.PartitionSpec('x', 'y')
|
|
)
|
|
inp = jax.device_put(np.zeros((8, 4), dtype=np.float32), sharding)
|
|
|
|
@jax.jit
|
|
def fn(x):
|
|
y = x + x
|
|
return [y for _ in range(50)]
|
|
|
|
jax.device_get(fn(inp))
|
|
|
|
while state:
|
|
jax.device_get(fn(inp))
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def np_asarray_8_devices(state):
|
|
mesh = jax.sharding.Mesh(
|
|
np.array(jax.devices()[:8]).reshape((4, 2)), ('x', 'y')
|
|
)
|
|
sharding = jax.sharding.NamedSharding(
|
|
mesh, jax.sharding.PartitionSpec('x', 'y')
|
|
)
|
|
inp = jax.device_put(np.zeros((8, 4), dtype=np.float32), sharding)
|
|
|
|
@jax.jit
|
|
def fn(x):
|
|
y = x + x
|
|
return [y for _ in range(50)]
|
|
|
|
jax.device_get(fn(inp))
|
|
|
|
while state:
|
|
[np.asarray(x) for x in fn(inp)]
|
|
|
|
|
|
@google_benchmark.register
|
|
@required_devices(8)
|
|
def jax_array_arrays_8_devices(state):
|
|
mesh = jax.sharding.Mesh(
|
|
np.array(jax.devices()[:8]).reshape((4, 2)), ('x', 'y')
|
|
)
|
|
sharding = jax.sharding.NamedSharding(
|
|
mesh, jax.sharding.PartitionSpec('x', 'y')
|
|
)
|
|
inp = jax.device_put(np.zeros((8, 4), dtype=np.float32), sharding)
|
|
|
|
@jax.jit
|
|
def fn(x):
|
|
y = x + x
|
|
return [y for _ in range(200)]
|
|
|
|
jax.device_get(fn(inp))
|
|
|
|
while state:
|
|
[x._arrays for x in fn(inp)]
|
|
|
|
|
|
def batch_inplace_while(inplace_op, state):
|
|
|
|
@jax.jit
|
|
@jax.vmap
|
|
def f(init_step, init_xs):
|
|
|
|
def cond(carry):
|
|
step, xs = carry
|
|
return step < xs.size
|
|
|
|
def body(carry):
|
|
step, xs = carry
|
|
if inplace_op == 'scatter':
|
|
xs = xs.at[step].set(1)
|
|
elif inplace_op == 'dynamic_update_slice':
|
|
xs = lax.dynamic_update_index_in_dim(xs, 1., step, 0)
|
|
else:
|
|
assert False
|
|
return step + 1, xs
|
|
|
|
return lax.while_loop(cond, body, (init_step, init_xs))
|
|
|
|
size = 100_000
|
|
args = jnp.array([0]), jnp.zeros((1, size))
|
|
jax.block_until_ready(f(*args)) # compile
|
|
while state:
|
|
jax.block_until_ready(f(*args))
|
|
|
|
|
|
google_benchmark.register(
|
|
partial(batch_inplace_while, 'scatter'), name='batch_inplace_while_scatter')
|
|
google_benchmark.register(
|
|
partial(batch_inplace_while, 'dynamic_update_slice'),
|
|
name='batch_inplace_while_dynamic_update_slice')
|
|
|
|
|
|
@google_benchmark.register
|
|
def serial_dot_products(state):
|
|
SIZE = 50
|
|
|
|
@jax.jit
|
|
@jax.vmap
|
|
@jax.grad
|
|
def f(x):
|
|
out = 0
|
|
for i in range(SIZE):
|
|
y = x @ jnp.array([i, i + 1], dtype=jnp.float32)
|
|
out = out + y * x[0]
|
|
return out
|
|
|
|
x = jax.random.normal(jax.random.key(0), (2, 2))
|
|
f(x).block_until_ready() # compile
|
|
while state:
|
|
f(x).block_until_ready()
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['arg_lengths', 'num_args'])
|
|
@google_benchmark.option.args_product([[0, 1, 2, 5, 10, 100], [1, 2, 3]])
|
|
def safe_map(state):
|
|
args = tuple(list(range(state.range(0))) for _ in range(state.range(1)))
|
|
def f(*args): return tuple(args)
|
|
while state:
|
|
jax.util.safe_map(f, *args)
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.arg_names(['arg_lengths', 'num_args'])
|
|
@google_benchmark.option.args_product([[0, 1, 2, 5, 10, 100], [1, 2, 3]])
|
|
def safe_zip(state):
|
|
args = tuple(list(range(state.range(0))) for _ in range(state.range(1)))
|
|
while state:
|
|
jax.util.safe_zip(*args)
|
|
|
|
|
|
@google_benchmark.register
|
|
def bench_make_array_from_callback_fully_replicated_sharding(state):
|
|
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
|
if mesh is None:
|
|
return
|
|
input_shape = (8, 2)
|
|
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
|
|
|
|
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
|
while state:
|
|
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_make_array_from_callback_partially_replicated_sharding(state):
|
|
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
|
if mesh is None:
|
|
return
|
|
input_shape = (8, 2)
|
|
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
|
|
|
|
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None, 'y'))
|
|
while state:
|
|
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def bench_make_array_from_callback_fully_sharded_sharding(state):
|
|
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
|
if mesh is None:
|
|
return
|
|
input_shape = (8, 2)
|
|
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
|
|
|
|
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
|
while state:
|
|
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def benchmark_lorentz63_cache_hits(state):
|
|
@jax.jit
|
|
def lorentz63(state, dt=0.01, sigma=10, beta=8/3, rho=28):
|
|
x, y, z = state
|
|
x_t = sigma * (y - x)
|
|
y_t = (rho - z) * x - y
|
|
z_t = x * y - beta * z
|
|
return jnp.array([x + x_t * dt, y + y_t * dt, z + z_t * dt])
|
|
|
|
def training_step(initial_conditions, steps=1, unroll=False):
|
|
def forward_sim(x0):
|
|
if unroll:
|
|
x = x0
|
|
for _ in range(steps):
|
|
x = lorentz63(x)
|
|
return x
|
|
else:
|
|
return jax.lax.fori_loop(0, steps, lambda _, x: lorentz63(x), x0)
|
|
|
|
def loss(x0):
|
|
out = jax.vmap(jax.remat(forward_sim))(x0)
|
|
return jnp.square(out).sum()
|
|
|
|
return jax.value_and_grad(loss)(initial_conditions)
|
|
|
|
x = jnp.ones((8, 3))
|
|
while state:
|
|
jax.make_jaxpr(lambda x: training_step(x, 100, unroll=True))(x)
|
|
|
|
|
|
@google_benchmark.register
|
|
def jit_add_chain(state):
|
|
SIZE = 100
|
|
|
|
@jax.jit
|
|
def g(x, y):
|
|
return lax.add(x, y)
|
|
|
|
x = jax.random.normal(jax.random.key(0), (2, 2))
|
|
while state:
|
|
@jax.jit
|
|
def f(x):
|
|
for i in range(SIZE):
|
|
x = g(x, x)
|
|
return x
|
|
f(x).block_until_ready()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
google_benchmark.main()
|