mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Pickling of Sharding classes: use module level functions when deserializing.
This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s. PiperOrigin-RevId: 490249959
This commit is contained in:
parent
d9383fc80d
commit
518fe6656c
@ -242,12 +242,8 @@ class NamedSharding(XLACompatibleSharding):
|
||||
self._parsed_pspec = _parsed_pspec
|
||||
self._preprocess()
|
||||
|
||||
@classmethod
|
||||
def unpickle(cls, mesh, spec):
|
||||
return cls(mesh, spec)
|
||||
|
||||
def __reduce__(self):
|
||||
return type(self).unpickle, (self.mesh, self.spec)
|
||||
return type(self), (self.mesh, self.spec)
|
||||
|
||||
def _preprocess(self):
|
||||
# This split exists because you can pass `_parsed_pspec` that has been
|
||||
@ -352,12 +348,8 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
def __init__(self, device: Device):
|
||||
self._device = device
|
||||
|
||||
@classmethod
|
||||
def unpickle(cls, device: Device):
|
||||
return cls(device)
|
||||
|
||||
def __reduce__(self):
|
||||
return type(self).unpickle, (self._device,)
|
||||
return type(self), (self._device,)
|
||||
|
||||
def __repr__(self):
|
||||
return f"SingleDeviceSharding(device={repr(self._device)})"
|
||||
@ -396,12 +388,8 @@ class PmapSharding(XLACompatibleSharding):
|
||||
# The sharding spec should be pmap's sharding spec.
|
||||
self.sharding_spec = sharding_spec
|
||||
|
||||
@classmethod
|
||||
def unpickle(cls, devices: np.ndarray, sharding_spec: pxla.ShardingSpec):
|
||||
return cls(devices, sharding_spec)
|
||||
|
||||
def __reduce__(self):
|
||||
return type(self).unpickle, (self.devices, self.sharding_spec)
|
||||
return type(self), (self.devices, self.sharding_spec)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, PmapSharding):
|
||||
@ -608,12 +596,8 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
self._devices = tuple(devices)
|
||||
self._op_sharding = op_sharding
|
||||
|
||||
@classmethod
|
||||
def unpickle(cls, devices: Sequence[Device], op_sharding: xc.OpSharding):
|
||||
return cls(devices, op_sharding)
|
||||
|
||||
def __reduce__(self):
|
||||
return type(self).unpickle, (self._devices, self._op_sharding)
|
||||
return type(self), (self._devices, self._op_sharding)
|
||||
|
||||
@pxla.maybe_cached_property
|
||||
def _op_sharding_hash(self):
|
||||
|
@ -548,7 +548,7 @@ jax_test(
|
||||
srcs = ["pickle_test.py"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
] + py_deps("cloudpickle"),
|
||||
] + py_deps("cloudpickle") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -35,6 +35,8 @@ from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@ -48,6 +50,15 @@ def _get_device_by_id(device_id: int) -> xc.Device:
|
||||
xc.Device.__reduce__ = lambda d: (_get_device_by_id, (d.id,))
|
||||
|
||||
|
||||
if cloudpickle is not None:
|
||||
def _reduce_mesh(mesh):
|
||||
# Avoid including mesh._hash in the serialized bytes for Mesh. Without this
|
||||
# the Mesh would be different among the workers.
|
||||
return jax.pxla.Mesh, (mesh.devices, mesh.axis_names)
|
||||
|
||||
cloudpickle.CloudPickler.dispatch_table[jax.pxla.Mesh] = _reduce_mesh
|
||||
|
||||
|
||||
class CloudpickleTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
||||
@ -188,5 +199,15 @@ class PickleTest(jtu.JaxTestCase):
|
||||
s = sharding.OpShardingSharding(jax.devices(), op_sharding)
|
||||
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 104,
|
||||
'NamedSharding pickling requires newer jaxlib.')
|
||||
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
||||
def test_pickle_named_sharding(self):
|
||||
s = jax.sharding.NamedSharding(
|
||||
mesh=pxla.Mesh(np.array(jax.devices()), 'd'),
|
||||
spec=pxla.PartitionSpec('d'))
|
||||
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user