mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add ndim
and size
to GDA
PiperOrigin-RevId: 427874829
This commit is contained in:
parent
2512aed4bd
commit
8df1932100
@ -209,6 +209,8 @@ class GlobalDeviceArray:
|
||||
Attributes:
|
||||
shape : Global shape of the array.
|
||||
dtype : Dtype of the global array.
|
||||
ndim : Number of array dimensions in the global shape.
|
||||
size: Number of elements in the global array.
|
||||
local_shards : List of :class:`Shard` on the local devices of the current process.
|
||||
Data is materialized for all local shards.
|
||||
global_shards : List of all :class:`Shard` of the global array. Data isn’t
|
||||
@ -312,6 +314,14 @@ class GlobalDeviceArray:
|
||||
def shape(self) -> Shape:
|
||||
return self._global_shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.shape)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return prod(self.shape)
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return self.shape == self.local_data(0).shape
|
||||
|
@ -67,7 +67,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
[0, 1, 2, 3, 4, 5, 6, 7], True),
|
||||
)
|
||||
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
expected_replica_ids, expected_is_fully_replicated):
|
||||
expected_replica_ids, expected_is_fully_replicated):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
global_input_data = np.arange(
|
||||
@ -77,6 +77,8 @@ class GDATest(jtu.JaxTestCase):
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.ndim, 2)
|
||||
self.assertEqual(gda.size, 16)
|
||||
self.assertEqual(gda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
@ -127,6 +129,8 @@ class GDATest(jtu.JaxTestCase):
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.ndim, 3)
|
||||
self.assertEqual(gda.size, 64)
|
||||
self.assertEqual(gda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
@ -160,6 +164,8 @@ class GDATest(jtu.JaxTestCase):
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.ndim, 1)
|
||||
self.assertEqual(gda.size, 16)
|
||||
self.assertEqual(gda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
@ -178,6 +184,8 @@ class GDATest(jtu.JaxTestCase):
|
||||
return np.array([])
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.ndim, 1)
|
||||
self.assertEqual(gda.size, 0)
|
||||
for i, s in enumerate(gda.local_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
|
Loading…
x
Reference in New Issue
Block a user