Use the new mesh property instead of the private _global_mesh attribute.

PiperOrigin-RevId: 431815802
This commit is contained in:
Yash Katariya 2022-03-01 17:43:11 -08:00 committed by jax authors
parent d9f82f7b9b
commit 72cc567c05
5 changed files with 16 additions and 16 deletions

View File

@ -2023,10 +2023,10 @@ def _check_gda_xmap_partitioning(axis_resources, resource_env,
in_positional_semantics).to_mesh_axes(in_axes_flat)
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
if isinstance(arg, GlobalDeviceArray):
if arg._global_mesh != resource_env.physical_mesh:
if arg.mesh != resource_env.physical_mesh:
raise ValueError("xmap's mesh and GDA's mesh should be equal. Got Xmap "
f"mesh: {resource_env.physical_mesh},\n"
f"GDA mesh: {arg._global_mesh}")
f"GDA mesh: {arg.mesh}")
gda_array_mapping = _get_array_mapping(arg._mesh_axes)
if gda_array_mapping != xmap_array_mapping:

View File

@ -102,7 +102,7 @@ def process_allgather(in_tree: PyTreeDef, titled: bool = False) -> PyTreeDef:
if isinstance(inp, GlobalDeviceArray):
if inp.is_fully_replicated:
return inp.local_data(0).to_py()
global_mesh = inp._global_mesh
global_mesh = inp.mesh
in_axis_resources = FROM_GDA
else:
# DA/SDA/np.array will be sharded based on global_mesh.local_mesh.

View File

@ -1076,9 +1076,9 @@ def gda_mesh_axes_to_canonicalized_parsed_pspec(mesh_axes) -> CanonicalizedParse
def _maybe_check_pjit_gda_mesh(args, mesh):
for x in args:
if isinstance(x, GDA) and x._global_mesh != mesh:
if isinstance(x, GDA) and x.mesh != mesh:
raise ValueError("Pjit's mesh and GDA's mesh should be equal. Got Pjit "
f"mesh: {mesh},\n GDA mesh: {x._global_mesh}")
f"mesh: {mesh},\n GDA mesh: {x.mesh}")
# -------------------- XLA OpSharding to PartitionSpec --------------------
# Note that OpSharding is more expressive than PartitionSpecs, so it's not

View File

@ -804,7 +804,7 @@ class GDAPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out.mesh.shape, {'x': 4, 'y': 2})
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@ -839,7 +839,7 @@ class GDAPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out.mesh.shape, {'x': 4, 'y': 2})
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@ -949,14 +949,14 @@ class GDAPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out1._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
for s in out1.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8, 8))
self.assertEqual(out2.local_shards[0].data.shape, (1, 8))
self.assertDictEqual(out2._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
for s in out2.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@ -978,14 +978,14 @@ class GDAPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out1._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
for s in out1.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8, 8))
self.assertEqual(out2.local_shards[0].data.shape, (1, 8))
self.assertDictEqual(out2._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
for s in out2.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])

View File

@ -920,7 +920,7 @@ class XMapGDATest(XMapTestCase):
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.local_shards[0].data.shape, (2, 1))
self.assertDictEqual(out._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out.mesh.shape, {'x': 4, 'y': 2})
for s in out.local_shards:
self.assertArraysEqual(s.data, input_data[s.index])
@ -950,14 +950,14 @@ class XMapGDATest(XMapTestCase):
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8,))
self.assertEqual(out1.local_shards[0].data.shape, (2,))
self.assertDictEqual(out1._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
for s in out1.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8,))
self.assertEqual(out2.local_shards[0].data.shape, (2,))
self.assertDictEqual(out2._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
for s in out2.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@ -991,14 +991,14 @@ class XMapGDATest(XMapTestCase):
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8,))
self.assertEqual(out1.local_shards[0].data.shape, (2,))
self.assertDictEqual(out1._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
for s in out1.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8,))
self.assertEqual(out2.local_shards[0].data.shape, (4,))
self.assertDictEqual(out2._global_mesh.shape, {'x': 4, 'y': 2})
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
for s in out2.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])