From 73b8f6aee292bcacadfc79f45931714d0fa18fc7 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 13 Mar 2025 15:50:05 -0700 Subject: [PATCH] [JAX] Clean up make_array_from_callback_* API benchmarks and add a partially replicated sharding variant To prepare for the upcoming `BatchedDevicePut` implementation changes, this change makes `make_array_from_callback_*` benchmark code to be more homogeneous. Also it adds a variant that uses a partially replicated sharding. PiperOrigin-RevId: 736665856 --- benchmarks/api_benchmark.py | 51 ++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index c3be27f4a..cabebce22 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -21,14 +21,14 @@ import operator import google_benchmark import jax from jax import lax -from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation -from jax._src import core -from jax._src.lib import xla_client as xc from jax._src import array +from jax._src import core from jax._src import op_shardings +from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation +from jax._src.lib import xla_client as xc from jax._src.pjit import pjit_check_aval_sharding -from jax.experimental import pjit as pjit_lib from jax.experimental import multihost_utils +from jax.experimental import pjit as pjit_lib import jax.numpy as jnp import numpy as np @@ -860,29 +860,44 @@ def safe_zip(state): @google_benchmark.register def bench_make_array_from_callback_fully_replicated_sharding(state): - mesh = jax.sharding.Mesh( - np.array(jax.devices()[:8]).reshape((4, 2)), ('x', 'y')) - shape = (8, 2) - np_arr = np.arange(16).reshape(shape) - s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + mesh = create_mesh((4, 2), ('x', 'y'), state) + if mesh is None: + return + input_shape = (8, 2) + np_arr = np.arange(math.prod(input_shape)).reshape(input_shape) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) while state: - jax.make_array_from_callback(shape, s, np_arr.__getitem__) + jax.make_array_from_callback(input_shape, s, np_arr.__getitem__) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) -def bench_make_array_from_callback_sharded(state): - global_mesh = create_mesh((4, 2), ('x', 'y'), state) +def bench_make_array_from_callback_partially_replicated_sharding(state): + mesh = create_mesh((4, 2), ('x', 'y'), state) + if mesh is None: + return input_shape = (8, 2) - input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + np_arr = np.arange(math.prod(input_shape)).reshape(input_shape) - def callback(index): - return input_data[index] - - s = jax.NamedSharding(global_mesh, jax.sharding.PartitionSpec('x', 'y')) + s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None, 'y')) while state: - jax.make_array_from_callback((8, 2), s, callback) + jax.make_array_from_callback(input_shape, s, np_arr.__getitem__) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def bench_make_array_from_callback_fully_sharded_sharding(state): + mesh = create_mesh((4, 2), ('x', 'y'), state) + if mesh is None: + return + input_shape = (8, 2) + np_arr = np.arange(math.prod(input_shape)).reshape(input_shape) + + s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')) + while state: + jax.make_array_from_callback(input_shape, s, np_arr.__getitem__) + @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond)