[JAX] Clean up make_array_from_callback_* API benchmarks and add a partially replicated sharding variant

To prepare for the upcoming `BatchedDevicePut` implementation changes, this
change makes `make_array_from_callback_*` benchmark code to be more
homogeneous. Also it adds a variant that uses a partially replicated sharding.

PiperOrigin-RevId: 736665856
This commit is contained in:
Hyeontaek Lim 2025-03-13 15:50:05 -07:00 committed by jax authors
parent e615e2acb3
commit 73b8f6aee2

View File

@ -21,14 +21,14 @@ import operator
import google_benchmark
import jax
from jax import lax
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src import core
from jax._src.lib import xla_client as xc
from jax._src import array
from jax._src import core
from jax._src import op_shardings
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src.lib import xla_client as xc
from jax._src.pjit import pjit_check_aval_sharding
from jax.experimental import pjit as pjit_lib
from jax.experimental import multihost_utils
from jax.experimental import pjit as pjit_lib
import jax.numpy as jnp
import numpy as np
@ -860,29 +860,44 @@ def safe_zip(state):
@google_benchmark.register
def bench_make_array_from_callback_fully_replicated_sharding(state):
mesh = jax.sharding.Mesh(
np.array(jax.devices()[:8]).reshape((4, 2)), ('x', 'y'))
shape = (8, 2)
np_arr = np.arange(16).reshape(shape)
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
mesh = create_mesh((4, 2), ('x', 'y'), state)
if mesh is None:
return
input_shape = (8, 2)
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
while state:
jax.make_array_from_callback(shape, s, np_arr.__getitem__)
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_make_array_from_callback_sharded(state):
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
def bench_make_array_from_callback_partially_replicated_sharding(state):
mesh = create_mesh((4, 2), ('x', 'y'), state)
if mesh is None:
return
input_shape = (8, 2)
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
def callback(index):
return input_data[index]
s = jax.NamedSharding(global_mesh, jax.sharding.PartitionSpec('x', 'y'))
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None, 'y'))
while state:
jax.make_array_from_callback((8, 2), s, callback)
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_make_array_from_callback_fully_sharded_sharding(state):
mesh = create_mesh((4, 2), ('x', 'y'), state)
if mesh is None:
return
input_shape = (8, 2)
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
while state:
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)