mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add local_data API for GSDA
PiperOrigin-RevId: 411188164
This commit is contained in:
parent
d1de309410
commit
65a99dba7c
@ -152,6 +152,9 @@ class GlobalShardedDeviceArray:
|
||||
def global_shards(self) -> Sequence[Shard]:
|
||||
return self._global_shards
|
||||
|
||||
def local_data(self, index) -> DeviceArray:
|
||||
return self.local_shards[index].data
|
||||
|
||||
@classmethod
|
||||
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes, data_callback: Callable[[Index],
|
||||
|
@ -68,6 +68,10 @@ class GSDATest(jtu.JaxTestCase):
|
||||
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
|
||||
(1, 2),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
("mesh_fully_replicated", [],
|
||||
((slice(None), slice(None)), (slice(None), slice(None))),
|
||||
(8, 2),
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
)
|
||||
def test_gsda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
expected_replica_ids):
|
||||
@ -81,12 +85,12 @@ class GSDATest(jtu.JaxTestCase):
|
||||
global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gsda.local_shards[0].data,
|
||||
self.assertArraysEqual(gsda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
self.assertEqual(gsda.local_shards[1].index, expected_index[1])
|
||||
self.assertArraysEqual(gsda.local_shards[1].data,
|
||||
self.assertArraysEqual(gsda.local_data(1),
|
||||
global_input_data[expected_index[1]])
|
||||
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
|
||||
self.assertEqual(gsda.local_data(0).shape, expected_shard_shape)
|
||||
replica_ids = [i.replica_id for i in gsda.local_shards]
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
self.assertListEqual([i.device.id for i in gsda.local_shards],
|
||||
@ -97,6 +101,7 @@ class GSDATest(jtu.JaxTestCase):
|
||||
self.assertEqual(g.replica_id, l.replica_id)
|
||||
self.assertArraysEqual(g.data, l.data)
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y_z", ["x", "y", "z"],
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
@ -125,12 +130,12 @@ class GSDATest(jtu.JaxTestCase):
|
||||
global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gsda.local_shards[0].data,
|
||||
self.assertArraysEqual(gsda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
self.assertEqual(gsda.local_shards[1].index, expected_index[1])
|
||||
self.assertArraysEqual(gsda.local_shards[1].data,
|
||||
self.assertArraysEqual(gsda.local_data(1),
|
||||
global_input_data[expected_index[1]])
|
||||
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
|
||||
self.assertEqual(gsda.local_data(0).shape, expected_shard_shape)
|
||||
|
||||
replica_ids = [i.replica_id for i in gsda.local_shards]
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
@ -158,12 +163,12 @@ class GSDATest(jtu.JaxTestCase):
|
||||
global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gsda.local_shards[0].data,
|
||||
self.assertArraysEqual(gsda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
self.assertEqual(gsda.local_shards[1].index, expected_index[1])
|
||||
self.assertArraysEqual(gsda.local_shards[1].data,
|
||||
self.assertArraysEqual(gsda.local_data(1),
|
||||
global_input_data[expected_index[1]])
|
||||
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
|
||||
self.assertEqual(gsda.local_data(0).shape, expected_shard_shape)
|
||||
replica_ids = [i.replica_id for i in gsda.local_shards]
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
|
||||
@ -187,12 +192,12 @@ class GSDATest(jtu.JaxTestCase):
|
||||
global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gsda.local_shards[0].data,
|
||||
self.assertArraysEqual(gsda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
self.assertEqual(gsda.local_shards[1].index, expected_index[1])
|
||||
self.assertArraysEqual(gsda.local_shards[1].data,
|
||||
self.assertArraysEqual(gsda.local_data(1),
|
||||
global_input_data[expected_index[1]])
|
||||
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
|
||||
self.assertEqual(gsda.local_data(0).shape, expected_shard_shape)
|
||||
replica_ids = [i.replica_id for i in gsda.local_shards]
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
for g, l in safe_zip(gsda.global_shards, gsda.local_shards):
|
||||
@ -215,10 +220,10 @@ class GSDATest(jtu.JaxTestCase):
|
||||
gsda = GlobalShardedDeviceArray.from_batched_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
expected_first_shard_value = np.array([[0, 1]])
|
||||
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
|
||||
self.assertArraysEqual(gsda.local_data(0).to_py(),
|
||||
expected_first_shard_value)
|
||||
expected_second_shard_value = np.array([[2, 3]])
|
||||
self.assertArraysEqual(gsda.local_shards[1].data.to_py(),
|
||||
self.assertArraysEqual(gsda.local_data(1).to_py(),
|
||||
expected_second_shard_value)
|
||||
|
||||
def test_gsda_batched_callback_with_devices(self):
|
||||
@ -241,10 +246,10 @@ class GSDATest(jtu.JaxTestCase):
|
||||
gsda = GlobalShardedDeviceArray.from_batched_callback_with_devices(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
|
||||
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
|
||||
self.assertArraysEqual(gsda.local_data(0).to_py(),
|
||||
expected_first_shard_value)
|
||||
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
|
||||
self.assertArraysEqual(gsda.local_shards[1].data.to_py(),
|
||||
self.assertArraysEqual(gsda.local_data(1).to_py(),
|
||||
expected_second_shard_value)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user