Pickling of Sharding classes: use module level functions when deserializing.

This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s.

PiperOrigin-RevId: 490249959
This commit is contained in:
jax authors 2022-11-22 08:30:38 -08:00
parent d9383fc80d
commit 518fe6656c
3 changed files with 26 additions and 21 deletions

View File

@ -242,12 +242,8 @@ class NamedSharding(XLACompatibleSharding):
self._parsed_pspec = _parsed_pspec
self._preprocess()
@classmethod
def unpickle(cls, mesh, spec):
return cls(mesh, spec)
def __reduce__(self):
return type(self).unpickle, (self.mesh, self.spec)
return type(self), (self.mesh, self.spec)
def _preprocess(self):
# This split exists because you can pass `_parsed_pspec` that has been
@ -352,12 +348,8 @@ class SingleDeviceSharding(XLACompatibleSharding):
def __init__(self, device: Device):
self._device = device
@classmethod
def unpickle(cls, device: Device):
return cls(device)
def __reduce__(self):
return type(self).unpickle, (self._device,)
return type(self), (self._device,)
def __repr__(self):
return f"SingleDeviceSharding(device={repr(self._device)})"
@ -396,12 +388,8 @@ class PmapSharding(XLACompatibleSharding):
# The sharding spec should be pmap's sharding spec.
self.sharding_spec = sharding_spec
@classmethod
def unpickle(cls, devices: np.ndarray, sharding_spec: pxla.ShardingSpec):
return cls(devices, sharding_spec)
def __reduce__(self):
return type(self).unpickle, (self.devices, self.sharding_spec)
return type(self), (self.devices, self.sharding_spec)
def __eq__(self, other):
if not isinstance(other, PmapSharding):
@ -608,12 +596,8 @@ class OpShardingSharding(XLACompatibleSharding):
self._devices = tuple(devices)
self._op_sharding = op_sharding
@classmethod
def unpickle(cls, devices: Sequence[Device], op_sharding: xc.OpSharding):
return cls(devices, op_sharding)
def __reduce__(self):
return type(self).unpickle, (self._devices, self._op_sharding)
return type(self), (self._devices, self._op_sharding)
@pxla.maybe_cached_property
def _op_sharding_hash(self):

View File

@ -548,7 +548,7 @@ jax_test(
srcs = ["pickle_test.py"],
deps = [
"//jax:experimental",
] + py_deps("cloudpickle"),
] + py_deps("cloudpickle") + py_deps("numpy"),
)
jax_test(

View File

@ -35,6 +35,8 @@ from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
import numpy as np
config.parse_flags_with_absl()
@ -48,6 +50,15 @@ def _get_device_by_id(device_id: int) -> xc.Device:
xc.Device.__reduce__ = lambda d: (_get_device_by_id, (d.id,))
if cloudpickle is not None:
def _reduce_mesh(mesh):
# Avoid including mesh._hash in the serialized bytes for Mesh. Without this
# the Mesh would be different among the workers.
return jax.pxla.Mesh, (mesh.devices, mesh.axis_names)
cloudpickle.CloudPickler.dispatch_table[jax.pxla.Mesh] = _reduce_mesh
class CloudpickleTest(jtu.JaxTestCase):
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
@ -188,5 +199,15 @@ class PickleTest(jtu.JaxTestCase):
s = sharding.OpShardingSharding(jax.devices(), op_sharding)
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
@unittest.skipIf(xla_extension_version < 104,
'NamedSharding pickling requires newer jaxlib.')
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
def test_pickle_named_sharding(self):
s = jax.sharding.NamedSharding(
mesh=pxla.Mesh(np.array(jax.devices()), 'd'),
spec=pxla.PartitionSpec('d'))
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())