Add ndim and size to GDA

PiperOrigin-RevId: 427874829
This commit is contained in:
Yash Katariya 2022-02-10 16:45:45 -08:00 committed by jax authors
parent 2512aed4bd
commit 8df1932100
2 changed files with 19 additions and 1 deletions

View File

@ -209,6 +209,8 @@ class GlobalDeviceArray:
Attributes:
shape : Global shape of the array.
dtype : Dtype of the global array.
ndim : Number of array dimensions in the global shape.
size: Number of elements in the global array.
local_shards : List of :class:`Shard` on the local devices of the current process.
Data is materialized for all local shards.
global_shards : List of all :class:`Shard` of the global array. Data isnt
@ -312,6 +314,14 @@ class GlobalDeviceArray:
def shape(self) -> Shape:
return self._global_shape
@property
def ndim(self):
return len(self.shape)
@property
def size(self):
return prod(self.shape)
@property
def is_fully_replicated(self) -> bool:
return self.shape == self.local_data(0).shape

View File

@ -67,7 +67,7 @@ class GDATest(jtu.JaxTestCase):
[0, 1, 2, 3, 4, 5, 6, 7], True),
)
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids, expected_is_fully_replicated):
expected_replica_ids, expected_is_fully_replicated):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
@ -77,6 +77,8 @@ class GDATest(jtu.JaxTestCase):
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 2)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
@ -127,6 +129,8 @@ class GDATest(jtu.JaxTestCase):
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 3)
self.assertEqual(gda.size, 64)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
@ -160,6 +164,8 @@ class GDATest(jtu.JaxTestCase):
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 1)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
@ -178,6 +184,8 @@ class GDATest(jtu.JaxTestCase):
return np.array([])
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 1)
self.assertEqual(gda.size, 0)
for i, s in enumerate(gda.local_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)