__repr__ if an Array is fully replicated. Its the same for _value so it makes sense to do the same for __repr__.

PiperOrigin-RevId: 469892350
This commit is contained in:
Yash Katariya 2022-08-24 20:41:48 -07:00 committed by jax authors
parent 8e2d1be0a5
commit fd3a72dd1f

View File

@ -118,7 +118,10 @@ class Array:
self._committed = committed
self._npy_value = None
if config.jax_enable_checks:
# TODO(yashkatariya): Add a check here which checks if the expected shard
# shape matches the shape of _arrays. A similar check exists for GDA.
if not _skip_checks or config.jax_enable_checks:
assert all(db.dtype == self.dtype for db in self._arrays), (
"Input arrays to `Array` must have matching dtypes, "
f"got: {[db.dtype for db in self._arrays]}, aval type: {self.dtype}")
@ -235,6 +238,9 @@ class Array:
else:
raise TypeError(self.dtype)
def is_fully_replicated(self) -> bool:
return self.shape == self._arrays[0].shape
def __repr__(self):
prefix = '{}('.format(self.__class__.__name__.lstrip('_'))
if self.aval is not None and self.aval.weak_type:
@ -242,7 +248,7 @@ class Array:
else:
dtype_str = f'dtype={self.dtype.name})'
if self.is_fully_addressable():
if self.is_fully_addressable() or self.is_fully_replicated():
line_width = np.get_printoptions()["linewidth"]
s = np.array2string(self._value, prefix=prefix, suffix=',',
separator=', ', max_line_width=line_width)
@ -340,15 +346,9 @@ class Array:
self._check_if_deleted()
if self._npy_value is None:
if isinstance(self.sharding, XLACompatibleSharding):
try:
op_sharding = self.sharding._to_xla_op_sharding(self.ndim)
assert op_sharding is not None
if pxla.is_op_sharding_replicated(op_sharding):
self._npy_value = self._arrays[0].to_py() # type: ignore
if self.is_fully_replicated():
self._npy_value = np.asarray(self._arrays[0]) # type: ignore
return cast(np.ndarray, self._npy_value)
except NotImplementedError:
pass
if not self.is_fully_addressable():
raise RuntimeError("Fetching value for `jax.Array` that spans "