mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Error if memory_kind is not correct for the devices in Shardings during initialization.
PiperOrigin-RevId: 549410478
This commit is contained in:
parent
7df3477926
commit
cd2dc2f2fa
@ -228,6 +228,10 @@ class NamedSharding(XLACompatibleSharding):
|
|||||||
return type(self), (self.mesh, self.spec)
|
return type(self), (self.mesh, self.spec)
|
||||||
|
|
||||||
def _preprocess(self):
|
def _preprocess(self):
|
||||||
|
if xla_extension_version >= 170 and self.memory_kind is not None:
|
||||||
|
# Will error if memory_kind does not exist on the device.
|
||||||
|
self.mesh.devices.flat[0].memory(self.memory_kind)
|
||||||
|
|
||||||
# This split exists because you can pass `_parsed_pspec` that has been
|
# This split exists because you can pass `_parsed_pspec` that has been
|
||||||
# modified from the original. For example: Adding extra dimension to
|
# modified from the original. For example: Adding extra dimension to
|
||||||
# axis_resources for vmap handlers. In such cases you need to preserve the
|
# axis_resources for vmap handlers. In such cases you need to preserve the
|
||||||
@ -602,6 +606,9 @@ class PositionalSharding(XLACompatibleSharding):
|
|||||||
name = self._devices[0].platform.upper()
|
name = self._devices[0].platform.upper()
|
||||||
self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)],
|
self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)],
|
||||||
dtype='object').reshape(devices.shape)
|
dtype='object').reshape(devices.shape)
|
||||||
|
if self._memory_kind is not None:
|
||||||
|
# Will error if memory_kind does not exist on the device.
|
||||||
|
self._devices[0].memory(self._memory_kind)
|
||||||
|
|
||||||
shape = property(op.attrgetter('_ids.shape'))
|
shape = property(op.attrgetter('_ids.shape'))
|
||||||
ndim = property(op.attrgetter('_ids.ndim'))
|
ndim = property(op.attrgetter('_ids.ndim'))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user