mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Skip the benchmarks properly via state.skip_with_error when enough devices are not present.
PiperOrigin-RevId: 485931295
This commit is contained in:
parent
91d134d65b
commit
532cd7ed74
@ -20,7 +20,6 @@ import operator
|
||||
import google_benchmark
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import config as jax_config
|
||||
from jax.experimental import sparse
|
||||
from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
@ -31,6 +30,7 @@ from jax.interpreters import pxla
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
from jax.experimental import maps
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -54,6 +54,17 @@ def required_devices(num_devices_required):
|
||||
return helper1
|
||||
|
||||
|
||||
def create_mesh(shape, axis_names, state):
|
||||
size = np.prod(shape)
|
||||
if len(jax.devices()) < size:
|
||||
state.skip_with_error(f"Requires {size} devices")
|
||||
return None
|
||||
devices = sorted(jax.devices(), key=lambda d: d.id)
|
||||
mesh_devices = np.array(devices[:size]).reshape(shape)
|
||||
global_mesh = maps.Mesh(mesh_devices, axis_names)
|
||||
return global_mesh
|
||||
|
||||
|
||||
def swap(a, b):
|
||||
return b, a
|
||||
|
||||
@ -610,7 +621,9 @@ def bench_are_op_shardings_equal(state):
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_pjit_check_aval_sharding(state):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||
if mesh is None:
|
||||
return
|
||||
s = sharding.MeshPspecSharding(mesh, pxla.PartitionSpec('x', 'y'))
|
||||
aval = jax.ShapedArray((8, 2), np.int32)
|
||||
|
||||
@ -666,7 +679,9 @@ def bench_slicing_compilation2(state):
|
||||
|
||||
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
|
||||
spec = pjit_lib.PartitionSpec('x')
|
||||
mesh = jtu.create_global_mesh((num_devices,), ('x',))
|
||||
mesh = create_mesh((num_devices,), ('x',), state)
|
||||
if mesh is None:
|
||||
return
|
||||
s = sharding.MeshPspecSharding(mesh, spec)
|
||||
inp_data = np.arange(num_devices).astype(np.float32)
|
||||
x = array.make_array_from_callback(inp_data.shape, s, lambda idx: inp_data[idx])
|
||||
|
Loading…
x
Reference in New Issue
Block a user