Optimize accessing index and replica_id of

addressable_shards

Benchmark:

```
name                                 old time/op  new time/op  delta
bench_addressable_shards_index       53.0µs ± 2%   2.6µs ± 4%  -95.07%  (p=0.008 n=5+5)
bench_addressable_shards_replica_id  51.7µs ± 2%   2.6µs ± 2%  -94.92%  (p=0.008 n=5+5)
```

PiperOrigin-RevId: 517977244
This commit is contained in:
Yash Katariya 2023-03-20 08:36:25 -07:00 committed by jax authors
parent 1faa7a8edd
commit 021fadfcbc

View File

@ -75,7 +75,7 @@ class Shard:
except ValueError:
return f'Shard(device={repr(self.device)}, data={self.data})'
@property
@functools.cached_property
def index(self) -> Index:
try:
device_indices_map_fn = self._sharding.devices_indices_map
@ -87,7 +87,7 @@ class Shard:
assert index is not None
return index
@property
@functools.cached_property
def replica_id(self) -> int:
return device_replica_id_map(self._sharding, self._global_shape)[self.device]