From 021fadfcbc7b081461053bc777406804889f0353 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 20 Mar 2023 08:36:25 -0700 Subject: [PATCH] Optimize accessing `index` and `replica_id` of addressable_shards MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Benchmark: ``` name old time/op new time/op delta bench_addressable_shards_index 53.0µs ± 2% 2.6µs ± 4% -95.07% (p=0.008 n=5+5) bench_addressable_shards_replica_id 51.7µs ± 2% 2.6µs ± 2% -94.92% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 517977244 --- jax/_src/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 4f1bf3ea3..5226dc2a0 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -75,7 +75,7 @@ class Shard: except ValueError: return f'Shard(device={repr(self.device)}, data={self.data})' - @property + @functools.cached_property def index(self) -> Index: try: device_indices_map_fn = self._sharding.devices_indices_map @@ -87,7 +87,7 @@ class Shard: assert index is not None return index - @property + @functools.cached_property def replica_id(self) -> int: return device_replica_id_map(self._sharding, self._global_shape)[self.device]