mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 10:06:07 +00:00

* Fixes a bug when a non-3D mesh was requested * Adds new logic when requesting a single-host mesh * Extends logic to v2 as well as v3
286 lines
11 KiB
Python
286 lines
11 KiB
Python
# Lint as: python3
|
|
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for mesh utils."""
|
|
|
|
import collections
|
|
import dataclasses
|
|
from typing import Sequence
|
|
|
|
import numpy as np
|
|
|
|
from absl import logging
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
from jax.experimental import mesh_utils
|
|
from jax.experimental.maps import Mesh
|
|
from jax._src import test_util
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MockTpuDevice:
|
|
"""Mock TPU device for testing."""
|
|
id: int
|
|
platform: str
|
|
device_kind: str
|
|
process_index: int
|
|
coords: Sequence[int]
|
|
core_on_chip: int
|
|
|
|
|
|
def mock_devices(x, y, z, dev_kind, two_cores_per_chip):
|
|
"""Hard-coded reproduction of jax.devices() output on 8x8, 4x4x4."""
|
|
devices = []
|
|
process_index = 0
|
|
for k in range(z):
|
|
for j in range(0, y, 2):
|
|
for i in range(0, x, 2):
|
|
# Local 2x2 subgrid of chips, with 2 cores per chip.
|
|
host_devices = [
|
|
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j, k), 0),
|
|
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j, k), 1),
|
|
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j, k), 0),
|
|
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j, k), 1),
|
|
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j + 1, k), 0),
|
|
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j + 1, k), 1),
|
|
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j + 1, k), 0),
|
|
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j + 1, k), 1),
|
|
]
|
|
if two_cores_per_chip:
|
|
# Only include core_on_chip = 0.
|
|
host_devices = host_devices[::2]
|
|
devices.extend(host_devices)
|
|
# Simulate one process per host (1 host = 2x2x1 slice)
|
|
process_index += 1
|
|
|
|
# id grows in (z, y, x) major order
|
|
for d in devices:
|
|
i, j, k = d.coords
|
|
d.id = k*x*y + j*x + i
|
|
if not two_cores_per_chip:
|
|
d.id = d.id * 2 + d.core_on_chip
|
|
|
|
_validate_mocked_process_indices(devices, two_cores_per_chip)
|
|
return devices
|
|
|
|
# If this function raises, it's a bug in the test code!
|
|
def _validate_mocked_process_indices(devices, two_cores_per_chip):
|
|
process_to_devices = collections.defaultdict(lambda: [])
|
|
for d in devices:
|
|
process_to_devices[d.process_index].append(d)
|
|
|
|
for local_devices in process_to_devices.values():
|
|
if two_cores_per_chip:
|
|
# 4 devices per process
|
|
assert len(local_devices) == 4, local_devices
|
|
else:
|
|
# 8 devices per process
|
|
assert len(local_devices) == 8, local_devices
|
|
# All devices have same z coord
|
|
assert len(set(d.coords[2] for d in local_devices)) == 1, local_devices
|
|
# All devices in a 2x2 subgrid
|
|
min_coords = min(d.coords for d in local_devices)
|
|
expected = set()
|
|
for x, y in [(0,0), (0,1), (1,0), (1,1)]:
|
|
expected.add((min_coords[0] + x, min_coords[1] + y, min_coords[2]))
|
|
assert set(d.coords for d in local_devices) == expected, local_devices
|
|
|
|
|
|
def mock_2x2_devices():
|
|
"""Hard-coded reproduction of jax.devices() output on v3-2x2."""
|
|
return mock_devices(2, 2, 1, 'TPU v3', False)
|
|
|
|
|
|
def mock_4x4_devices():
|
|
"""Hard-coded reproduction of jax.devices() output on v3-4x4."""
|
|
return mock_devices(4, 4, 1, 'TPU v3', False)
|
|
|
|
|
|
def mock_8x8_devices():
|
|
"""Hard-coded reproduction of jax.devices() output on v3-8x8."""
|
|
return mock_devices(8, 8, 1, 'TPU v3', False)
|
|
|
|
|
|
def mock_2x2x1_devices(two_cores_per_chip):
|
|
"""Hard-coded reproduction of jax.devices() output on 2x2x1."""
|
|
return mock_devices(2, 2, 1, 'TPU v4', two_cores_per_chip)
|
|
|
|
|
|
def mock_2x2x4_devices(two_cores_per_chip):
|
|
"""Hard-coded reproduction of jax.devices() output on 2x2x4."""
|
|
return mock_devices(2, 2, 4, 'TPU v4', two_cores_per_chip)
|
|
|
|
|
|
def mock_4x4x4_devices(two_cores_per_chip):
|
|
"""Hard-coded reproduction of jax.devices() output on 4x4x4."""
|
|
return mock_devices(4, 4, 4, 'TPU v4', two_cores_per_chip)
|
|
|
|
|
|
def mock_4x4x8_devices(two_cores_per_chip):
|
|
"""Hard-coded reproduction of jax.devices() output on 4x4x4."""
|
|
return mock_devices(4, 4, 8, 'TPU v4', two_cores_per_chip)
|
|
|
|
|
|
def mock_8x8x8_devices(two_cores_per_chip):
|
|
"""Hard-coded reproduction of jax.devices() output on 8x8x8."""
|
|
return mock_devices(8, 8, 8, 'TPU v4', two_cores_per_chip)
|
|
|
|
|
|
def mock_4x8x8_devices(two_cores_per_chip):
|
|
"""Hard-coded reproduction of jax.devices() output on 4x8x8."""
|
|
return mock_devices(4, 8, 8, 'TPU v4', two_cores_per_chip)
|
|
|
|
|
|
def mock_4x8x16_devices(two_cores_per_chip):
|
|
"""Hard-coded reproduction of jax.devices() output on 4x8x16."""
|
|
return mock_devices(4, 8, 16, 'TPU v4', two_cores_per_chip)
|
|
|
|
|
|
def mock_8x8x16_devices(two_cores_per_chip):
|
|
"""Hard-coded reproduction of jax.devices() output on 8x8x16."""
|
|
return mock_devices(8, 8, 16, 'TPU v4', two_cores_per_chip)
|
|
|
|
|
|
class PartitioningTest(test_util.JaxTestCase):
|
|
|
|
@parameterized.named_parameters(
|
|
('2x2x1_t', mock_2x2x1_devices, True, (2, 2, 1, 1)),
|
|
('2x2x1_f', mock_2x2x1_devices, False, (2, 2, 1, 2)),
|
|
('8x8x16_t', mock_8x8x16_devices, True, (8, 8, 16, 1)),
|
|
('8x8x16_f', mock_8x8x16_devices, False, (8, 8, 16, 2)),
|
|
)
|
|
def test_bounds_from_last_device(self, devices, two_cores_per_chip,
|
|
expected_bounds):
|
|
self.assertEqual(
|
|
mesh_utils._bounds_from_last_device(devices(two_cores_per_chip)[-1]),
|
|
expected_bounds)
|
|
|
|
@parameterized.named_parameters(
|
|
('4x4x4', mock_4x4x4_devices, (4, 4, 4)),
|
|
('4x4x8', mock_4x4x8_devices, (4, 4, 8)),
|
|
('8x8x8', mock_8x8x8_devices, (8, 8, 8)),
|
|
('8x8x16', mock_8x8x16_devices, (8, 8, 16)),
|
|
)
|
|
def test_jax_devices_order_normalized(self, devices, expected_shape):
|
|
jax_local_devices_from_process_0 = mock_2x2x1_devices(True)
|
|
jax_devices = devices(True)
|
|
normalized = mesh_utils._jax_devices_order_normalized(
|
|
jax_local_devices_from_process_0, jax_devices)
|
|
self.assertEqual(normalized.shape, expected_shape)
|
|
x, y, z = expected_shape
|
|
# major_to_minor: x, y, z
|
|
for i in range(x):
|
|
for j in range(y):
|
|
for k in range(z):
|
|
self.assertEqual(normalized[i, j, k].coords, (i, j, k))
|
|
|
|
@parameterized.named_parameters(
|
|
('2x2x1', mock_2x2x1_devices, [1, 1, 4], ((), (2,), (0, 1))),
|
|
('2x2x4', mock_2x2x4_devices, [1, 4, 4], ((), (2,), (0, 1))),
|
|
('4x4x4', mock_4x4x4_devices, [1, 16, 4], ((), (1, 2), (0,))),
|
|
('4x4x8a', mock_4x4x8_devices, [1, 16, 8], ((), (0, 1), (2,))),
|
|
('4x4x8b', mock_4x4x8_devices, [1, 8, 16], ((), (2,), (0, 1))),
|
|
('4x4x8c', mock_4x4x8_devices, [16, 8, 1], ((0, 1), (2,), ())),
|
|
('4x8x8', mock_4x8x8_devices, [1, 32, 8], ((), (0, 2), (1,))),
|
|
('8x8x8', mock_8x8x8_devices, [1, 64, 8], ((), (1, 2), (0,))),
|
|
('8x8x16', mock_8x8x16_devices, [1, 64, 16], ((), (0, 1), (2,))),
|
|
)
|
|
def test_create_device_mesh_for_tpu_v4(self, devices, mesh_shape,
|
|
expected_assignment):
|
|
jax_local_devices_from_process_0 = mock_2x2x1_devices(True)
|
|
jax_devices = devices(True)
|
|
physical_mesh = mesh_utils._jax_devices_order_normalized(
|
|
jax_local_devices_from_process_0, jax_devices)
|
|
_, assignment = mesh_utils._create_device_mesh_for_tpu_v4(
|
|
physical_mesh, mesh_shape)
|
|
self.assertEqual(assignment, expected_assignment)
|
|
|
|
@parameterized.named_parameters(
|
|
# Physical ring order over tray
|
|
('2x2_1d', mock_2x2_devices, [8], [0, 1, 2, 3, 6, 7, 4, 5]),
|
|
# Reshaped physical ring order over tray
|
|
('2x2_2d', mock_2x2_devices, [2, 4], [[0, 1, 2, 3],
|
|
[6, 7, 4, 5]]),
|
|
# 4 per-tray rings
|
|
('4x4_2d', mock_4x4_devices, [4, 8], [[ 0, 1, 2, 3, 10, 11, 8, 9],
|
|
[ 4, 5, 6, 7, 14, 15, 12, 13],
|
|
[16, 17, 18, 19, 26, 27, 24, 25],
|
|
[20, 21, 22, 23, 30, 31, 28, 29]]),
|
|
)
|
|
def test_v3_create_device_mesh(self, devices, mesh_shape,
|
|
expected_device_id_mesh):
|
|
jax_local_devices_from_process_0 = mock_2x2_devices()
|
|
global_devices = devices()
|
|
mesh = mesh_utils._create_device_mesh(
|
|
jax_local_devices_from_process_0, global_devices,
|
|
global_devices[0].device_kind, mesh_shape, contiguous_submeshes=False)
|
|
device_id_mesh = np.vectorize(lambda d: d.id)(mesh)
|
|
self.assertAllClose(np.array(expected_device_id_mesh), device_id_mesh)
|
|
|
|
def _assert_contiguous_submeshes(self, global_device_mesh):
|
|
global_mesh = Mesh(global_device_mesh, list(range(global_device_mesh.ndim)))
|
|
max_process_index = max(d.process_index
|
|
for d in global_device_mesh.flatten())
|
|
for p_idx in range(max_process_index + 1):
|
|
# Raises an error if non-contiguous
|
|
global_mesh._local_mesh(p_idx)
|
|
|
|
def test_create_contiguous_submeshes_for_tpu_v4(self):
|
|
v4 = mesh_utils._TPU_V4
|
|
process_0_devices = mock_2x2x1_devices(True)
|
|
for topology, mesh_shapes in mesh_utils._TRANSPOSE_TRICKS.items():
|
|
logging.vlog(1, "topology: %s", topology)
|
|
devices = mock_devices(topology[0], topology[1], topology[2], v4,
|
|
two_cores_per_chip=True)
|
|
for mesh_shape in mesh_shapes:
|
|
logging.vlog(1, " mesh_shape: %s", mesh_shape)
|
|
mesh = mesh_utils._create_device_mesh(
|
|
process_0_devices, devices, v4, mesh_shape,
|
|
contiguous_submeshes=True)
|
|
self._assert_contiguous_submeshes(mesh)
|
|
|
|
def test_create_contiguous_submeshes_errors(self):
|
|
process_0_devices = mock_2x2x1_devices(True)
|
|
v4 = mesh_utils._TPU_V4
|
|
|
|
topology = (4, 4, 8)
|
|
mesh_shape = (1, 16, 8)
|
|
devices = mock_devices(topology[0], topology[1], topology[2], v4,
|
|
two_cores_per_chip=True)
|
|
with self.assertRaisesWithLiteralMatch(
|
|
ValueError,
|
|
"create_device_mesh cannot create contiguous submeshes for "
|
|
"physical mesh topology (4, 4, 8)"):
|
|
mesh_utils._create_device_mesh(
|
|
process_0_devices, devices, v4, mesh_shape,
|
|
contiguous_submeshes=True)
|
|
|
|
topology = (4, 8, 8)
|
|
mesh_shape = (1, 128, 2)
|
|
devices = mock_devices(topology[0], topology[1], topology[2], v4,
|
|
two_cores_per_chip=True)
|
|
with self.assertRaisesWithLiteralMatch(
|
|
ValueError,
|
|
"create_device_mesh cannot create contiguous submeshes for mesh_shape "
|
|
"(1, 128, 2) and physical mesh topology (4, 8, 8). "
|
|
"Available mesh_shapes: [(1, 64, 4), (1, 4, 64), (64, 4), (4, 64)]"):
|
|
mesh_utils._create_device_mesh(
|
|
process_0_devices, devices, v4, mesh_shape,
|
|
contiguous_submeshes=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main()
|