Add local_data API for GSDA

PiperOrigin-RevId: 411188164
This commit is contained in:
Yash Katariya 2021-11-19 18:55:46 -08:00 committed by jax authors
parent d1de309410
commit 65a99dba7c
2 changed files with 24 additions and 16 deletions

View File

@ -152,6 +152,9 @@ class GlobalShardedDeviceArray:
def global_shards(self) -> Sequence[Shard]:
return self._global_shards
def local_data(self, index) -> DeviceArray:
return self.local_shards[index].data
@classmethod
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes, data_callback: Callable[[Index],

View File

@ -68,6 +68,10 @@ class GSDATest(jtu.JaxTestCase):
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
(1, 2),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_fully_replicated", [],
((slice(None), slice(None)), (slice(None), slice(None))),
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7]),
)
def test_gsda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
@ -81,12 +85,12 @@ class GSDATest(jtu.JaxTestCase):
global_mesh,
mesh_axes, cb)
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gsda.local_shards[0].data,
self.assertArraysEqual(gsda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gsda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gsda.local_shards[1].data,
self.assertArraysEqual(gsda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
self.assertEqual(gsda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gsda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
self.assertListEqual([i.device.id for i in gsda.local_shards],
@ -97,6 +101,7 @@ class GSDATest(jtu.JaxTestCase):
self.assertEqual(g.replica_id, l.replica_id)
self.assertArraysEqual(g.data, l.data)
@parameterized.named_parameters(
("mesh_x_y_z", ["x", "y", "z"],
# There are more slices but for convienient purposes, checking for only
@ -125,12 +130,12 @@ class GSDATest(jtu.JaxTestCase):
global_mesh,
mesh_axes, cb)
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gsda.local_shards[0].data,
self.assertArraysEqual(gsda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gsda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gsda.local_shards[1].data,
self.assertArraysEqual(gsda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
self.assertEqual(gsda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gsda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
@ -158,12 +163,12 @@ class GSDATest(jtu.JaxTestCase):
global_mesh,
mesh_axes, cb)
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gsda.local_shards[0].data,
self.assertArraysEqual(gsda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gsda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gsda.local_shards[1].data,
self.assertArraysEqual(gsda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
self.assertEqual(gsda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gsda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
@ -187,12 +192,12 @@ class GSDATest(jtu.JaxTestCase):
global_mesh,
mesh_axes, cb)
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gsda.local_shards[0].data,
self.assertArraysEqual(gsda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gsda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gsda.local_shards[1].data,
self.assertArraysEqual(gsda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
self.assertEqual(gsda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gsda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
for g, l in safe_zip(gsda.global_shards, gsda.local_shards):
@ -215,10 +220,10 @@ class GSDATest(jtu.JaxTestCase):
gsda = GlobalShardedDeviceArray.from_batched_callback(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1]])
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
self.assertArraysEqual(gsda.local_data(0).to_py(),
expected_first_shard_value)
expected_second_shard_value = np.array([[2, 3]])
self.assertArraysEqual(gsda.local_shards[1].data.to_py(),
self.assertArraysEqual(gsda.local_data(1).to_py(),
expected_second_shard_value)
def test_gsda_batched_callback_with_devices(self):
@ -241,10 +246,10 @@ class GSDATest(jtu.JaxTestCase):
gsda = GlobalShardedDeviceArray.from_batched_callback_with_devices(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
self.assertArraysEqual(gsda.local_data(0).to_py(),
expected_first_shard_value)
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gsda.local_shards[1].data.to_py(),
self.assertArraysEqual(gsda.local_data(1).to_py(),
expected_second_shard_value)