mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add registration handler for TPU v5e in mesh_utils.
PiperOrigin-RevId: 643092629
This commit is contained in:
parent
dd3b0a6981
commit
023bc7856b
@ -13,9 +13,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
bumped to 0.4.0 but this has been rolled back in this release to give users
|
||||
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
|
||||
release.
|
||||
* `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e.
|
||||
* jax now depends on jaxlib directly. This change was enabled by the CUDA
|
||||
plugin switch: there are no longer multiple jaxlib variants. You can install
|
||||
a CPU-only jax with `pip install jax`, no extras required.
|
||||
|
||||
* Deprecations
|
||||
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
|
||||
in a future release.
|
||||
|
@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
|
||||
_TPU_V2 = 'TPU v2'
|
||||
_TPU_V3 = 'TPU v3'
|
||||
_TPU_V4 = 'TPU v4'
|
||||
_TPU_V5_LITE = "TPU v5 lite"
|
||||
|
||||
# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's
|
||||
# famous contiguous mesh trick.
|
||||
@ -64,7 +65,8 @@ _TRANSPOSE_TRICKS: dict[
|
||||
|
||||
# Physical ordering of core IDs in a tray that creates a ring
|
||||
_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5)
|
||||
|
||||
_TRAY_2x2_RING_ORDER = (0, 1, 3, 2)
|
||||
_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4)
|
||||
|
||||
def _tpu_v2_v3_create_device_mesh(
|
||||
mesh_shape: Sequence[int],
|
||||
@ -94,6 +96,45 @@ def _tpu_v2_v3_create_device_mesh(
|
||||
return np.asarray(devices).reshape(mesh_shape)
|
||||
|
||||
|
||||
def _vlc_create_device_mesh(
|
||||
mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs
|
||||
) -> np.ndarray | None:
|
||||
"""Creates rotated pincer device assignment for selected topologies.
|
||||
|
||||
Args:
|
||||
mesh_shape: Logical mesh shape used by the model.
|
||||
devices: TPU devices.
|
||||
**unused_kwargs: ...
|
||||
|
||||
Returns:
|
||||
None or reordered devices reshaped as `mesh_shape`.
|
||||
"""
|
||||
max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices)
|
||||
bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1
|
||||
# Our ring re-ordering makes sense only if the passed-in devices are
|
||||
# sequential, which may not always be the case. reversed() changes z-minor to
|
||||
# x-minor.
|
||||
sequential_devices = sorted(
|
||||
devices,
|
||||
key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0)))))
|
||||
|
||||
if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2
|
||||
device_mesh = np.asarray(sequential_devices)
|
||||
device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)]
|
||||
device_mesh = device_mesh.reshape(mesh_shape)
|
||||
return device_mesh
|
||||
|
||||
if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4
|
||||
# Only uses ring order if the whole mesh is a replica group.
|
||||
if max(mesh_shape) == len(devices):
|
||||
device_mesh = np.asarray(sequential_devices)
|
||||
device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)]
|
||||
device_mesh = device_mesh.reshape(mesh_shape)
|
||||
return device_mesh
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Registers functions to create device mesh for specific device kinds. Takes
|
||||
# precedence over the more general logic in create_device_mesh(). Handler may
|
||||
# return None; in that case, it will fall back to using the default logic.
|
||||
@ -103,6 +144,7 @@ device_kind_handler_dict: dict[
|
||||
] = {
|
||||
_TPU_V2: _tpu_v2_v3_create_device_mesh,
|
||||
_TPU_V3: _tpu_v2_v3_create_device_mesh,
|
||||
_TPU_V5_LITE: _vlc_create_device_mesh,
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user