diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 2562b82ce..5a125667d 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -595,6 +595,34 @@ def bench_pjit_check_aval_sharding(state): pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False) +@google_benchmark.register +def bench_addressable_shards_index(state): + mesh = create_mesh((4, 2), ('x', 'y'), state) + if mesh is None: + return + shape = (8, 2) + inp = np.arange(np.prod(shape)).reshape(shape) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')) + arr = jax.device_put(inp, s) + + while state: + [s.index for s in arr.addressable_shards] + + +@google_benchmark.register +def bench_addressable_shards_replica_id(state): + mesh = create_mesh((32, 16), ('x', 'y'), state) + if mesh is None: + return + shape = (64, 32) + inp = np.arange(np.prod(shape)).reshape(shape) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')) + arr = jax.device_put(inp, s) + + while state: + [s.replica_id for s in arr.addressable_shards] + + @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def bench_remat_eager_retracing_overheads(state):