mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
[JAX] Clean up make_array_from_callback_* API benchmarks and add a partially replicated sharding variant
To prepare for the upcoming `BatchedDevicePut` implementation changes, this change makes `make_array_from_callback_*` benchmark code to be more homogeneous. Also it adds a variant that uses a partially replicated sharding. PiperOrigin-RevId: 736665856
This commit is contained in:
parent
e615e2acb3
commit
73b8f6aee2
@ -21,14 +21,14 @@ import operator
|
||||
import google_benchmark
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
from jax._src import core
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import array
|
||||
from jax._src import core
|
||||
from jax._src import op_shardings
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.pjit import pjit_check_aval_sharding
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -860,29 +860,44 @@ def safe_zip(state):
|
||||
|
||||
@google_benchmark.register
|
||||
def bench_make_array_from_callback_fully_replicated_sharding(state):
|
||||
mesh = jax.sharding.Mesh(
|
||||
np.array(jax.devices()[:8]).reshape((4, 2)), ('x', 'y'))
|
||||
shape = (8, 2)
|
||||
np_arr = np.arange(16).reshape(shape)
|
||||
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||
if mesh is None:
|
||||
return
|
||||
input_shape = (8, 2)
|
||||
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
while state:
|
||||
jax.make_array_from_callback(shape, s, np_arr.__getitem__)
|
||||
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_make_array_from_callback_sharded(state):
|
||||
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||
def bench_make_array_from_callback_partially_replicated_sharding(state):
|
||||
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||
if mesh is None:
|
||||
return
|
||||
input_shape = (8, 2)
|
||||
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
def callback(index):
|
||||
return input_data[index]
|
||||
|
||||
s = jax.NamedSharding(global_mesh, jax.sharding.PartitionSpec('x', 'y'))
|
||||
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None, 'y'))
|
||||
while state:
|
||||
jax.make_array_from_callback((8, 2), s, callback)
|
||||
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_make_array_from_callback_fully_sharded_sharding(state):
|
||||
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||
if mesh is None:
|
||||
return
|
||||
input_shape = (8, 2)
|
||||
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
||||
while state:
|
||||
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
|
Loading…
x
Reference in New Issue
Block a user