2023-03-10 10:07:37 -08:00
|
|
|
# Copyright 2018 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Definitions of Mesh and ResourceEnv."""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import collections
|
|
|
|
import contextlib
|
|
|
|
import functools
|
2023-04-13 11:48:11 -07:00
|
|
|
import math
|
2023-03-10 10:07:37 -08:00
|
|
|
import threading
|
|
|
|
from typing import Any, Hashable, NamedTuple, Set, Sequence, Tuple, Union
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from jax._src import config as jax_config
|
|
|
|
from jax._src import xla_bridge as xb
|
|
|
|
from jax._src import util
|
|
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
|
|
|
|
MeshAxisName = Any
|
|
|
|
ResourceAxisName = Hashable
|
|
|
|
|
|
|
|
class Loop(NamedTuple):
|
|
|
|
name: ResourceAxisName
|
|
|
|
length: int
|
|
|
|
|
|
|
|
|
|
|
|
def show_axes(axes):
|
|
|
|
return ", ".join(sorted(f"`{a}`" for a in axes))
|
|
|
|
|
|
|
|
|
|
|
|
class ResourceEnv(NamedTuple):
|
|
|
|
physical_mesh: Mesh
|
|
|
|
loops: Tuple[Loop, ...]
|
|
|
|
|
|
|
|
def with_mesh(self, mesh: Mesh):
|
|
|
|
overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names))
|
|
|
|
if overlap:
|
|
|
|
raise ValueError(f"Cannot update the mesh of the current resource "
|
|
|
|
f"environment. The new mesh shadows already defined axes "
|
|
|
|
f"{show_axes(overlap)}")
|
|
|
|
return self._replace(physical_mesh=mesh)
|
|
|
|
|
|
|
|
def with_extra_loop(self, loop: Loop):
|
|
|
|
if loop.name in self.resource_axes:
|
|
|
|
raise ValueError(f"Cannot extend the resource environment with loop named "
|
|
|
|
f"`{loop.name}`. An axis of this name is already defined!")
|
|
|
|
return self._replace(loops=self.loops + (loop,))
|
|
|
|
|
|
|
|
@property
|
|
|
|
def physical_resource_axes(self) -> Set[ResourceAxisName]:
|
|
|
|
return set(self.physical_mesh.axis_names)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def loop_resource_axes(self) -> Set[ResourceAxisName]:
|
|
|
|
return {loop.name for loop in self.loops}
|
|
|
|
|
|
|
|
@property
|
|
|
|
def resource_axes(self) -> Set[ResourceAxisName]:
|
|
|
|
return self.physical_resource_axes | self.loop_resource_axes
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shape(self):
|
|
|
|
shape = self.physical_mesh.shape
|
|
|
|
shape.update(self.loops)
|
|
|
|
return shape
|
|
|
|
|
|
|
|
@property
|
|
|
|
def local_shape(self):
|
|
|
|
shape = self.physical_mesh.local_mesh.shape
|
|
|
|
shape.update(self.loops)
|
|
|
|
return shape
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"
|
|
|
|
|
|
|
|
class Mesh(contextlib.ContextDecorator):
|
|
|
|
"""Declare the hardware resources available in the scope of this manager.
|
|
|
|
|
|
|
|
In particular, all ``axis_names`` become valid resource names inside the
|
|
|
|
managed block and can be used e.g. in the ``in_axis_resources`` argument of
|
|
|
|
:py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming
|
|
|
|
model (https://jax.readthedocs.io/en/latest/multi_process.html)
|
|
|
|
and the Distributed arrays and automatic parallelization tutorial
|
|
|
|
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
|
|
|
|
|
|
|
|
If you are compiling in multiple threads, make sure that the
|
|
|
|
``with Mesh`` context manager is inside the function that the threads will
|
|
|
|
execute.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
devices: A NumPy ndarray object containing JAX device objects (as
|
|
|
|
obtained e.g. from :py:func:`jax.devices`).
|
|
|
|
axis_names: A sequence of resource axis names to be assigned to the
|
|
|
|
dimensions of the ``devices`` argument. Its length should match the
|
|
|
|
rank of ``devices``.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> from jax.experimental.pjit import pjit
|
|
|
|
>>> from jax.sharding import Mesh
|
|
|
|
>>> from jax.sharding import PartitionSpec as P
|
|
|
|
>>> import numpy as np
|
|
|
|
...
|
|
|
|
>>> inp = np.arange(16).reshape((8, 2))
|
|
|
|
>>> devices = np.array(jax.devices()).reshape(4, 2)
|
|
|
|
...
|
|
|
|
>>> # Declare a 2D mesh with axes `x` and `y`.
|
|
|
|
>>> global_mesh = Mesh(devices, ('x', 'y'))
|
|
|
|
>>> # Use the mesh object directly as a context manager.
|
|
|
|
>>> with global_mesh:
|
|
|
|
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
|
|
|
|
|
|
|
|
>>> # Initialize the Mesh and use the mesh as the context manager.
|
|
|
|
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
|
|
|
|
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
|
|
|
|
|
|
|
|
>>> # Also you can use it as `with ... as ...`.
|
|
|
|
>>> global_mesh = Mesh(devices, ('x', 'y'))
|
|
|
|
>>> with global_mesh as m:
|
|
|
|
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
|
|
|
|
|
|
|
|
>>> # You can also use it as `with Mesh(...)`.
|
|
|
|
>>> with Mesh(devices, ('x', 'y')):
|
|
|
|
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
|
|
|
|
"""
|
|
|
|
|
|
|
|
devices: np.ndarray
|
|
|
|
axis_names: Tuple[MeshAxisName, ...]
|
|
|
|
|
|
|
|
def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]],
|
|
|
|
axis_names: Union[str, Sequence[MeshAxisName]]):
|
|
|
|
if not isinstance(devices, np.ndarray):
|
|
|
|
devices = np.array(devices)
|
|
|
|
if isinstance(axis_names, str):
|
|
|
|
axis_names = (axis_names,)
|
|
|
|
assert devices.ndim == len(axis_names)
|
|
|
|
# TODO: Make sure that devices are unique? At least with the quick and
|
|
|
|
# dirty check that the array size is not larger than the number of
|
|
|
|
# available devices?
|
|
|
|
self.devices = devices.copy()
|
|
|
|
self.devices.flags.writeable = False
|
|
|
|
self.axis_names = tuple(axis_names)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
if not isinstance(other, Mesh):
|
|
|
|
return False
|
|
|
|
# This is a performance optimization. Comparing thousands of devices
|
|
|
|
# can be expensive.
|
|
|
|
if id(self) == id(other):
|
|
|
|
return True
|
|
|
|
return (self.axis_names == other.axis_names and
|
|
|
|
np.array_equal(self.devices, other.devices))
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
if not hasattr(self, '_hash'):
|
|
|
|
self._hash = hash(
|
|
|
|
(self.axis_names, tuple(self.devices.flat), self.devices.shape))
|
|
|
|
return self._hash
|
|
|
|
|
|
|
|
def __setattr__(self, name, value):
|
|
|
|
if hasattr(self, name):
|
|
|
|
raise RuntimeError("Cannot reassign attributes of immutable mesh objects")
|
|
|
|
super().__setattr__(name, value)
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
new_env = thread_resources.stack[-1].with_mesh(self)
|
|
|
|
thread_resources.stack.append(new_env)
|
|
|
|
thread_resources.env = new_env
|
|
|
|
jax_config.update_thread_local_jit_state(
|
|
|
|
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
|
|
|
|
if not t.physical_mesh.empty))
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
thread_resources.stack.pop()
|
|
|
|
thread_resources.env = thread_resources.stack[-1]
|
|
|
|
jax_config.update_thread_local_jit_state(
|
|
|
|
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
|
|
|
|
if not t.physical_mesh.empty))
|
|
|
|
return False
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shape(self):
|
|
|
|
return collections.OrderedDict(
|
|
|
|
(name, size)
|
|
|
|
for name, size in util.safe_zip(self.axis_names, self.devices.shape))
|
|
|
|
|
|
|
|
@property
|
|
|
|
def size(self):
|
2023-04-13 11:48:11 -07:00
|
|
|
return math.prod(self.shape.values())
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def empty(self):
|
|
|
|
return self.devices.ndim == 0
|
|
|
|
|
2023-04-06 19:45:59 -07:00
|
|
|
@functools.cached_property
|
2023-03-10 10:07:37 -08:00
|
|
|
def is_multi_process(self):
|
|
|
|
return self.devices.size != len(self.local_devices)
|
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def local_mesh(self):
|
|
|
|
return self._local_mesh(xb.process_index())
|
|
|
|
|
|
|
|
def _local_mesh(self, process_index):
|
|
|
|
if self.empty:
|
|
|
|
return self
|
|
|
|
is_local_device = np.vectorize(
|
|
|
|
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
|
|
|
|
subcube_indices = []
|
|
|
|
# We take the smallest slice of each dimension that doesn't skip any local device.
|
|
|
|
for axis in range(self.devices.ndim):
|
|
|
|
other_axes = util.tuple_delete(tuple(range(self.devices.ndim)), axis)
|
|
|
|
# NOTE: This re-reduces over many axes multiple times, so we could definitely
|
|
|
|
# optimize it, but I hope it won't be a bottleneck anytime soon.
|
|
|
|
local_slices = is_local_device.any(other_axes, keepdims=False)
|
|
|
|
nonzero_indices = np.flatnonzero(local_slices)
|
|
|
|
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
|
|
|
|
subcube_indices.append(slice(start, end + 1))
|
|
|
|
subcube_indices = tuple(subcube_indices)
|
|
|
|
# We only end up with all conditions being true if the local devices formed a
|
|
|
|
# subcube of the full array. This is because we were biased towards taking a
|
|
|
|
# "hull" spanned by the devices, and in case the local devices don't form a
|
|
|
|
# subcube that hull will contain non-local devices.
|
|
|
|
if not is_local_device[subcube_indices].all():
|
|
|
|
raise ValueError(
|
|
|
|
"When passing host local inputs to pjit or xmap, devices "
|
|
|
|
"connected to a single host must form a contiguous subcube of the "
|
|
|
|
"global device mesh")
|
|
|
|
return Mesh(self.devices[subcube_indices], self.axis_names)
|
|
|
|
|
2023-04-07 12:05:31 -07:00
|
|
|
@functools.cached_property
|
2023-03-10 10:07:37 -08:00
|
|
|
def device_ids(self):
|
|
|
|
assert not self.empty
|
|
|
|
return np.vectorize(lambda d: d.id, otypes=[int])(self.devices)
|
|
|
|
|
2023-04-07 13:54:04 -07:00
|
|
|
@functools.cached_property
|
|
|
|
def _local_devices_set(self):
|
|
|
|
return set(self.local_devices)
|
|
|
|
|
|
|
|
@functools.cached_property
|
2023-04-13 08:02:53 -07:00
|
|
|
def _flat_devices_tuple(self):
|
|
|
|
return tuple(self.devices.flat)
|
2023-04-07 13:54:04 -07:00
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def _flat_devices_set(self):
|
|
|
|
return set(self.devices.flat)
|
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def _repr(self):
|
2023-03-10 10:07:37 -08:00
|
|
|
if self.empty:
|
|
|
|
return "Mesh(device_ids=[], axis_names=())"
|
|
|
|
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})"
|
|
|
|
|
2023-04-07 13:54:04 -07:00
|
|
|
def __repr__(self):
|
|
|
|
return self._repr
|
|
|
|
|
2023-03-10 10:07:37 -08:00
|
|
|
@functools.cached_property
|
|
|
|
def local_devices(self):
|
|
|
|
return [d for d in self.devices.flat
|
|
|
|
if d.process_index == d.client.process_index()]
|
|
|
|
|
|
|
|
|
|
|
|
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ())
|
|
|
|
|
|
|
|
class _ThreadResourcesLocalState(threading.local):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.stack = [EMPTY_ENV]
|
|
|
|
self.env = self.stack[-1]
|
|
|
|
|
|
|
|
thread_resources = _ThreadResourcesLocalState()
|