mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[JAX] Clean up make_array_from_callback_* API benchmarks and add a partially replicated sharding variant
To prepare for the upcoming `BatchedDevicePut` implementation changes, this change makes `make_array_from_callback_*` benchmark code to be more homogeneous. Also it adds a variant that uses a partially replicated sharding. PiperOrigin-RevId: 736665856
This commit is contained in:
parent
e615e2acb3
commit
73b8f6aee2
@ -21,14 +21,14 @@ import operator
|
|||||||
import google_benchmark
|
import google_benchmark
|
||||||
import jax
|
import jax
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
|
||||||
from jax._src import core
|
|
||||||
from jax._src.lib import xla_client as xc
|
|
||||||
from jax._src import array
|
from jax._src import array
|
||||||
|
from jax._src import core
|
||||||
from jax._src import op_shardings
|
from jax._src import op_shardings
|
||||||
|
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||||
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.pjit import pjit_check_aval_sharding
|
from jax._src.pjit import pjit_check_aval_sharding
|
||||||
from jax.experimental import pjit as pjit_lib
|
|
||||||
from jax.experimental import multihost_utils
|
from jax.experimental import multihost_utils
|
||||||
|
from jax.experimental import pjit as pjit_lib
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -860,29 +860,44 @@ def safe_zip(state):
|
|||||||
|
|
||||||
@google_benchmark.register
|
@google_benchmark.register
|
||||||
def bench_make_array_from_callback_fully_replicated_sharding(state):
|
def bench_make_array_from_callback_fully_replicated_sharding(state):
|
||||||
mesh = jax.sharding.Mesh(
|
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||||
np.array(jax.devices()[:8]).reshape((4, 2)), ('x', 'y'))
|
if mesh is None:
|
||||||
shape = (8, 2)
|
return
|
||||||
np_arr = np.arange(16).reshape(shape)
|
input_shape = (8, 2)
|
||||||
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||||
|
|
||||||
|
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||||
while state:
|
while state:
|
||||||
jax.make_array_from_callback(shape, s, np_arr.__getitem__)
|
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
|
||||||
|
|
||||||
|
|
||||||
@google_benchmark.register
|
@google_benchmark.register
|
||||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||||
def bench_make_array_from_callback_sharded(state):
|
def bench_make_array_from_callback_partially_replicated_sharding(state):
|
||||||
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
|
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||||
|
if mesh is None:
|
||||||
|
return
|
||||||
input_shape = (8, 2)
|
input_shape = (8, 2)
|
||||||
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||||
|
|
||||||
def callback(index):
|
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None, 'y'))
|
||||||
return input_data[index]
|
|
||||||
|
|
||||||
s = jax.NamedSharding(global_mesh, jax.sharding.PartitionSpec('x', 'y'))
|
|
||||||
while state:
|
while state:
|
||||||
jax.make_array_from_callback((8, 2), s, callback)
|
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
|
||||||
|
|
||||||
|
|
||||||
|
@google_benchmark.register
|
||||||
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||||
|
def bench_make_array_from_callback_fully_sharded_sharding(state):
|
||||||
|
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||||
|
if mesh is None:
|
||||||
|
return
|
||||||
|
input_shape = (8, 2)
|
||||||
|
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||||
|
|
||||||
|
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
||||||
|
while state:
|
||||||
|
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
|
||||||
|
|
||||||
|
|
||||||
@google_benchmark.register
|
@google_benchmark.register
|
||||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user