diff --git a/jax/_src/array.py b/jax/_src/array.py index 4f1bf3ea3..5226dc2a0 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -75,7 +75,7 @@ class Shard: except ValueError: return f'Shard(device={repr(self.device)}, data={self.data})' - @property + @functools.cached_property def index(self) -> Index: try: device_indices_map_fn = self._sharding.devices_indices_map @@ -87,7 +87,7 @@ class Shard: assert index is not None return index - @property + @functools.cached_property def replica_id(self) -> int: return device_replica_id_map(self._sharding, self._global_shape)[self.device]