From 32a0ea80ef0663f606c14bec2f44ec2845bc2c17 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 31 Oct 2022 09:07:28 -0700 Subject: [PATCH] Add `global_shards` to `jax.Array` as it exists on GDA and is being used in various places. PiperOrigin-RevId: 485065876 --- jax/_src/array.py | 22 ++++++++++++++++++++++ tests/array_test.py | 7 +++++++ 2 files changed, 29 insertions(+) diff --git a/jax/_src/array.py b/jax/_src/array.py index ee1ca13a8..85957dbeb 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -407,6 +407,28 @@ class ArrayImpl(basearray.Array): out.append(Shard(db.device(), self.sharding, self.shape, array)) return out + @property + def global_shards(self) -> Sequence[Shard]: + """Returns list of all `Shard`s of the Array across all devices. + + The result includes shards that are not addressable by the current process. + If a `Shard` is not addressable, then its `data` will be `None`. + """ + self._check_if_deleted() + if self.is_fully_addressable: # pylint: disable=using-constant-test + return self.addressable_shards + + out = [] + device_id_to_buffer = {db.device().id: db for db in self._arrays} + for global_d in self.sharding.device_set: + if device_id_to_buffer.get(global_d.id, None) is not None: + array = _single_device_array_from_buf( + device_id_to_buffer[global_d.id], self._committed) + else: + array = None + out.append(Shard(global_d, self.sharding, self.shape, array)) + return out + def delete(self): if self._arrays is None: return diff --git a/tests/array_test.py b/tests/array_test.py index dbe40d8b8..48c5594cf 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -139,6 +139,13 @@ class JaxArrayTest(jtu.JaxTestCase): self.assertArraysEqual(s.data, global_input_data[s.index]) self.assertArraysEqual(s.data, arr.addressable_data(i)) + for g, l in safe_zip(arr.global_shards, arr.addressable_shards): + self.assertEqual(g.device, l.device) + self.assertEqual(g.index, l.index) + self.assertEqual(g.replica_id, l.replica_id) + self.assertEqual(g.data.aval, l.data.aval) + self.assertArraysEqual(g.data, l.data) + def test_addressable_data(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) shape = (8, 2)