mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

These docstrings do not make the tests any more clear and typically just duplicate the test module name. PiperOrigin-RevId: 737611977
198 lines
6.0 KiB
Python
198 lines
6.0 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import pickle
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
try:
|
|
import cloudpickle
|
|
except ImportError:
|
|
cloudpickle = None
|
|
|
|
import jax
|
|
from jax import numpy as jnp
|
|
from jax.interpreters import pxla
|
|
from jax._src import test_util as jtu
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
import numpy as np
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
def _get_device_by_id(device_id: int) -> xc.Device:
|
|
for device in jax.devices():
|
|
if device.id == device_id:
|
|
return device
|
|
raise ValueError(f'Device {device_id} was not found')
|
|
|
|
|
|
xc.Device.__reduce__ = lambda d: (_get_device_by_id, (d.id,))
|
|
|
|
|
|
if cloudpickle is not None:
|
|
def _reduce_mesh(mesh):
|
|
# Avoid including mesh._hash in the serialized bytes for Mesh. Without this
|
|
# the Mesh would be different among the workers.
|
|
return jax.sharding.Mesh, (mesh.devices, mesh.axis_names)
|
|
|
|
cloudpickle.CloudPickler.dispatch_table[jax.sharding.Mesh] = _reduce_mesh
|
|
|
|
|
|
class CloudpickleTest(jtu.JaxTestCase):
|
|
|
|
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
|
def testPickleOfJittedFunctions(self):
|
|
|
|
@jax.jit
|
|
def f(x, y):
|
|
return x * y
|
|
|
|
@jax.jit
|
|
def g(z):
|
|
return f(z, z + 77) # noqa: F821
|
|
|
|
expected = g(32)
|
|
s = cloudpickle.dumps(g)
|
|
del f, g
|
|
|
|
g_unpickled = pickle.loads(s)
|
|
actual = g_unpickled(32)
|
|
self.assertEqual(expected, actual)
|
|
|
|
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
|
def testPickleOfPmappedFunctions(self):
|
|
|
|
@jax.pmap
|
|
def f(x, y):
|
|
return x * y
|
|
|
|
@jax.pmap
|
|
def g(z):
|
|
return f(z, z + 77) # noqa: F821
|
|
|
|
expected = g(jnp.asarray([[32]]))
|
|
s = cloudpickle.dumps(g)
|
|
del f, g
|
|
|
|
g_unpickled = pickle.loads(s)
|
|
actual = g_unpickled(jnp.asarray([[32]]))
|
|
self.assertEqual(expected, actual)
|
|
|
|
|
|
class PickleTest(jtu.JaxTestCase):
|
|
|
|
def testPickleOfArray(self):
|
|
x = jnp.arange(10.0)
|
|
s = pickle.dumps(x)
|
|
y = pickle.loads(s)
|
|
self.assertArraysEqual(x, y)
|
|
self.assertIsInstance(y, type(x))
|
|
self.assertEqual(x.aval, y.aval)
|
|
|
|
def testPickleOfArrayWeakType(self):
|
|
x = jnp.array(4.0)
|
|
self.assertEqual(x.aval.weak_type, True)
|
|
s = pickle.dumps(x)
|
|
y = pickle.loads(s)
|
|
self.assertArraysEqual(x, y)
|
|
self.assertIsInstance(y, type(x))
|
|
self.assertEqual(x.aval, y.aval)
|
|
|
|
@jtu.sample_product(prng_name=['threefry2x32', 'rbg', 'unsafe_rbg'])
|
|
def testPickleOfKeyArray(self, prng_name):
|
|
with jax.default_prng_impl(prng_name):
|
|
k1 = jax.random.PRNGKey(72)
|
|
s = pickle.dumps(k1)
|
|
k2 = pickle.loads(s)
|
|
self.assertEqual(k1.dtype, k2.dtype)
|
|
with jax.legacy_prng_key('allow'):
|
|
self.assertArraysEqual(jax.random.key_data(k1),
|
|
jax.random.key_data(k2))
|
|
|
|
@parameterized.parameters(
|
|
(jax.sharding.PartitionSpec(),),
|
|
(jax.sharding.PartitionSpec(None),),
|
|
(jax.sharding.PartitionSpec('x', None),),
|
|
(jax.sharding.PartitionSpec(None, 'y'),),
|
|
(jax.sharding.PartitionSpec('x', 'y'),),
|
|
(jax.sharding.PartitionSpec(('x', 'y'),),),
|
|
)
|
|
def testPickleOfPartitionSpecs(self, partition_spec):
|
|
restored_partition_spec = pickle.loads(pickle.dumps(partition_spec))
|
|
self.assertIsInstance(restored_partition_spec, jax.sharding.PartitionSpec)
|
|
self.assertEqual(partition_spec, restored_partition_spec)
|
|
|
|
def testPickleX64(self):
|
|
with jax.experimental.enable_x64():
|
|
x = jnp.array(4.0, dtype='float64')
|
|
s = pickle.dumps(x)
|
|
|
|
with jax.experimental.disable_x64():
|
|
y = pickle.loads(s)
|
|
|
|
self.assertEqual(x.dtype, jnp.float64)
|
|
self.assertArraysEqual(x, y, check_dtypes=False)
|
|
self.assertEqual(y.dtype, jnp.float32)
|
|
self.assertEqual(y.aval.dtype, jnp.float32)
|
|
self.assertIsInstance(y, type(x))
|
|
|
|
def testPickleTracerError(self):
|
|
with self.assertRaises(jax.errors.ConcretizationTypeError):
|
|
jax.jit(pickle.dumps)(0)
|
|
|
|
def testPickleSharding(self):
|
|
sharding = pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked(
|
|
(2, 2)), pxla.Unstacked(3)), (pxla.ShardedAxis(0), pxla.ShardedAxis(1),
|
|
pxla.ShardedAxis(2), pxla.Replicated(4)))
|
|
self.assertEqual(pickle.loads(pickle.dumps(sharding)), sharding)
|
|
|
|
def testPickleOpSharding(self):
|
|
op = xc.OpSharding()
|
|
op.type = xc.OpSharding.Type.OTHER
|
|
op.tile_assignment_dimensions = [4, 2]
|
|
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
self.assertTrue(
|
|
xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op))),
|
|
xc.HloSharding.from_proto(op))
|
|
|
|
def test_pickle_single_device_sharding(self):
|
|
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
|
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
|
|
|
|
def test_pickle_pmap_sharding(self):
|
|
ss = pxla.ShardingSpec(
|
|
sharding=(pxla.Unstacked(8),),
|
|
mesh_mapping=(pxla.ShardedAxis(0),))
|
|
s = jax.sharding.PmapSharding(jax.devices(), ss)
|
|
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
|
|
|
|
def test_pickle_gspmd_sharding(self):
|
|
s = jax.sharding.GSPMDSharding.get_replicated(jax.devices())
|
|
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
|
|
|
|
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
|
def test_pickle_named_sharding(self):
|
|
s = jax.sharding.NamedSharding(
|
|
mesh=jax.sharding.Mesh(np.array(jax.devices()), 'd'),
|
|
spec=jax.sharding.PartitionSpec('d'))
|
|
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|