mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Optimize accessing index
and replica_id
of
addressable_shards Benchmark: ``` name old time/op new time/op delta bench_addressable_shards_index 53.0µs ± 2% 2.6µs ± 4% -95.07% (p=0.008 n=5+5) bench_addressable_shards_replica_id 51.7µs ± 2% 2.6µs ± 2% -94.92% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 517977244
This commit is contained in:
parent
1faa7a8edd
commit
021fadfcbc
@ -75,7 +75,7 @@ class Shard:
|
||||
except ValueError:
|
||||
return f'Shard(device={repr(self.device)}, data={self.data})'
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def index(self) -> Index:
|
||||
try:
|
||||
device_indices_map_fn = self._sharding.devices_indices_map
|
||||
@ -87,7 +87,7 @@ class Shard:
|
||||
assert index is not None
|
||||
return index
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def replica_id(self) -> int:
|
||||
return device_replica_id_map(self._sharding, self._global_shape)[self.device]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user