Skip the benchmarks properly via state.skip_with_error when enough devices are not present.

PiperOrigin-RevId: 485931295
This commit is contained in:
Yash Katariya 2022-11-03 11:44:16 -07:00 committed by jax authors
parent 91d134d65b
commit 532cd7ed74

View File

@ -20,7 +20,6 @@ import operator
import google_benchmark
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src import config as jax_config
from jax.experimental import sparse
from jax._src.api_util import shaped_abstractify # technically not an api fn
@ -31,6 +30,7 @@ from jax.interpreters import pxla
from jax._src import array
from jax._src import sharding
from jax.experimental import pjit as pjit_lib
from jax.experimental import maps
import jax.numpy as jnp
import numpy as np
@ -54,6 +54,17 @@ def required_devices(num_devices_required):
return helper1
def create_mesh(shape, axis_names, state):
size = np.prod(shape)
if len(jax.devices()) < size:
state.skip_with_error(f"Requires {size} devices")
return None
devices = sorted(jax.devices(), key=lambda d: d.id)
mesh_devices = np.array(devices[:size]).reshape(shape)
global_mesh = maps.Mesh(mesh_devices, axis_names)
return global_mesh
def swap(a, b):
return b, a
@ -610,7 +621,9 @@ def bench_are_op_shardings_equal(state):
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_pjit_check_aval_sharding(state):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mesh = create_mesh((4, 2), ('x', 'y'), state)
if mesh is None:
return
s = sharding.MeshPspecSharding(mesh, pxla.PartitionSpec('x', 'y'))
aval = jax.ShapedArray((8, 2), np.int32)
@ -666,7 +679,9 @@ def bench_slicing_compilation2(state):
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
spec = pjit_lib.PartitionSpec('x')
mesh = jtu.create_global_mesh((num_devices,), ('x',))
mesh = create_mesh((num_devices,), ('x',), state)
if mesh is None:
return
s = sharding.MeshPspecSharding(mesh, spec)
inp_data = np.arange(num_devices).astype(np.float32)
x = array.make_array_from_callback(inp_data.shape, s, lambda idx: inp_data[idx])