diff --git a/jax/experimental/array.py b/jax/experimental/array.py index 59e5aef22..a165d8695 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -118,7 +118,10 @@ class Array: self._committed = committed self._npy_value = None - if config.jax_enable_checks: + # TODO(yashkatariya): Add a check here which checks if the expected shard + # shape matches the shape of _arrays. A similar check exists for GDA. + + if not _skip_checks or config.jax_enable_checks: assert all(db.dtype == self.dtype for db in self._arrays), ( "Input arrays to `Array` must have matching dtypes, " f"got: {[db.dtype for db in self._arrays]}, aval type: {self.dtype}") @@ -235,6 +238,9 @@ class Array: else: raise TypeError(self.dtype) + def is_fully_replicated(self) -> bool: + return self.shape == self._arrays[0].shape + def __repr__(self): prefix = '{}('.format(self.__class__.__name__.lstrip('_')) if self.aval is not None and self.aval.weak_type: @@ -242,7 +248,7 @@ class Array: else: dtype_str = f'dtype={self.dtype.name})' - if self.is_fully_addressable(): + if self.is_fully_addressable() or self.is_fully_replicated(): line_width = np.get_printoptions()["linewidth"] s = np.array2string(self._value, prefix=prefix, suffix=',', separator=', ', max_line_width=line_width) @@ -340,15 +346,9 @@ class Array: self._check_if_deleted() if self._npy_value is None: - if isinstance(self.sharding, XLACompatibleSharding): - try: - op_sharding = self.sharding._to_xla_op_sharding(self.ndim) - assert op_sharding is not None - if pxla.is_op_sharding_replicated(op_sharding): - self._npy_value = self._arrays[0].to_py() # type: ignore - return cast(np.ndarray, self._npy_value) - except NotImplementedError: - pass + if self.is_fully_replicated(): + self._npy_value = np.asarray(self._arrays[0]) # type: ignore + return cast(np.ndarray, self._npy_value) if not self.is_fully_addressable(): raise RuntimeError("Fetching value for `jax.Array` that spans "