From cd2dc2f2fa6d892eb6166fafca79b830398e19bb Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Jul 2023 13:32:06 -0700 Subject: [PATCH] Error if memory_kind is not correct for the devices in Shardings during initialization. PiperOrigin-RevId: 549410478 --- jax/_src/sharding_impls.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index a25b1c58e..5475b24a8 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -228,6 +228,10 @@ class NamedSharding(XLACompatibleSharding): return type(self), (self.mesh, self.spec) def _preprocess(self): + if xla_extension_version >= 170 and self.memory_kind is not None: + # Will error if memory_kind does not exist on the device. + self.mesh.devices.flat[0].memory(self.memory_kind) + # This split exists because you can pass `_parsed_pspec` that has been # modified from the original. For example: Adding extra dimension to # axis_resources for vmap handlers. In such cases you need to preserve the @@ -602,6 +606,9 @@ class PositionalSharding(XLACompatibleSharding): name = self._devices[0].platform.upper() self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)], dtype='object').reshape(devices.shape) + if self._memory_kind is not None: + # Will error if memory_kind does not exist on the device. + self._devices[0].memory(self._memory_kind) shape = property(op.attrgetter('_ids.shape')) ndim = property(op.attrgetter('_ids.ndim'))