rocm_jax/benchmarks/api_benchmark.py
Peter Hawkins 428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00

876 lines
23 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Microbenchmarks for JAX `api` functions."""
import enum
import functools
import operator
import google_benchmark
import jax
from jax import lax
from jax._src import config as jax_config
from jax.experimental import sparse
from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src.lib import xla_client as xc
from jax.interpreters import xla
from jax.interpreters import pxla
from jax._src import array
from jax._src import sharding
from jax.experimental import pjit as pjit_lib
from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np
from jax.config import config
config.parse_flags_with_absl()
partial = functools.partial
def required_devices(num_devices_required):
"""Helper to skip benchmarks that require more devices."""
def helper1(f):
@functools.wraps(f)
def helper2(state):
if jax.device_count() < num_devices_required:
state.skip_with_error(f"requires {num_devices_required} devices")
return
return f(state)
return helper2
return helper1
def create_mesh(shape, axis_names, state):
size = np.prod(shape)
if len(jax.devices()) < size:
state.skip_with_error(f"Requires {size} devices")
return None
devices = sorted(jax.devices(), key=lambda d: d.id)
mesh_devices = np.array(devices[:size]).reshape(shape)
global_mesh = jax.sharding.Mesh(mesh_devices, axis_names)
return global_mesh
def swap(a, b):
return b, a
class AnEnum(enum.IntEnum):
A = 123
B = 456
@google_benchmark.register
def eager_unary_dispatch(state):
a = jax.device_put(1)
lax.neg(a)
while state:
lax.neg(a)
@google_benchmark.register
def eager_unary(state):
a = jax.device_put(1)
lax.neg(a).block_until_ready()
while state:
lax.neg(a).block_until_ready()
@google_benchmark.register
def eager_binary_dispatch(state):
a = jax.device_put(1)
b = jax.device_put(2)
lax.add(a, b)
while state:
lax.add(a, b)
@google_benchmark.register
def eager_binary(state):
a = jax.device_put(1)
b = jax.device_put(2)
lax.add(a, b).block_until_ready()
while state:
lax.add(a, b).block_until_ready()
@google_benchmark.register
def jit_trivial_dispatch(state):
"""Benchmarks only the duration for jitted_f to return the future."""
f = jax.jit(swap)
a, b = f(1, 2)
x = f(a, b)
while state:
x = f(a, b)
x[0].block_until_ready()
@google_benchmark.register
def jit_trivial(state):
f = jax.jit(swap)
a, b = f(1, 2)
f(a, b)
while state:
c, d = f(a, b)
c.block_until_ready()
d.block_until_ready()
@google_benchmark.register
def jit_simple_dispatch(state):
a = jax.device_put(1)
b = jax.device_put(2)
f = jax.jit(operator.add)
f(a, b)
while state:
f(a, b)
@google_benchmark.register
def jit_simple(state):
a = jax.device_put(1)
b = jax.device_put(2)
f = jax.jit(operator.add)
f(a, b)
while state:
f(a, b).block_until_ready()
@google_benchmark.register
def jit_simple_dispatch_array(state):
with jax_config.jax_array(True):
a = jax.device_put(1)
b = jax.device_put(2)
f = jax.jit(operator.add)
f(a, b)
while state:
f(a, b)
@google_benchmark.register
def jit_simple_array(state):
with jax_config.jax_array(True):
a = jax.device_put(1)
b = jax.device_put(2)
f = jax.jit(operator.add)
f(a, b)
while state:
f(a, b).block_until_ready()
@google_benchmark.register
def jit_small_matmul(state):
x = np.random.uniform(size=(2, 2)).astype(np.float32)
x = jax.device_put(x)
f = jax.jit(lambda x: jnp.dot(x, x))
f(x).block_until_ready()
while state:
f(x).block_until_ready()
@google_benchmark.register
def jit_big_matmul(state):
x = np.random.uniform(size=(100, 100)).astype(np.float32)
x = jax.device_put(x)
f = jax.jit(lambda x: jnp.dot(x, x))
f(x).block_until_ready()
while state:
f(x).block_until_ready()
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'jax_array'])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.args([1000, False])
@google_benchmark.option.args([1000, True])
@google_benchmark.option.args([2000, False])
@google_benchmark.option.args([2000, True])
def jit_simple_many_args_dispatch(state):
with jax_config.jax_array(state.range(1)):
args = [jax.device_put(i) for i in range(state.range(0))]
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
x = f(args)
x.block_until_ready()
while state:
x = f(args)
x.block_until_ready()
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'jax_array'])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.args([1000, False])
@google_benchmark.option.args([1000, True])
@google_benchmark.option.args([2000, False])
@google_benchmark.option.args([2000, True])
def jit_simple_many_args(state):
with jax_config.jax_array(state.range(1)):
args = [jax.device_put(i) for i in range(state.range(0))]
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
f(args).block_until_ready()
while state:
f(args).block_until_ready()
def jit_simple_pruned_args_dispatch(n, state):
args = [jax.device_put(i) for i in range(n)]
f = jax.jit(lambda *xs: xs[0] + 1)
x = f(*args)
x.block_until_ready()
while state:
x = f(*args)
x.block_until_ready()
def jit_simple_pruned_args(n, state):
args = [jax.device_put(i) for i in range(n)]
f = jax.jit(lambda *xs: xs[0] + 1)
x = f(*args)
x.block_until_ready()
while state:
f(*args).block_until_ready()
benchmarks = []
for n in [10, 100, 1000, 2000]:
benchmarks += [
google_benchmark.register(partial(jit_simple_pruned_args_dispatch, n),
name=f"jit_simple_pruned_args_dispatch_{n}"),
google_benchmark.register(partial(jit_simple_pruned_args, n),
name=f"jit_simple_pruned_args_{n}")
]
@google_benchmark.register
def jit_dispatch_without_transfer(state):
# We pick up a realistic input. 224 is usual for classification and 128 a
# TPU-friendly batch-size.
imgs = np.ones((128, 224, 224), np.float32)
imgs = jax.device_put(imgs)
f = jax.jit(lambda x: x+1)
f(imgs)
while state:
f(imgs)
@google_benchmark.register
def jit_dispatch_with_transfer(state):
imgs = np.ones((128, 224, 224), np.float32)
f = jax.jit(lambda x: x+1)
f(imgs).block_until_ready()
while state:
x = f(imgs)
x.block_until_ready()
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(2)
def pmap_trivial_2_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(swap)
a, b = f(jnp.array([1, 2]), jnp.array([3, 4]))
while state:
c, d = f(a, b)
c.block_until_ready()
d.block_until_ready()
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_trivial_dispatch_8_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(swap)
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
while state:
a, b = f(a, b)
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_trivial_8_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(swap)
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
while state:
c, d = f(a, b)
c.block_until_ready()
d.block_until_ready()
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(2)
def pmap_simple_2_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda a, b: (a + b, a - b))
a, b = f(jnp.array([1, 2]), jnp.array([3, 4]))
while state:
c, d = f(a, b)
c.block_until_ready()
d.block_until_ready()
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_simple_dispatch_8_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda a, b: (a + b, a - b))
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
while state:
a, b = f(a, b)
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_simple_8_devices(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda a, b: (a + b, a - b))
a, b = f(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]),
jnp.array([2, 3, 4, 5, 6, 7, 8, 9]))
while state:
c, d = f(a, b)
c.block_until_ready()
d.block_until_ready()
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_simple_dispatch_8_devices_100_args(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
args = []
for i in range(100):
args.append(jnp.array(list(range(i, i+8))))
args = f(*args)
while state:
args = f(*args)
@google_benchmark.register
@google_benchmark.option.arg_name('jax_array')
@google_benchmark.option.arg(True)
@google_benchmark.option.arg(False)
@required_devices(8)
def pmap_simple_8_devices_100_args(state):
with jax_config.jax_array(state.range(0)):
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
args = []
for i in range(100):
args.append(jnp.array(list(range(i, i+8))))
# Warmup loop.
out = f(*args)
while state:
out = f(*args)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
def _run_sda_index_bench(state, num_devices):
x = jax.pmap(jnp.sin)(jnp.arange(num_devices))
jax.device_get(x)
while state:
for i in range(num_devices):
_ = x[i]
@google_benchmark.register
@required_devices(1)
def sda_index_1(state):
_run_sda_index_bench(state, 1)
@google_benchmark.register
@required_devices(2)
def sda_index_2(state):
_run_sda_index_bench(state, 2)
@google_benchmark.register
@required_devices(8)
def sda_index_8(state):
_run_sda_index_bench(state, 8)
def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False):
shape = (2000, 2000)
nse = 10000
size = np.prod(shape)
rng = np.random.RandomState(1701)
data = rng.randn(nse)
indices = np.unravel_index(
rng.choice(size, size=nse, replace=False), shape=shape)
mat = jnp.zeros(shape).at[indices].set(data)
f = sparse.BCOO.fromdense
if compile or jit:
# Note: nse must be specified for JIT.
f = jax.jit(partial(f, nse=nse))
if compile:
while state:
f.lower(mat).compile()
else:
f(mat).block_until_ready()
while state:
f(mat).block_until_ready()
@google_benchmark.register
def sparse_bcoo_fromdense(state):
return _sparse_bcoo_fromdense(state)
@google_benchmark.register
def sparse_bcoo_fromdense_jit(state):
return _sparse_bcoo_fromdense(state, jit=True)
@google_benchmark.register
def sparse_bcoo_fromdense_compile(state):
return _sparse_bcoo_fromdense(state, compile=True)
def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False):
shape = (2000, 2000)
nse = 10000
size = np.prod(shape)
rng = np.random.RandomState(1701)
data = rng.randn(nse)
indices = np.unravel_index(
rng.choice(size, size=nse, replace=False), shape=shape)
mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape)
f = lambda mat: mat.todense()
if jit or compile:
f = jax.jit(f)
if compile:
while state:
f.lower(mat).compile()
else:
f(mat).block_until_ready()
while state:
f(mat).block_until_ready()
@google_benchmark.register
def sparse_bcoo_todense(state):
return _sparse_bcoo_todense(state)
@google_benchmark.register
def sparse_bcoo_todense_jit(state):
return _sparse_bcoo_todense(state, jit=True)
@google_benchmark.register
def sparse_bcoo_todense_compile(state):
return _sparse_bcoo_todense(state, compile=True)
def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False):
shape = (2000, 2000)
nse = 10000
key = jax.random.PRNGKey(1701)
mat = sparse.random_bcoo(key, nse=nse, shape=shape, dtype=jnp.float32,
indices_dtype=jnp.int32, sorted_indices=True)
vec = jax.random.uniform(key, shape=(shape[1],), dtype=jnp.float32)
f = lambda mat, vec: mat @ vec
if jit or compile:
f = jax.jit(f)
if compile:
while state:
f.lower(mat, vec).compile()
else:
f(mat, vec).block_until_ready()
while state:
f(mat, vec).block_until_ready()
@google_benchmark.register
def sparse_bcoo_matvec(state):
return _sparse_bcoo_matvec(state)
@google_benchmark.register
def sparse_bcoo_matvec_jit(state):
return _sparse_bcoo_matvec(state, jit=True)
@google_benchmark.register
def sparse_bcoo_matvec_compile(state):
return _sparse_bcoo_matvec(state, compile=True)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_shaped_abstractify(state):
device, *_ = jax.devices()
args = [jax.device_put_replicated(1, [device])] * 1000
while state:
_ = [shaped_abstractify(x) for x in args]
def _run_benchmark_for_xla_abstractify(arg, state):
while state:
xla.abstractify(arg)
def bench_xla_abstractify():
_abstractify_args = [
(3, 'scalar_int'),
(3.5, 'scalar_float'),
(np.int32(3), 'scalar_numpy_int32'),
(np.uint32(7), 'scalar_numpy_uint32'),
(np.random.randn(3, 4, 5, 6), 'numpy_random'),
(np.arange(100, dtype=np.float32), 'numpy_arange_100_float32'),
(AnEnum.B, 'enum'),
]
benchmarks = []
for a, name in _abstractify_args:
benchmarks.extend([
google_benchmark.register(
partial(_run_benchmark_for_xla_abstractify, a),
name=f'bench_xla_abstractify_{name}'),
])
bench_xla_abstractify()
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMicrosecond)
def bench_are_op_shardings_equal(state):
op1 = xc.OpSharding()
op1.type = xc.OpSharding.Type.OTHER
op1.tile_assignment_dimensions = [4, 192, 16]
op1.tile_assignment_devices = list(range(12288))
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.OTHER
op2.tile_assignment_dimensions = [4, 192, 16]
op2.tile_assignment_devices = list(range(12288))
while state:
pxla.are_op_shardings_equal(op1, op2)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_pjit_check_aval_sharding(state):
mesh = create_mesh((4, 2), ('x', 'y'), state)
if mesh is None:
return
s = sharding.NamedSharding(mesh, pxla.PartitionSpec('x', 'y'))
aval = jax.core.ShapedArray((8, 2), np.int32)
while state:
pjit_lib.pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_remat_eager_retracing_overheads(state):
def double_compose(f):
return lambda x: f(f(x))
f = jnp.sin
for _ in range(6):
f = double_compose(f)
f = double_compose(checkpoint(f))
while state:
y, _ = jax.vjp(f, 3.)
y.block_until_ready()
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_remat_eager_retracing_overheads_static_argnums(state):
def double_compose(f):
return lambda x, y: f(f(x, y), y)
f = lambda x, _: jnp.sin(x)
for _ in range(6):
f = double_compose(f)
f = double_compose(checkpoint(f, static_argnums=(1,)))
while state:
y, _ = jax.vjp(f, 3., True)
y.block_until_ready()
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_slicing_compilation(state):
x = jnp.arange(3)
while state:
jax.jit(lambda x: (x[0], x[1], x[2])).lower(x).compile()
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_slicing_compilation2(state):
x = jnp.arange(3)
while state:
jax.jit(lambda x: (x[:1], x[1:2], x[2:3])).lower(x).compile()
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_repeated_static_indexing(state):
x = jnp.arange(500)
while state:
jax.block_until_ready([x[i] for i in range(500)])
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_repeated_static_slicing(state):
x = jnp.arange(1000)
while state:
jax.block_until_ready([x[i:i + 2] for i in range(0, 1000, 2)])
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
spec = jax.sharding.PartitionSpec('x')
mesh = create_mesh((num_devices,), ('x',), state)
if mesh is None:
return
s = sharding.NamedSharding(mesh, spec)
inp_data = np.arange(num_devices).astype(np.float32)
x = array.make_array_from_callback(inp_data.shape, s, lambda idx: inp_data[idx])
x = [x for _ in range(num_args)]
prev_state = jax_config.FLAGS.experimental_cpp_pjit
jax_config.FLAGS.experimental_cpp_pjit = cpp_jit
in_axis_resources = sharding.NamedSharding(mesh, spec)
out_axis_resources = sharding.NamedSharding(mesh, spec)
f = pjit_lib.pjit(
lambda x: jax.tree_map(lambda x: x + 1, x),
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources)
if use_aot:
f = f.lower(x).compile()
x = f(x)
while state:
x = f(x)
jax_config.FLAGS.experimental_cpp_pjit = prev_state
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_1_device(state):
pjit_simple_benchmark(
state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_4_device(state):
pjit_simple_benchmark(
state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_4000_device(state):
pjit_simple_benchmark(
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_aot_1_device(state):
pjit_simple_benchmark(
state,
num_devices=1,
num_args=state.range(0),
cpp_jit=state.range(1),
use_aot=True)
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_aot_4_device(state):
pjit_simple_benchmark(
state,
num_devices=4,
num_args=state.range(0),
cpp_jit=state.range(1),
use_aot=True)
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_aot_4000_device(state):
pjit_simple_benchmark(
state,
num_devices=4000,
num_args=state.range(0),
cpp_jit=state.range(1),
use_aot=True)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def host_local_array_to_global_array(state):
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
input_shape = (8, 2)
input_data = np.arange(np.prod(input_shape)).reshape(input_shape)
in_pspec = pxla.PartitionSpec('x', 'y')
while state:
multihost_utils.host_local_array_to_global_array(
(input_data, input_data), global_mesh, (in_pspec, in_pspec))
@google_benchmark.register
def device_put(state):
x = np.array(1, np.int32)
while state:
_ = jax.device_put(x).block_until_ready()
def batch_inplace_while(inplace_op, state):
@jax.jit
@jax.vmap
def f(init_step, init_xs):
def cond(carry):
step, xs = carry
return step < xs.size
def body(carry):
step, xs = carry
if inplace_op == 'scatter':
xs = xs.at[step].set(1)
elif inplace_op == 'dynamic_update_slice':
xs = lax.dynamic_update_index_in_dim(xs, 1., step, 0)
else:
assert False
return step + 1, xs
return lax.while_loop(cond, body, (init_step, init_xs))
size = 100_000
args = jnp.array([0]), jnp.zeros((1, size))
f(*args) # compile
while state:
f(*args)
google_benchmark.register(
partial(batch_inplace_while, 'scatter'), name='batch_inplace_while_scatter')
google_benchmark.register(
partial(batch_inplace_while, 'dynamic_update_slice'),
name='batch_inplace_while_dynamic_update_slice')
if __name__ == "__main__":
google_benchmark.main()