mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Delete GDA serialization tests
PiperOrigin-RevId: 505260343
This commit is contained in:
parent
a2970928b7
commit
ef4c2a38d6
@ -24,7 +24,6 @@ from jax.config import config
|
||||
from jax._src import array
|
||||
from jax._src.sharding import NamedSharding, OpShardingSharding
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
@ -34,68 +33,6 @@ config.parse_flags_with_absl()
|
||||
|
||||
class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
def test_checkpointing_gda(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
# First GDA
|
||||
global_input_data1 = np.arange(num).reshape(global_input_shape)
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb1)
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
# Second GDA
|
||||
global_input_data2 = np.arange(num, num + num).reshape(global_input_shape)
|
||||
def cb2(index):
|
||||
return global_input_data2[index]
|
||||
gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb2)
|
||||
ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)
|
||||
|
||||
# Third GDA
|
||||
def cb3(_):
|
||||
return np.array([])
|
||||
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
|
||||
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
|
||||
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1, gda2, gda3], tspecs)
|
||||
|
||||
m1, m2, m3 = serialization.run_deserialization(
|
||||
[NamedSharding(global_mesh, mesh_axes),
|
||||
NamedSharding(global_mesh, P('x')),
|
||||
NamedSharding(global_mesh1d, P(None))],
|
||||
tspecs)
|
||||
|
||||
self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
|
||||
np.array([[0], [2]]))
|
||||
self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
|
||||
np.array([[1], [3]]))
|
||||
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(m2.dtype, np.int32)
|
||||
|
||||
for i, s in enumerate(m3.addressable_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
self.assertArraysEqual(np.asarray(s.data), np.array([]))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
def test_checkpointing_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
@ -159,57 +96,6 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np.asarray(s.data), np.array([]))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
def test_checkpointing_with_bigger_shape_gda(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
# First GDA
|
||||
global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape)
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
P('x', 'y'), cb1)
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[ds],
|
||||
tspecs,
|
||||
[(12, 2)],
|
||||
[np.float32]
|
||||
)
|
||||
|
||||
expected_data = {
|
||||
0: np.array([[0], [2], [4]], dtype=np.float32),
|
||||
1: np.array([[1], [3], [5]], dtype=np.float32),
|
||||
2: np.array([[6], [8], [10]], dtype=np.float32),
|
||||
3: np.array([[7], [9], [11]], dtype=np.float32),
|
||||
4: np.array([[12], [14], [0]], dtype=np.float32),
|
||||
5: np.array([[13], [15], [0]], dtype=np.float32),
|
||||
6: np.array([[0], [0], [0]], dtype=np.float32),
|
||||
7: np.array([[0], [0], [0]], dtype=np.float32),
|
||||
}
|
||||
|
||||
for l in m1.addressable_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Deserializing a GlobalDeviceArray is only possible with '
|
||||
'a `NamedSharding`'):
|
||||
new_ds = OpShardingSharding.get_replicated(list(global_mesh.devices.flat))
|
||||
serialization.run_deserialization([new_ds], tspecs, [(12, 2)], [np.float32])
|
||||
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
def test_checkpointing_with_bigger_shape_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
@ -252,45 +138,20 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m2.addressable_shards:
|
||||
self.assertArraysEqual(l.data, global_input_data1.astype('float32'))
|
||||
|
||||
def test_checkpointing_scalar_gda(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
global_input_shape = ()
|
||||
data = np.array(4)
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
P(None), lambda idx: data[idx])
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[NamedSharding(jtu.create_global_mesh((2,), ('x')), P(None))],
|
||||
tspecs,
|
||||
[()],
|
||||
[np.float32]
|
||||
)
|
||||
|
||||
for l in m1.addressable_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
def test_checkpointing_scalar_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
global_input_shape = ()
|
||||
data = np.array(4)
|
||||
s = NamedSharding(global_mesh, P(None))
|
||||
gda1 = array.make_array_from_callback(
|
||||
array1 = array.make_array_from_callback(
|
||||
global_input_shape, s, lambda idx: data[idx])
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
serialization.run_serialization([array1], tspecs)
|
||||
ds = NamedSharding(jtu.create_global_mesh((2,), ('x')), P(None))
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
@ -303,19 +164,6 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.addressable_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
|
||||
|
||||
def test_deserialize_tensorstore_array_gda(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
data = np.arange(1024)
|
||||
tspec = ts.array(data).spec()
|
||||
m1, = serialization.run_deserialization(
|
||||
[NamedSharding(global_mesh, P(None))],
|
||||
[tspec]
|
||||
)
|
||||
for l in m1.addressable_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), data)
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
def test_deserialize_tensorstore_array_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user