mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add Array
counterparts to the serialization_test.py and disable the GDA tests if jax_array is enabled.
PiperOrigin-RevId: 474944400
This commit is contained in:
parent
e6bdb00d31
commit
590b5b5d7f
@ -34,7 +34,9 @@ config.parse_flags_with_absl()
|
||||
|
||||
class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
def test_checkpointing(self):
|
||||
def test_checkpointing_gda(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
@ -94,7 +96,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
def test_checkpointing_with_array(self):
|
||||
def test_checkpointing_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp_shape = (8, 2)
|
||||
pspec = P('x', 'y')
|
||||
@ -155,7 +157,9 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np.asarray(s.data), np.array([]))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
def test_checkpointing_with_bigger_shape(self):
|
||||
def test_checkpointing_with_bigger_shape_gda(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
num = util.prod(global_input_shape)
|
||||
@ -195,7 +199,49 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
|
||||
|
||||
def test_checkpointing_scalar(self):
|
||||
@jax_config.jax_array(True)
|
||||
def test_checkpointing_with_bigger_shape_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape)
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
arr = array.make_array_from_callback(
|
||||
global_input_shape, MeshPspecSharding(global_mesh, P('x', 'y')), cb1)
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([arr], tspecs)
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[jtu.create_global_mesh((4, 2), ('x', 'y'))],
|
||||
[P('x', 'y')],
|
||||
tspecs,
|
||||
[(12, 2)],
|
||||
[np.float32]
|
||||
)
|
||||
|
||||
expected_data = {
|
||||
0: np.array([[0], [2], [4]], dtype=np.float32),
|
||||
1: np.array([[1], [3], [5]], dtype=np.float32),
|
||||
2: np.array([[6], [8], [10]], dtype=np.float32),
|
||||
3: np.array([[7], [9], [11]], dtype=np.float32),
|
||||
4: np.array([[12], [14], [0]], dtype=np.float32),
|
||||
5: np.array([[13], [15], [0]], dtype=np.float32),
|
||||
6: np.array([[0], [0], [0]], dtype=np.float32),
|
||||
7: np.array([[0], [0], [0]], dtype=np.float32),
|
||||
}
|
||||
|
||||
for l in m1.addressable_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
|
||||
|
||||
def test_checkpointing_scalar_gda(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
global_input_shape = ()
|
||||
data = np.array(4)
|
||||
@ -219,7 +265,35 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
|
||||
|
||||
def test_deserialize_tensorstore_array(self):
|
||||
@jax_config.jax_array(True)
|
||||
def test_checkpointing_scalar_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
global_input_shape = ()
|
||||
data = np.array(4)
|
||||
s = MeshPspecSharding(global_mesh, P(None))
|
||||
gda1 = array.make_array_from_callback(
|
||||
global_input_shape, s, lambda idx: data[idx])
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[jtu.create_global_mesh((2,), ('x'))],
|
||||
[P(None)],
|
||||
tspecs,
|
||||
[()],
|
||||
[np.float32]
|
||||
)
|
||||
|
||||
for l in m1.addressable_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
|
||||
|
||||
def test_deserialize_tensorstore_array_gda(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
data = np.arange(1024)
|
||||
tspec = ts.array(data).spec()
|
||||
@ -231,6 +305,19 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), data)
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
def test_deserialize_tensorstore_array_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
data = np.arange(1024)
|
||||
tspec = ts.array(data).spec()
|
||||
m1, = serialization.run_deserialization(
|
||||
[global_mesh],
|
||||
[P(None)],
|
||||
[tspec]
|
||||
)
|
||||
for l in m1.addressable_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), data)
|
||||
|
||||
def test_spec_has_metadata(self):
|
||||
spec = {
|
||||
'a': {
|
||||
|
Loading…
x
Reference in New Issue
Block a user