Add registration handler for TPU v5e in mesh_utils.

PiperOrigin-RevId: 643092629
This commit is contained in:
Yash Katariya 2024-06-13 12:51:43 -07:00 committed by jax authors
parent dd3b0a6981
commit 023bc7856b
2 changed files with 45 additions and 1 deletions

View File

@ -13,9 +13,11 @@ Remember to align the itemized text with the first line of an item within a list
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
* `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e.
* jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required.
* Deprecations
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
in a future release.

View File

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
_TPU_V2 = 'TPU v2'
_TPU_V3 = 'TPU v3'
_TPU_V4 = 'TPU v4'
_TPU_V5_LITE = "TPU v5 lite"
# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's
# famous contiguous mesh trick.
@ -64,7 +65,8 @@ _TRANSPOSE_TRICKS: dict[
# Physical ordering of core IDs in a tray that creates a ring
_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5)
_TRAY_2x2_RING_ORDER = (0, 1, 3, 2)
_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4)
def _tpu_v2_v3_create_device_mesh(
mesh_shape: Sequence[int],
@ -94,6 +96,45 @@ def _tpu_v2_v3_create_device_mesh(
return np.asarray(devices).reshape(mesh_shape)
def _vlc_create_device_mesh(
mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs
) -> np.ndarray | None:
"""Creates rotated pincer device assignment for selected topologies.
Args:
mesh_shape: Logical mesh shape used by the model.
devices: TPU devices.
**unused_kwargs: ...
Returns:
None or reordered devices reshaped as `mesh_shape`.
"""
max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices)
bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1
# Our ring re-ordering makes sense only if the passed-in devices are
# sequential, which may not always be the case. reversed() changes z-minor to
# x-minor.
sequential_devices = sorted(
devices,
key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0)))))
if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2
device_mesh = np.asarray(sequential_devices)
device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4
# Only uses ring order if the whole mesh is a replica group.
if max(mesh_shape) == len(devices):
device_mesh = np.asarray(sequential_devices)
device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
return None
# Registers functions to create device mesh for specific device kinds. Takes
# precedence over the more general logic in create_device_mesh(). Handler may
# return None; in that case, it will fall back to using the default logic.
@ -103,6 +144,7 @@ device_kind_handler_dict: dict[
] = {
_TPU_V2: _tpu_v2_v3_create_device_mesh,
_TPU_V3: _tpu_v2_v3_create_device_mesh,
_TPU_V5_LITE: _vlc_create_device_mesh,
}