Add benchmarks for accessing index and replica id in addressable_shards

PiperOrigin-RevId: 517974091
This commit is contained in:
Yash Katariya 2023-03-20 08:21:55 -07:00 committed by jax authors
parent f4abde222a
commit 1faa7a8edd

View File

@ -595,6 +595,34 @@ def bench_pjit_check_aval_sharding(state):
pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
@google_benchmark.register
def bench_addressable_shards_index(state):
mesh = create_mesh((4, 2), ('x', 'y'), state)
if mesh is None:
return
shape = (8, 2)
inp = np.arange(np.prod(shape)).reshape(shape)
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
arr = jax.device_put(inp, s)
while state:
[s.index for s in arr.addressable_shards]
@google_benchmark.register
def bench_addressable_shards_replica_id(state):
mesh = create_mesh((32, 16), ('x', 'y'), state)
if mesh is None:
return
shape = (64, 32)
inp = np.arange(np.prod(shape)).reshape(shape)
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
arr = jax.device_put(inp, s)
while state:
[s.replica_id for s in arr.addressable_shards]
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_remat_eager_retracing_overheads(state):