mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Allow a single logical mesh dim to take all devices.
PiperOrigin-RevId: 435240241
This commit is contained in:
parent
e309fb98de
commit
3a949acccb
@ -116,7 +116,7 @@ def _create_device_mesh_for_tpu_v4(
|
||||
for logical_axis_index, logical_axis_size in reversed(
|
||||
list(enumerate(mesh_shape))):
|
||||
# Preferentially map to 2D subplane first for higher bandwidth.
|
||||
for num_axes in range(2, 0, -1):
|
||||
for num_axes in range(3, 0, -1):
|
||||
# Try assign to any subset of size num_axes. Generate all candidates.
|
||||
axes = itertools.combinations(assignable_physical_mesh, num_axes)
|
||||
indices = itertools.combinations(
|
||||
|
@ -187,7 +187,7 @@ class PartitioningTest(test_util.JaxTestCase):
|
||||
self.assertEqual(normalized[i, j, k].coords, (i, j, k))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('2x2x1', mock_2x2x1_devices, [1, 1, 4], ((), (2,), (0, 1))),
|
||||
('2x2x1', mock_2x2x1_devices, [1, 1, 4], ((), (), (0, 1, 2))),
|
||||
('2x2x4', mock_2x2x4_devices, [1, 4, 4], ((), (2,), (0, 1))),
|
||||
('4x4x4', mock_4x4x4_devices, [1, 16, 4], ((), (1, 2), (0,))),
|
||||
('4x4x8a', mock_4x4x8_devices, [1, 16, 8], ((), (0, 1), (2,))),
|
||||
|
Loading…
x
Reference in New Issue
Block a user