Add global_shards to jax.Array as it exists on GDA and is being used in various places.

PiperOrigin-RevId: 485065876
This commit is contained in:
Yash Katariya 2022-10-31 09:07:28 -07:00 committed by jax authors
parent 5b0cb11304
commit 32a0ea80ef
2 changed files with 29 additions and 0 deletions

View File

@ -407,6 +407,28 @@ class ArrayImpl(basearray.Array):
out.append(Shard(db.device(), self.sharding, self.shape, array))
return out
@property
def global_shards(self) -> Sequence[Shard]:
"""Returns list of all `Shard`s of the Array across all devices.
The result includes shards that are not addressable by the current process.
If a `Shard` is not addressable, then its `data` will be `None`.
"""
self._check_if_deleted()
if self.is_fully_addressable: # pylint: disable=using-constant-test
return self.addressable_shards
out = []
device_id_to_buffer = {db.device().id: db for db in self._arrays}
for global_d in self.sharding.device_set:
if device_id_to_buffer.get(global_d.id, None) is not None:
array = _single_device_array_from_buf(
device_id_to_buffer[global_d.id], self._committed)
else:
array = None
out.append(Shard(global_d, self.sharding, self.shape, array))
return out
def delete(self):
if self._arrays is None:
return

View File

@ -139,6 +139,13 @@ class JaxArrayTest(jtu.JaxTestCase):
self.assertArraysEqual(s.data, global_input_data[s.index])
self.assertArraysEqual(s.data, arr.addressable_data(i))
for g, l in safe_zip(arr.global_shards, arr.addressable_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertEqual(g.data.aval, l.data.aval)
self.assertArraysEqual(g.data, l.data)
def test_addressable_data(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 2)