diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 9a432695d..d45f0dbc9 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -242,12 +242,8 @@ class NamedSharding(XLACompatibleSharding): self._parsed_pspec = _parsed_pspec self._preprocess() - @classmethod - def unpickle(cls, mesh, spec): - return cls(mesh, spec) - def __reduce__(self): - return type(self).unpickle, (self.mesh, self.spec) + return type(self), (self.mesh, self.spec) def _preprocess(self): # This split exists because you can pass `_parsed_pspec` that has been @@ -352,12 +348,8 @@ class SingleDeviceSharding(XLACompatibleSharding): def __init__(self, device: Device): self._device = device - @classmethod - def unpickle(cls, device: Device): - return cls(device) - def __reduce__(self): - return type(self).unpickle, (self._device,) + return type(self), (self._device,) def __repr__(self): return f"SingleDeviceSharding(device={repr(self._device)})" @@ -396,12 +388,8 @@ class PmapSharding(XLACompatibleSharding): # The sharding spec should be pmap's sharding spec. self.sharding_spec = sharding_spec - @classmethod - def unpickle(cls, devices: np.ndarray, sharding_spec: pxla.ShardingSpec): - return cls(devices, sharding_spec) - def __reduce__(self): - return type(self).unpickle, (self.devices, self.sharding_spec) + return type(self), (self.devices, self.sharding_spec) def __eq__(self, other): if not isinstance(other, PmapSharding): @@ -608,12 +596,8 @@ class OpShardingSharding(XLACompatibleSharding): self._devices = tuple(devices) self._op_sharding = op_sharding - @classmethod - def unpickle(cls, devices: Sequence[Device], op_sharding: xc.OpSharding): - return cls(devices, op_sharding) - def __reduce__(self): - return type(self).unpickle, (self._devices, self._op_sharding) + return type(self), (self._devices, self._op_sharding) @pxla.maybe_cached_property def _op_sharding_hash(self): diff --git a/tests/BUILD b/tests/BUILD index 88996e5ca..12e4a0fe7 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -548,7 +548,7 @@ jax_test( srcs = ["pickle_test.py"], deps = [ "//jax:experimental", - ] + py_deps("cloudpickle"), + ] + py_deps("cloudpickle") + py_deps("numpy"), ) jax_test( diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 21250c531..3a9594e9b 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -35,6 +35,8 @@ from jax._src import test_util as jtu from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version +import numpy as np + config.parse_flags_with_absl() @@ -48,6 +50,15 @@ def _get_device_by_id(device_id: int) -> xc.Device: xc.Device.__reduce__ = lambda d: (_get_device_by_id, (d.id,)) +if cloudpickle is not None: + def _reduce_mesh(mesh): + # Avoid including mesh._hash in the serialized bytes for Mesh. Without this + # the Mesh would be different among the workers. + return jax.pxla.Mesh, (mesh.devices, mesh.axis_names) + + cloudpickle.CloudPickler.dispatch_table[jax.pxla.Mesh] = _reduce_mesh + + class CloudpickleTest(jtu.JaxTestCase): @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") @@ -188,5 +199,15 @@ class PickleTest(jtu.JaxTestCase): s = sharding.OpShardingSharding(jax.devices(), op_sharding) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + @unittest.skipIf(xla_extension_version < 104, + 'NamedSharding pickling requires newer jaxlib.') + @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") + def test_pickle_named_sharding(self): + s = jax.sharding.NamedSharding( + mesh=pxla.Mesh(np.array(jax.devices()), 'd'), + spec=pxla.PartitionSpec('d')) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())