mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Enable more mesh shape assignment
We now sort the mesh dims by size first. Smaller dims have fewer choices so they should be assigned first. PiperOrigin-RevId: 520942700
This commit is contained in:
parent
dfbbc2551c
commit
6e00ba8bad
@ -109,16 +109,11 @@ def _create_device_mesh_for_nd_torus(
|
||||
# Map each logical axis to a subset of physical axes.
|
||||
assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape]
|
||||
|
||||
def sort_key(i):
|
||||
# Sort from smaller dims to larger dims. Because smaller dims have fewer
|
||||
# choices. Use the index to break ties: `mesh_shape` is assumed to ordered
|
||||
# by lowest network intensity first, so larger i comes earlier.
|
||||
return (mesh_shape[i], -i)
|
||||
|
||||
sorted_dims = sorted(list(range(len(mesh_shape))), key=sort_key)
|
||||
|
||||
for logical_axis_index in sorted_dims:
|
||||
logical_axis_size = mesh_shape[logical_axis_index]
|
||||
# Assign logical axes from highest network intensity to lowest.
|
||||
# `mesh_shape` is assumed to ordered by lowest network intensity first, so
|
||||
# reverse it first.
|
||||
for logical_axis_index, logical_axis_size in reversed(
|
||||
list(enumerate(mesh_shape))):
|
||||
# Preferentially map to more physical axes first for higher bandwidth.
|
||||
for num_axes in range(3, 0, -1):
|
||||
# Try assign to any subset of size num_axes. Generate all candidates.
|
||||
|
@ -140,10 +140,6 @@ def mock_8x8x16_devices(one_device_per_chip):
|
||||
"""Hard-coded reproduction of jax.devices() output on 8x8x16."""
|
||||
return mock_tpu_devices(8, 8, 16, 'TPU v4', one_device_per_chip)
|
||||
|
||||
def mock_4x8x32_devices(one_device_per_chip):
|
||||
"""Hard-coded reproduction of jax.devices() output on 4x8x16."""
|
||||
return mock_tpu_devices(4, 8, 32, 'TPU v4', one_device_per_chip)
|
||||
|
||||
|
||||
class MeshUtilsTest(test_util.JaxTestCase):
|
||||
|
||||
@ -177,7 +173,7 @@ class MeshUtilsTest(test_util.JaxTestCase):
|
||||
self.assertEqual(normalized[i, j, k].coords, (i, j, k))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('2x2x1', mock_2x2x1_devices, [1, 1, 4], [(), (2,), (0, 1)]),
|
||||
('2x2x1', mock_2x2x1_devices, [1, 1, 4], [(), (), (0, 1, 2)]),
|
||||
('2x2x4', mock_2x2x4_devices, [1, 4, 4], [(), (2,), (0, 1)]),
|
||||
('4x4x4', mock_4x4x4_devices, [1, 16, 4], [(), (1, 2), (0,)]),
|
||||
('4x4x8a', mock_4x4x8_devices, [1, 16, 8], [(), (0, 1), (2,)]),
|
||||
@ -186,8 +182,7 @@ class MeshUtilsTest(test_util.JaxTestCase):
|
||||
('4x8x8', mock_4x8x8_devices, [1, 32, 8], [(), (0, 2), (1,)]),
|
||||
('8x8x8', mock_8x8x8_devices, [1, 64, 8], [(), (1, 2), (0,)]),
|
||||
('8x8x16', mock_8x8x16_devices, [1, 64, 16], [(), (0, 1), (2,)]),
|
||||
('8x8', mock_8x8_devices, [8, 8], [(1,), (0, 2)]),
|
||||
('4x8x32', mock_4x8x32_devices, [8, 4, 32], [(1,), (0,), (2,)]),
|
||||
('8x8', mock_8x8_devices, [8, 8], [(1,), (0, 2)])
|
||||
)
|
||||
def test_create_device_mesh_for_nd_torus(self, devices, mesh_shape,
|
||||
expected_assignment):
|
||||
|
Loading…
x
Reference in New Issue
Block a user