Add Array counterparts to the serialization_test.py and disable the GDA tests if jax_array is enabled.

PiperOrigin-RevId: 474944400
This commit is contained in:
Yash Katariya 2022-09-16 18:37:07 -07:00 committed by jax authors
parent e6bdb00d31
commit 590b5b5d7f

View File

@ -34,7 +34,9 @@ config.parse_flags_with_absl()
class CheckpointTest(jtu.JaxTestCase):
def test_checkpointing(self):
def test_checkpointing_gda(self):
if config.jax_array:
self.skipTest('GDA and Array cannot be enabled together.')
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
@ -94,7 +96,7 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertEqual(m3.dtype, np.float32)
@jax_config.jax_array(True)
def test_checkpointing_with_array(self):
def test_checkpointing_jax_array(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp_shape = (8, 2)
pspec = P('x', 'y')
@ -155,7 +157,9 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertArraysEqual(np.asarray(s.data), np.array([]))
self.assertEqual(m3.dtype, np.float32)
def test_checkpointing_with_bigger_shape(self):
def test_checkpointing_with_bigger_shape_gda(self):
if config.jax_array:
self.skipTest('GDA and Array cannot be enabled together.')
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
num = util.prod(global_input_shape)
@ -195,7 +199,49 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.local_shards:
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
def test_checkpointing_scalar(self):
@jax_config.jax_array(True)
def test_checkpointing_with_bigger_shape_jax_array(self):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
num = util.prod(global_input_shape)
global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape)
def cb1(index):
return global_input_data1[index]
arr = array.make_array_from_callback(
global_input_shape, MeshPspecSharding(global_mesh, P('x', 'y')), cb1)
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
ckpt_paths = [str(ckpt_dir1)]
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization([arr], tspecs)
m1, = serialization.run_deserialization(
[jtu.create_global_mesh((4, 2), ('x', 'y'))],
[P('x', 'y')],
tspecs,
[(12, 2)],
[np.float32]
)
expected_data = {
0: np.array([[0], [2], [4]], dtype=np.float32),
1: np.array([[1], [3], [5]], dtype=np.float32),
2: np.array([[6], [8], [10]], dtype=np.float32),
3: np.array([[7], [9], [11]], dtype=np.float32),
4: np.array([[12], [14], [0]], dtype=np.float32),
5: np.array([[13], [15], [0]], dtype=np.float32),
6: np.array([[0], [0], [0]], dtype=np.float32),
7: np.array([[0], [0], [0]], dtype=np.float32),
}
for l in m1.addressable_shards:
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
def test_checkpointing_scalar_gda(self):
if config.jax_array:
self.skipTest('GDA and Array cannot be enabled together.')
global_mesh = jtu.create_global_mesh((2,), ('x'))
global_input_shape = ()
data = np.array(4)
@ -219,7 +265,35 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.local_shards:
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
def test_deserialize_tensorstore_array(self):
@jax_config.jax_array(True)
def test_checkpointing_scalar_jax_array(self):
global_mesh = jtu.create_global_mesh((2,), ('x'))
global_input_shape = ()
data = np.array(4)
s = MeshPspecSharding(global_mesh, P(None))
gda1 = array.make_array_from_callback(
global_input_shape, s, lambda idx: data[idx])
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
ckpt_paths = [str(ckpt_dir1)]
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization([gda1], tspecs)
m1, = serialization.run_deserialization(
[jtu.create_global_mesh((2,), ('x'))],
[P(None)],
tspecs,
[()],
[np.float32]
)
for l in m1.addressable_shards:
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
def test_deserialize_tensorstore_array_gda(self):
if config.jax_array:
self.skipTest('GDA and Array cannot be enabled together.')
global_mesh = jtu.create_global_mesh((2,), ('x'))
data = np.arange(1024)
tspec = ts.array(data).spec()
@ -231,6 +305,19 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.local_shards:
self.assertArraysEqual(np.asarray(l.data), data)
@jax_config.jax_array(True)
def test_deserialize_tensorstore_array_jax_array(self):
global_mesh = jtu.create_global_mesh((2,), ('x'))
data = np.arange(1024)
tspec = ts.array(data).spec()
m1, = serialization.run_deserialization(
[global_mesh],
[P(None)],
[tspec]
)
for l in m1.addressable_shards:
self.assertArraysEqual(np.asarray(l.data), data)
def test_spec_has_metadata(self):
spec = {
'a': {