rocm_jax/tests/pickle_test.py
Sergei Lebedev 0ff234049b Removed trivial docstrings from JAX tests
These docstrings do not make the tests any more clear and typically just duplicate the test module name.

PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00

198 lines
6.0 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import unittest
from absl.testing import absltest
from absl.testing import parameterized
try:
import cloudpickle
except ImportError:
cloudpickle = None
import jax
from jax import numpy as jnp
from jax.interpreters import pxla
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
import numpy as np
jax.config.parse_flags_with_absl()
def _get_device_by_id(device_id: int) -> xc.Device:
for device in jax.devices():
if device.id == device_id:
return device
raise ValueError(f'Device {device_id} was not found')
xc.Device.__reduce__ = lambda d: (_get_device_by_id, (d.id,))
if cloudpickle is not None:
def _reduce_mesh(mesh):
# Avoid including mesh._hash in the serialized bytes for Mesh. Without this
# the Mesh would be different among the workers.
return jax.sharding.Mesh, (mesh.devices, mesh.axis_names)
cloudpickle.CloudPickler.dispatch_table[jax.sharding.Mesh] = _reduce_mesh
class CloudpickleTest(jtu.JaxTestCase):
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
def testPickleOfJittedFunctions(self):
@jax.jit
def f(x, y):
return x * y
@jax.jit
def g(z):
return f(z, z + 77) # noqa: F821
expected = g(32)
s = cloudpickle.dumps(g)
del f, g
g_unpickled = pickle.loads(s)
actual = g_unpickled(32)
self.assertEqual(expected, actual)
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
def testPickleOfPmappedFunctions(self):
@jax.pmap
def f(x, y):
return x * y
@jax.pmap
def g(z):
return f(z, z + 77) # noqa: F821
expected = g(jnp.asarray([[32]]))
s = cloudpickle.dumps(g)
del f, g
g_unpickled = pickle.loads(s)
actual = g_unpickled(jnp.asarray([[32]]))
self.assertEqual(expected, actual)
class PickleTest(jtu.JaxTestCase):
def testPickleOfArray(self):
x = jnp.arange(10.0)
s = pickle.dumps(x)
y = pickle.loads(s)
self.assertArraysEqual(x, y)
self.assertIsInstance(y, type(x))
self.assertEqual(x.aval, y.aval)
def testPickleOfArrayWeakType(self):
x = jnp.array(4.0)
self.assertEqual(x.aval.weak_type, True)
s = pickle.dumps(x)
y = pickle.loads(s)
self.assertArraysEqual(x, y)
self.assertIsInstance(y, type(x))
self.assertEqual(x.aval, y.aval)
@jtu.sample_product(prng_name=['threefry2x32', 'rbg', 'unsafe_rbg'])
def testPickleOfKeyArray(self, prng_name):
with jax.default_prng_impl(prng_name):
k1 = jax.random.PRNGKey(72)
s = pickle.dumps(k1)
k2 = pickle.loads(s)
self.assertEqual(k1.dtype, k2.dtype)
with jax.legacy_prng_key('allow'):
self.assertArraysEqual(jax.random.key_data(k1),
jax.random.key_data(k2))
@parameterized.parameters(
(jax.sharding.PartitionSpec(),),
(jax.sharding.PartitionSpec(None),),
(jax.sharding.PartitionSpec('x', None),),
(jax.sharding.PartitionSpec(None, 'y'),),
(jax.sharding.PartitionSpec('x', 'y'),),
(jax.sharding.PartitionSpec(('x', 'y'),),),
)
def testPickleOfPartitionSpecs(self, partition_spec):
restored_partition_spec = pickle.loads(pickle.dumps(partition_spec))
self.assertIsInstance(restored_partition_spec, jax.sharding.PartitionSpec)
self.assertEqual(partition_spec, restored_partition_spec)
def testPickleX64(self):
with jax.experimental.enable_x64():
x = jnp.array(4.0, dtype='float64')
s = pickle.dumps(x)
with jax.experimental.disable_x64():
y = pickle.loads(s)
self.assertEqual(x.dtype, jnp.float64)
self.assertArraysEqual(x, y, check_dtypes=False)
self.assertEqual(y.dtype, jnp.float32)
self.assertEqual(y.aval.dtype, jnp.float32)
self.assertIsInstance(y, type(x))
def testPickleTracerError(self):
with self.assertRaises(jax.errors.ConcretizationTypeError):
jax.jit(pickle.dumps)(0)
def testPickleSharding(self):
sharding = pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked(
(2, 2)), pxla.Unstacked(3)), (pxla.ShardedAxis(0), pxla.ShardedAxis(1),
pxla.ShardedAxis(2), pxla.Replicated(4)))
self.assertEqual(pickle.loads(pickle.dumps(sharding)), sharding)
def testPickleOpSharding(self):
op = xc.OpSharding()
op.type = xc.OpSharding.Type.OTHER
op.tile_assignment_dimensions = [4, 2]
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
self.assertTrue(
xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op))),
xc.HloSharding.from_proto(op))
def test_pickle_single_device_sharding(self):
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
def test_pickle_pmap_sharding(self):
ss = pxla.ShardingSpec(
sharding=(pxla.Unstacked(8),),
mesh_mapping=(pxla.ShardedAxis(0),))
s = jax.sharding.PmapSharding(jax.devices(), ss)
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
def test_pickle_gspmd_sharding(self):
s = jax.sharding.GSPMDSharding.get_replicated(jax.devices())
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
def test_pickle_named_sharding(self):
s = jax.sharding.NamedSharding(
mesh=jax.sharding.Mesh(np.array(jax.devices()), 'd'),
spec=jax.sharding.PartitionSpec('d'))
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())