Error if memory_kind is not correct for the devices in Shardings during initialization.

PiperOrigin-RevId: 549410478
This commit is contained in:
Yash Katariya 2023-07-19 13:32:06 -07:00 committed by jax authors
parent 7df3477926
commit cd2dc2f2fa

View File

@ -228,6 +228,10 @@ class NamedSharding(XLACompatibleSharding):
return type(self), (self.mesh, self.spec)
def _preprocess(self):
if xla_extension_version >= 170 and self.memory_kind is not None:
# Will error if memory_kind does not exist on the device.
self.mesh.devices.flat[0].memory(self.memory_kind)
# This split exists because you can pass `_parsed_pspec` that has been
# modified from the original. For example: Adding extra dimension to
# axis_resources for vmap handlers. In such cases you need to preserve the
@ -602,6 +606,9 @@ class PositionalSharding(XLACompatibleSharding):
name = self._devices[0].platform.upper()
self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)],
dtype='object').reshape(devices.shape)
if self._memory_kind is not None:
# Will error if memory_kind does not exist on the device.
self._devices[0].memory(self._memory_kind)
shape = property(op.attrgetter('_ids.shape'))
ndim = property(op.attrgetter('_ids.ndim'))