mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add global_shards
to jax.Array
as it exists on GDA and is being used in various places.
PiperOrigin-RevId: 485065876
This commit is contained in:
parent
5b0cb11304
commit
32a0ea80ef
@ -407,6 +407,28 @@ class ArrayImpl(basearray.Array):
|
||||
out.append(Shard(db.device(), self.sharding, self.shape, array))
|
||||
return out
|
||||
|
||||
@property
|
||||
def global_shards(self) -> Sequence[Shard]:
|
||||
"""Returns list of all `Shard`s of the Array across all devices.
|
||||
|
||||
The result includes shards that are not addressable by the current process.
|
||||
If a `Shard` is not addressable, then its `data` will be `None`.
|
||||
"""
|
||||
self._check_if_deleted()
|
||||
if self.is_fully_addressable: # pylint: disable=using-constant-test
|
||||
return self.addressable_shards
|
||||
|
||||
out = []
|
||||
device_id_to_buffer = {db.device().id: db for db in self._arrays}
|
||||
for global_d in self.sharding.device_set:
|
||||
if device_id_to_buffer.get(global_d.id, None) is not None:
|
||||
array = _single_device_array_from_buf(
|
||||
device_id_to_buffer[global_d.id], self._committed)
|
||||
else:
|
||||
array = None
|
||||
out.append(Shard(global_d, self.sharding, self.shape, array))
|
||||
return out
|
||||
|
||||
def delete(self):
|
||||
if self._arrays is None:
|
||||
return
|
||||
|
@ -139,6 +139,13 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(s.data, global_input_data[s.index])
|
||||
self.assertArraysEqual(s.data, arr.addressable_data(i))
|
||||
|
||||
for g, l in safe_zip(arr.global_shards, arr.addressable_shards):
|
||||
self.assertEqual(g.device, l.device)
|
||||
self.assertEqual(g.index, l.index)
|
||||
self.assertEqual(g.replica_id, l.replica_id)
|
||||
self.assertEqual(g.data.aval, l.data.aval)
|
||||
self.assertArraysEqual(g.data, l.data)
|
||||
|
||||
def test_addressable_data(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
shape = (8, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user