mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix GDA error message formatting
PiperOrigin-RevId: 486724647
This commit is contained in:
parent
b127b70e30
commit
218305964f
@ -288,7 +288,7 @@ class GlobalDeviceArray:
|
||||
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
|
||||
assert all(db.shape == ss for db in self._device_buffers), (
|
||||
f"Expected shard shape {ss} doesn't match the device buffer "
|
||||
f"shape, got: {[db.shape for db in device_buffers]}")
|
||||
f"shape, got: {[db.shape for db in self._device_buffers]}")
|
||||
|
||||
if self._sharded_buffer is None:
|
||||
dtype = device_buffers[0].dtype # type: ignore
|
||||
@ -297,7 +297,7 @@ class GlobalDeviceArray:
|
||||
if _enable_checks or config.jax_enable_checks:
|
||||
assert all(db.dtype == dtype for db in self._device_buffers), (
|
||||
"Input arrays to GlobalDeviceArray must have matching dtypes, "
|
||||
f"got: {[db.dtype for db in device_buffers]}")
|
||||
f"got: {[db.dtype for db in self._device_buffers]}")
|
||||
self.dtype = dtype
|
||||
|
||||
def _init_buffers(self, device_buffers):
|
||||
|
Loading…
x
Reference in New Issue
Block a user