From 518fe6656ca2aab66dcfc8cd7866c10f476a17b1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 22 Nov 2022 08:30:38 -0800 Subject: [PATCH] Pickling of Sharding classes: use module level functions when deserializing. This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s. PiperOrigin-RevId: 490249959 --- jax/_src/sharding.py | 24 ++++-------------------- tests/BUILD | 2 +- tests/pickle_test.py | 21 +++++++++++++++++++++ 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 9a432695d..d45f0dbc9 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -242,12 +242,8 @@ class NamedSharding(XLACompatibleSharding): self._parsed_pspec = _parsed_pspec self._preprocess() - @classmethod - def unpickle(cls, mesh, spec): - return cls(mesh, spec) - def __reduce__(self): - return type(self).unpickle, (self.mesh, self.spec) + return type(self), (self.mesh, self.spec) def _preprocess(self): # This split exists because you can pass `_parsed_pspec` that has been @@ -352,12 +348,8 @@ class SingleDeviceSharding(XLACompatibleSharding): def __init__(self, device: Device): self._device = device - @classmethod - def unpickle(cls, device: Device): - return cls(device) - def __reduce__(self): - return type(self).unpickle, (self._device,) + return type(self), (self._device,) def __repr__(self): return f"SingleDeviceSharding(device={repr(self._device)})" @@ -396,12 +388,8 @@ class PmapSharding(XLACompatibleSharding): # The sharding spec should be pmap's sharding spec. self.sharding_spec = sharding_spec - @classmethod - def unpickle(cls, devices: np.ndarray, sharding_spec: pxla.ShardingSpec): - return cls(devices, sharding_spec) - def __reduce__(self): - return type(self).unpickle, (self.devices, self.sharding_spec) + return type(self), (self.devices, self.sharding_spec) def __eq__(self, other): if not isinstance(other, PmapSharding): @@ -608,12 +596,8 @@ class OpShardingSharding(XLACompatibleSharding): self._devices = tuple(devices) self._op_sharding = op_sharding - @classmethod - def unpickle(cls, devices: Sequence[Device], op_sharding: xc.OpSharding): - return cls(devices, op_sharding) - def __reduce__(self): - return type(self).unpickle, (self._devices, self._op_sharding) + return type(self), (self._devices, self._op_sharding) @pxla.maybe_cached_property def _op_sharding_hash(self): diff --git a/tests/BUILD b/tests/BUILD index 88996e5ca..12e4a0fe7 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -548,7 +548,7 @@ jax_test( srcs = ["pickle_test.py"], deps = [ "//jax:experimental", - ] + py_deps("cloudpickle"), + ] + py_deps("cloudpickle") + py_deps("numpy"), ) jax_test( diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 21250c531..3a9594e9b 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -35,6 +35,8 @@ from jax._src import test_util as jtu from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version +import numpy as np + config.parse_flags_with_absl() @@ -48,6 +50,15 @@ def _get_device_by_id(device_id: int) -> xc.Device: xc.Device.__reduce__ = lambda d: (_get_device_by_id, (d.id,)) +if cloudpickle is not None: + def _reduce_mesh(mesh): + # Avoid including mesh._hash in the serialized bytes for Mesh. Without this + # the Mesh would be different among the workers. + return jax.pxla.Mesh, (mesh.devices, mesh.axis_names) + + cloudpickle.CloudPickler.dispatch_table[jax.pxla.Mesh] = _reduce_mesh + + class CloudpickleTest(jtu.JaxTestCase): @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") @@ -188,5 +199,15 @@ class PickleTest(jtu.JaxTestCase): s = sharding.OpShardingSharding(jax.devices(), op_sharding) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + @unittest.skipIf(xla_extension_version < 104, + 'NamedSharding pickling requires newer jaxlib.') + @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") + def test_pickle_named_sharding(self): + s = jax.sharding.NamedSharding( + mesh=pxla.Mesh(np.array(jax.devices()), 'd'), + spec=pxla.PartitionSpec('d')) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())