mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add benchmarks for accessing index and replica id in addressable_shards
PiperOrigin-RevId: 517974091
This commit is contained in:
parent
f4abde222a
commit
1faa7a8edd
@ -595,6 +595,34 @@ def bench_pjit_check_aval_sharding(state):
|
||||
pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def bench_addressable_shards_index(state):
|
||||
mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||
if mesh is None:
|
||||
return
|
||||
shape = (8, 2)
|
||||
inp = np.arange(np.prod(shape)).reshape(shape)
|
||||
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
||||
arr = jax.device_put(inp, s)
|
||||
|
||||
while state:
|
||||
[s.index for s in arr.addressable_shards]
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def bench_addressable_shards_replica_id(state):
|
||||
mesh = create_mesh((32, 16), ('x', 'y'), state)
|
||||
if mesh is None:
|
||||
return
|
||||
shape = (64, 32)
|
||||
inp = np.arange(np.prod(shape)).reshape(shape)
|
||||
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
||||
arr = jax.device_put(inp, s)
|
||||
|
||||
while state:
|
||||
[s.replica_id for s in arr.addressable_shards]
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_remat_eager_retracing_overheads(state):
|
||||
|
Loading…
x
Reference in New Issue
Block a user