Enable more mesh shape assignment

We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.

PiperOrigin-RevId: 520942700
This commit is contained in:
Zafarali Ahmed 2023-03-31 09:35:19 -07:00 committed by jax authors
parent dfbbc2551c
commit 6e00ba8bad
2 changed files with 7 additions and 17 deletions

View File

@ -109,16 +109,11 @@ def _create_device_mesh_for_nd_torus(
# Map each logical axis to a subset of physical axes.
assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape]
def sort_key(i):
# Sort from smaller dims to larger dims. Because smaller dims have fewer
# choices. Use the index to break ties: `mesh_shape` is assumed to ordered
# by lowest network intensity first, so larger i comes earlier.
return (mesh_shape[i], -i)
sorted_dims = sorted(list(range(len(mesh_shape))), key=sort_key)
for logical_axis_index in sorted_dims:
logical_axis_size = mesh_shape[logical_axis_index]
# Assign logical axes from highest network intensity to lowest.
# `mesh_shape` is assumed to ordered by lowest network intensity first, so
# reverse it first.
for logical_axis_index, logical_axis_size in reversed(
list(enumerate(mesh_shape))):
# Preferentially map to more physical axes first for higher bandwidth.
for num_axes in range(3, 0, -1):
# Try assign to any subset of size num_axes. Generate all candidates.

View File

@ -140,10 +140,6 @@ def mock_8x8x16_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 8x8x16."""
return mock_tpu_devices(8, 8, 16, 'TPU v4', one_device_per_chip)
def mock_4x8x32_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 4x8x16."""
return mock_tpu_devices(4, 8, 32, 'TPU v4', one_device_per_chip)
class MeshUtilsTest(test_util.JaxTestCase):
@ -177,7 +173,7 @@ class MeshUtilsTest(test_util.JaxTestCase):
self.assertEqual(normalized[i, j, k].coords, (i, j, k))
@parameterized.named_parameters(
('2x2x1', mock_2x2x1_devices, [1, 1, 4], [(), (2,), (0, 1)]),
('2x2x1', mock_2x2x1_devices, [1, 1, 4], [(), (), (0, 1, 2)]),
('2x2x4', mock_2x2x4_devices, [1, 4, 4], [(), (2,), (0, 1)]),
('4x4x4', mock_4x4x4_devices, [1, 16, 4], [(), (1, 2), (0,)]),
('4x4x8a', mock_4x4x8_devices, [1, 16, 8], [(), (0, 1), (2,)]),
@ -186,8 +182,7 @@ class MeshUtilsTest(test_util.JaxTestCase):
('4x8x8', mock_4x8x8_devices, [1, 32, 8], [(), (0, 2), (1,)]),
('8x8x8', mock_8x8x8_devices, [1, 64, 8], [(), (1, 2), (0,)]),
('8x8x16', mock_8x8x16_devices, [1, 64, 16], [(), (0, 1), (2,)]),
('8x8', mock_8x8_devices, [8, 8], [(1,), (0, 2)]),
('4x8x32', mock_4x8x32_devices, [8, 4, 32], [(1,), (0,), (2,)]),
('8x8', mock_8x8_devices, [8, 8], [(1,), (0, 2)])
)
def test_create_device_mesh_for_nd_torus(self, devices, mesh_shape,
expected_assignment):