mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Error if memory_kind is not correct for the devices in Shardings during initialization.
PiperOrigin-RevId: 549410478
This commit is contained in:
parent
7df3477926
commit
cd2dc2f2fa
@ -228,6 +228,10 @@ class NamedSharding(XLACompatibleSharding):
|
||||
return type(self), (self.mesh, self.spec)
|
||||
|
||||
def _preprocess(self):
|
||||
if xla_extension_version >= 170 and self.memory_kind is not None:
|
||||
# Will error if memory_kind does not exist on the device.
|
||||
self.mesh.devices.flat[0].memory(self.memory_kind)
|
||||
|
||||
# This split exists because you can pass `_parsed_pspec` that has been
|
||||
# modified from the original. For example: Adding extra dimension to
|
||||
# axis_resources for vmap handlers. In such cases you need to preserve the
|
||||
@ -602,6 +606,9 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
name = self._devices[0].platform.upper()
|
||||
self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)],
|
||||
dtype='object').reshape(devices.shape)
|
||||
if self._memory_kind is not None:
|
||||
# Will error if memory_kind does not exist on the device.
|
||||
self._devices[0].memory(self._memory_kind)
|
||||
|
||||
shape = property(op.attrgetter('_ids.shape'))
|
||||
ndim = property(op.attrgetter('_ids.ndim'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user