Allow a single logical mesh dim to take all devices.

PiperOrigin-RevId: 435240241
This commit is contained in:
Yuanzhong Xu 2022-03-16 20:57:29 -07:00 committed by jax authors
parent e309fb98de
commit 3a949acccb
2 changed files with 2 additions and 2 deletions

View File

@ -116,7 +116,7 @@ def _create_device_mesh_for_tpu_v4(
for logical_axis_index, logical_axis_size in reversed(
list(enumerate(mesh_shape))):
# Preferentially map to 2D subplane first for higher bandwidth.
for num_axes in range(2, 0, -1):
for num_axes in range(3, 0, -1):
# Try assign to any subset of size num_axes. Generate all candidates.
axes = itertools.combinations(assignable_physical_mesh, num_axes)
indices = itertools.combinations(

View File

@ -187,7 +187,7 @@ class PartitioningTest(test_util.JaxTestCase):
self.assertEqual(normalized[i, j, k].coords, (i, j, k))
@parameterized.named_parameters(
('2x2x1', mock_2x2x1_devices, [1, 1, 4], ((), (2,), (0, 1))),
('2x2x1', mock_2x2x1_devices, [1, 1, 4], ((), (), (0, 1, 2))),
('2x2x4', mock_2x2x4_devices, [1, 4, 4], ((), (2,), (0, 1))),
('4x4x4', mock_4x4x4_devices, [1, 16, 4], ((), (1, 2), (0,))),
('4x4x8a', mock_4x4x8_devices, [1, 16, 8], ((), (0, 1), (2,))),