Fix GDA error message formatting

PiperOrigin-RevId: 486724647
This commit is contained in:
Hyeontaek Lim 2022-11-07 11:54:30 -08:00 committed by jax authors
parent b127b70e30
commit 218305964f

View File

@ -288,7 +288,7 @@ class GlobalDeviceArray:
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
assert all(db.shape == ss for db in self._device_buffers), (
f"Expected shard shape {ss} doesn't match the device buffer "
f"shape, got: {[db.shape for db in device_buffers]}")
f"shape, got: {[db.shape for db in self._device_buffers]}")
if self._sharded_buffer is None:
dtype = device_buffers[0].dtype # type: ignore
@ -297,7 +297,7 @@ class GlobalDeviceArray:
if _enable_checks or config.jax_enable_checks:
assert all(db.dtype == dtype for db in self._device_buffers), (
"Input arrays to GlobalDeviceArray must have matching dtypes, "
f"got: {[db.dtype for db in device_buffers]}")
f"got: {[db.dtype for db in self._device_buffers]}")
self.dtype = dtype
def _init_buffers(self, device_buffers):