mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
__repr__
if an Array is fully replicated. Its the same for _value
so it makes sense to do the same for __repr__
.
PiperOrigin-RevId: 469892350
This commit is contained in:
parent
8e2d1be0a5
commit
fd3a72dd1f
@ -118,7 +118,10 @@ class Array:
|
||||
self._committed = committed
|
||||
self._npy_value = None
|
||||
|
||||
if config.jax_enable_checks:
|
||||
# TODO(yashkatariya): Add a check here which checks if the expected shard
|
||||
# shape matches the shape of _arrays. A similar check exists for GDA.
|
||||
|
||||
if not _skip_checks or config.jax_enable_checks:
|
||||
assert all(db.dtype == self.dtype for db in self._arrays), (
|
||||
"Input arrays to `Array` must have matching dtypes, "
|
||||
f"got: {[db.dtype for db in self._arrays]}, aval type: {self.dtype}")
|
||||
@ -235,6 +238,9 @@ class Array:
|
||||
else:
|
||||
raise TypeError(self.dtype)
|
||||
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return self.shape == self._arrays[0].shape
|
||||
|
||||
def __repr__(self):
|
||||
prefix = '{}('.format(self.__class__.__name__.lstrip('_'))
|
||||
if self.aval is not None and self.aval.weak_type:
|
||||
@ -242,7 +248,7 @@ class Array:
|
||||
else:
|
||||
dtype_str = f'dtype={self.dtype.name})'
|
||||
|
||||
if self.is_fully_addressable():
|
||||
if self.is_fully_addressable() or self.is_fully_replicated():
|
||||
line_width = np.get_printoptions()["linewidth"]
|
||||
s = np.array2string(self._value, prefix=prefix, suffix=',',
|
||||
separator=', ', max_line_width=line_width)
|
||||
@ -340,15 +346,9 @@ class Array:
|
||||
self._check_if_deleted()
|
||||
|
||||
if self._npy_value is None:
|
||||
if isinstance(self.sharding, XLACompatibleSharding):
|
||||
try:
|
||||
op_sharding = self.sharding._to_xla_op_sharding(self.ndim)
|
||||
assert op_sharding is not None
|
||||
if pxla.is_op_sharding_replicated(op_sharding):
|
||||
self._npy_value = self._arrays[0].to_py() # type: ignore
|
||||
if self.is_fully_replicated():
|
||||
self._npy_value = np.asarray(self._arrays[0]) # type: ignore
|
||||
return cast(np.ndarray, self._npy_value)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
if not self.is_fully_addressable():
|
||||
raise RuntimeError("Fetching value for `jax.Array` that spans "
|
||||
|
Loading…
x
Reference in New Issue
Block a user