From 1faa7a8eddc2683687722e9a561233523d61f3f6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 20 Mar 2023 08:21:55 -0700 Subject: [PATCH] Add benchmarks for accessing index and replica id in addressable_shards PiperOrigin-RevId: 517974091 --- benchmarks/api_benchmark.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 2562b82ce..5a125667d 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -595,6 +595,34 @@ def bench_pjit_check_aval_sharding(state): pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False) +@google_benchmark.register +def bench_addressable_shards_index(state): + mesh = create_mesh((4, 2), ('x', 'y'), state) + if mesh is None: + return + shape = (8, 2) + inp = np.arange(np.prod(shape)).reshape(shape) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')) + arr = jax.device_put(inp, s) + + while state: + [s.index for s in arr.addressable_shards] + + +@google_benchmark.register +def bench_addressable_shards_replica_id(state): + mesh = create_mesh((32, 16), ('x', 'y'), state) + if mesh is None: + return + shape = (64, 32) + inp = np.arange(np.prod(shape)).reshape(shape) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')) + arr = jax.device_put(inp, s) + + while state: + [s.replica_id for s in arr.addressable_shards] + + @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def bench_remat_eager_retracing_overheads(state):