From 623282715de895cd813a0dc44cf7da2d5630b310 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 10 Mar 2023 10:07:37 -0800 Subject: [PATCH] Split Mesh and ResourceEnv into a new module jax._src.mesh. This work is an effort to reduce cyclic dependencies in JAX internals. Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals. PiperOrigin-RevId: 515667671 --- jax/_src/debugging.py | 4 +- jax/_src/interpreters/pxla.py | 285 +++------------------------- jax/_src/maps.py | 31 +-- jax/_src/mesh.py | 267 ++++++++++++++++++++++++++ jax/_src/pjit.py | 15 +- jax/experimental/maps.py | 10 +- jax/experimental/multihost_utils.py | 11 +- jax/interpreters/pxla.py | 15 +- tests/pjit_test.py | 18 +- 9 files changed, 348 insertions(+), 308 deletions(-) create mode 100644 jax/_src/mesh.py diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 96590aa46..3dc7b2a44 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -27,12 +27,12 @@ from jax import lax from jax._src import core from jax._src import effects from jax._src import linear_util as lu +from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -310,7 +310,7 @@ class ShardingCallbackInfo: def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, callback): - mesh = pxla.thread_resources.env.physical_mesh + mesh = mesh_lib.thread_resources.env.physical_mesh axis_context = ctx.module_context.axis_context if isinstance(axis_context, mlir.ShardingContext): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9b755c3a2..f430b8aa5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -31,7 +31,7 @@ from __future__ import annotations import enum -from contextlib import contextmanager, ContextDecorator +from contextlib import contextmanager from collections import defaultdict, OrderedDict, namedtuple import dataclasses from functools import partial, lru_cache, cached_property @@ -60,6 +60,7 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu +from jax._src import mesh from jax._src import profiler from jax._src import sharding as sharding_internal from jax._src import source_info_util @@ -68,7 +69,6 @@ from jax._src import util from jax._src import xla_bridge as xb from jax._src.abstract_arrays import array_types from jax._src.config import config -from jax._src import config as jax_config from jax._src.config import flags from jax._src.core import ConcreteArray, ShapedArray from jax._src.interpreters import ad @@ -113,10 +113,11 @@ Replicated = pmap_lib.Replicated _UNSHARDED_INSTANCE = NoSharding() AvalDimSharding = Union[Unstacked, Chunked, NoSharding] +Mesh = jax._src.mesh.Mesh +MeshAxisName = mesh.MeshAxisName MeshDimAssignment = Union[ShardedAxis, Replicated] ShardingSpec = pmap_lib.ShardingSpec -MeshAxisName = Any OpShardingType = Any PartitionSpec = sharding_internal.PartitionSpec @@ -191,7 +192,7 @@ def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType # specially over some mesh axes. if replicated_maxes: last_tile_dims = [] - axes_by_type: Dict[OpShardingType, List[MeshAxisName]] = {} + axes_by_type: Dict[OpShardingType, List[jax._src.mesh.MeshAxisName]] = {} size_by_type: Dict[OpShardingType, int] = defaultdict(lambda: 1) assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys())) for axis, size in replicated_maxes: @@ -541,7 +542,7 @@ that would mean that a flat list of chunks would get assigned to a flattened lis mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the mesh devices ndarray would have to be transposed before flattening and assignment. """ -ArrayMapping = OrderedDictType[MeshAxisName, int] +ArrayMapping = OrderedDictType[mesh.MeshAxisName, int] ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource, UnspecifiedValue] @@ -691,7 +692,7 @@ def make_sharded_device_array( sharding_spec = _create_pmap_sharding_spec(aval) if jax.config.jax_array: - mesh = thread_resources.env.physical_mesh + mesh = jax._src.mesh.thread_resources.env.physical_mesh if mesh.empty: sharding = sharding_internal.PmapSharding( np.asarray([d.device() for d in device_buffers]), sharding_spec) @@ -1993,8 +1994,9 @@ def global_avals_to_results_handler( else: # This path is taken when the outputs are SDAs. assert all(isinstance(s, sharding_internal.NamedSharding) for s in shardings) - local_out_avals = [s.mesh._global_to_local(get_array_mapping(s.spec), aval) - for aval, s in safe_zip(global_out_avals, shardings)] + local_out_avals = [ + mesh_global_to_local(s.mesh, get_array_mapping(s.spec), aval) # type: ignore + for aval, s in safe_zip(global_out_avals, shardings)] local_shardings = [sharding_internal.NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore for s in shardings] return local_avals_to_results_handler(local_out_avals, local_shardings) @@ -2388,250 +2390,6 @@ mlir.register_lowering(xla_pmap_p, _pmap_lowering) # ------------------- xmap ------------------- -class Mesh(ContextDecorator): - """Declare the hardware resources available in the scope of this manager. - - In particular, all ``axis_names`` become valid resource names inside the - managed block and can be used e.g. in the ``in_axis_resources`` argument of - :py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming - model (https://jax.readthedocs.io/en/latest/multi_process.html) - and the Distributed arrays and automatic parallelization tutorial - (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) - - If you are compiling in multiple threads, make sure that the - ``with Mesh`` context manager is inside the function that the threads will - execute. - - Args: - devices: A NumPy ndarray object containing JAX device objects (as - obtained e.g. from :py:func:`jax.devices`). - axis_names: A sequence of resource axis names to be assigned to the - dimensions of the ``devices`` argument. Its length should match the - rank of ``devices``. - - Example: - - >>> from jax.experimental.pjit import pjit - >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P - >>> import numpy as np - ... - >>> inp = np.arange(16).reshape((8, 2)) - >>> devices = np.array(jax.devices()).reshape(4, 2) - ... - >>> # Declare a 2D mesh with axes `x` and `y`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> # Use the mesh object directly as a context manager. - >>> with global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Initialize the Mesh and use the mesh as the context manager. - >>> with Mesh(devices, ('x', 'y')) as global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Also you can use it as `with ... as ...`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> with global_mesh as m: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # You can also use it as `with Mesh(...)`. - >>> with Mesh(devices, ('x', 'y')): - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - """ - - devices: np.ndarray - axis_names: Tuple[MeshAxisName, ...] - - def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]], - axis_names: Union[str, Sequence[MeshAxisName]]): - if not isinstance(devices, np.ndarray): - devices = np.array(devices) - if isinstance(axis_names, str): - axis_names = (axis_names,) - assert devices.ndim == len(axis_names) - # TODO: Make sure that devices are unique? At least with the quick and - # dirty check that the array size is not larger than the number of - # available devices? - self.devices = devices.copy() - self.devices.flags.writeable = False - self.axis_names = tuple(axis_names) - - def __eq__(self, other): - if not isinstance(other, Mesh): - return False - # This is a performance optimization. Comparing thousands of devices - # can be expensive. - if id(self) == id(other): - return True - return (self.axis_names == other.axis_names and - np.array_equal(self.devices, other.devices)) - - def __hash__(self): - if not hasattr(self, '_hash'): - self._hash = hash( - (self.axis_names, tuple(self.devices.flat), self.devices.shape)) - return self._hash - - def __setattr__(self, name, value): - if hasattr(self, name): - raise RuntimeError("Cannot reassign attributes of immutable mesh objects") - super().__setattr__(name, value) - - def __enter__(self): - new_env = thread_resources.stack[-1].with_mesh(self) - thread_resources.stack.append(new_env) - thread_resources.env = new_env - jax_config.update_thread_local_jit_state( - mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) - return self - - def __exit__(self, exc_type, exc_value, traceback): - thread_resources.stack.pop() - thread_resources.env = thread_resources.stack[-1] - jax_config.update_thread_local_jit_state( - mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) - return False - - @property - def shape(self): - return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape)) - - @property - def size(self): - return np.prod(list(self.shape.values())) - - @property - def empty(self): - return self.devices.ndim == 0 - - @property - def is_multi_process(self): - return self.devices.size != len(self.local_devices) - - @cached_property - def local_mesh(self): - return self._local_mesh(xb.process_index()) - - def _local_mesh(self, process_index): - if self.empty: - return self - is_local_device = np.vectorize( - lambda d: d.process_index == process_index, otypes=[bool])(self.devices) - subcube_indices = [] - # We take the smallest slice of each dimension that doesn't skip any local device. - for axis in range(self.devices.ndim): - other_axes = tuple_delete(tuple(range(self.devices.ndim)), axis) - # NOTE: This re-reduces over many axes multiple times, so we could definitely - # optimize it, but I hope it won't be a bottleneck anytime soon. - local_slices = is_local_device.any(other_axes, keepdims=False) - nonzero_indices = np.flatnonzero(local_slices) - start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices)) - subcube_indices.append(slice(start, end + 1)) - subcube_indices = tuple(subcube_indices) - # We only end up with all conditions being true if the local devices formed a - # subcube of the full array. This is because we were biased towards taking a - # "hull" spanned by the devices, and in case the local devices don't form a - # subcube that hull will contain non-local devices. - if not is_local_device[subcube_indices].all(): - raise ValueError( - "When passing host local inputs to pjit or xmap, devices " - "connected to a single host must form a contiguous subcube of the " - "global device mesh") - return Mesh(self.devices[subcube_indices], self.axis_names) - - @property - def device_ids(self): - assert not self.empty - return np.vectorize(lambda d: d.id, otypes=[int])(self.devices) - - def __repr__(self): - if self.empty: - return "Mesh(device_ids=[], axis_names=())" - return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})" - - @cached_property - def local_devices(self): - return [d for d in self.devices.flat - if d.process_index == d.client.process_index()] - - def _local_to_global(self, axes: ArrayMapping, aval): - return untile_aval_nd(self.shape, axes, - tile_aval_nd(self.local_mesh.shape, axes, aval)) - - def _global_to_local(self, axes: ArrayMapping, aval): - return untile_aval_nd(self.local_mesh.shape, axes, - tile_aval_nd(self.shape, axes, aval)) - - -ResourceAxisName = core.AxisName - -class Loop(NamedTuple): - name: ResourceAxisName - length: int - - -def show_axes(axes): - return ", ".join(sorted(f"`{a}`" for a in axes)) - - -class ResourceEnv(NamedTuple): - physical_mesh: Mesh - loops: Tuple[Loop, ...] - - def with_mesh(self, mesh: Mesh): - overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names)) - if overlap: - raise ValueError(f"Cannot update the mesh of the current resource " - f"environment. The new mesh shadows already defined axes " - f"{show_axes(overlap)}") - return self._replace(physical_mesh=mesh) - - def with_extra_loop(self, loop: Loop): - if loop.name in self.resource_axes: - raise ValueError(f"Cannot extend the resource environment with loop named " - f"`{loop.name}`. An axis of this name is already defined!") - return self._replace(loops=self.loops + (loop,)) - - @property - def physical_resource_axes(self) -> Set[ResourceAxisName]: - return set(self.physical_mesh.axis_names) - - @property - def loop_resource_axes(self) -> Set[ResourceAxisName]: - return {loop.name for loop in self.loops} - - @property - def resource_axes(self) -> Set[ResourceAxisName]: - return self.physical_resource_axes | self.loop_resource_axes - - @property - def shape(self): - shape = self.physical_mesh.shape - shape.update(self.loops) - return shape - - @property - def local_shape(self): - shape = self.physical_mesh.local_mesh.shape - shape.update(self.loops) - return shape - - def __repr__(self): - return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})" - -EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ()) - -class _ThreadResourcesLocalState(threading.local): - - def __init__(self): - self.stack = [EMPTY_ENV] - self.env = self.stack[-1] - -thread_resources = _ThreadResourcesLocalState() - - def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval): assert isinstance(aval, ShapedArray) shape = list(aval.shape) @@ -2653,6 +2411,15 @@ def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval): return aval.update(shape=tuple(shape), named_shape=named_shape) +def mesh_local_to_global(mesh, axes: ArrayMapping, aval): + return untile_aval_nd(mesh.shape, axes, + tile_aval_nd(mesh.local_mesh.shape, axes, aval)) + +def mesh_global_to_local(mesh, axes: ArrayMapping, aval): + return untile_aval_nd(mesh.local_mesh.shape, axes, + tile_aval_nd(mesh.shape, axes, aval)) + + class SPMDBatchTrace(batching.BatchTrace): def get_axis_primitive_batcher(self, primitive, frame): if primitive in spmd_primitive_batchers: @@ -2688,7 +2455,7 @@ def _full_to_shard_abstract_eval(x, axes, mesh, **_): # TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes! return tile_aval_nd(mesh.shape, axes, x) -def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[MeshAxisName], mesh: Mesh): +def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[mesh.MeshAxisName], mesh: Mesh): """Create an OpSharding proto that declares all mesh axes from `axes` as manual and all others as replicated. """ @@ -2714,7 +2481,8 @@ def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[MeshAxisName return proto @partial(mlir.register_lowering, full_to_shard_p) -def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]): +def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, + manual_axes: FrozenSet[mesh.MeshAxisName]): # TODO: Can we short-circuit for replicated values? Probably not. aval_in, = ctx.avals_in aval_out, = ctx.avals_out @@ -2733,7 +2501,8 @@ def _shard_to_full_abstract_eval(x, axes, mesh, **_): return untile_aval_nd(mesh.shape, axes, x) @partial(mlir.register_lowering, shard_to_full_p) -def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]): +def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, + manual_axes: FrozenSet[mesh.MeshAxisName]): aval_in, = ctx.avals_in aval_out, = ctx.avals_out proto = manual_proto(aval_in, manual_axes, mesh) @@ -2744,7 +2513,7 @@ def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_ax return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims), @lu.transformation -def vtile_manual(manual_axes: FrozenSet[MeshAxisName], +def vtile_manual(manual_axes: FrozenSet[mesh.MeshAxisName], mesh: Mesh, in_axes: Sequence[ArrayMapping], out_axes: Sequence[ArrayMapping], @@ -2763,7 +2532,7 @@ class TileVectorize: @dataclasses.dataclass(frozen=True) class TileManual: - manual_axes: FrozenSet[MeshAxisName] + manual_axes: FrozenSet[mesh.MeshAxisName] TilingMethod = Union[TileVectorize, TileManual] @@ -3415,7 +3184,7 @@ def _get_normalized_avals_and_shardings( in_sharding = i else: assert isinstance(i, sharding_internal.NamedSharding) - aval = i.mesh._global_to_local( + aval = mesh_global_to_local(i.mesh, cast(ArrayMapping, get_array_mapping(i.spec)), gaval) # pylint: disable=g-bare-generic in_sharding = sharding_internal.NamedSharding(i.mesh.local_mesh, i.spec) avals.append(aval) diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 9675febdd..dc8ab620c 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -22,6 +22,7 @@ from functools import wraps, partial, partialmethod, lru_cache from jax import numpy as jnp from jax._src import core +from jax._src import mesh from jax._src import linear_util as lu from jax import stages from jax._src import dispatch @@ -85,13 +86,12 @@ class FrozenDict(abc.Mapping): # Multi-dimensional generalized map AxisName = core.AxisName -ResourceAxisName = AxisName # Different name just for documentation purposes -# TODO(https://github.com/google/jax/issues/13487): Remove Mesh in -# 3 months from `jax.experimental.maps.Mesh`. -Mesh = pxla.Mesh -ResourceEnv = pxla.ResourceEnv -EMPTY_ENV = pxla.EMPTY_ENV -thread_resources = pxla.thread_resources +ResourceAxisName = mesh.ResourceAxisName # Different name just for documentation purposes +Mesh = mesh.Mesh +MeshAxisName = mesh.MeshAxisName +ResourceEnv = mesh.ResourceEnv +EMPTY_ENV = mesh.EMPTY_ENV +thread_resources = mesh.thread_resources class SerialLoop: @@ -161,7 +161,7 @@ def serial_loop(name: ResourceAxisName, length: int): axis_resources={'i': 'l'})(x) """ old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV) - thread_resources.env = old_env.with_extra_loop(pxla.Loop(name, length)) + thread_resources.env = old_env.with_extra_loop(mesh.Loop(name, length)) try: yield finally: @@ -686,7 +686,7 @@ def make_xmap_callable(fun: lu.WrappedFun, mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes) mesh = resource_env.physical_mesh global_in_avals = [ - av if ips == _PositionalSemantics.GLOBAL else mesh._local_to_global(ax, av) + av if ips == _PositionalSemantics.GLOBAL else pxla.mesh_local_to_global(mesh, ax, av) for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics) ] in_is_global = [ips == _PositionalSemantics.GLOBAL or not ia @@ -964,7 +964,7 @@ def _resource_typing_xmap(avals, raise JAXTypeError( f"Detected disallowed xmap axis name shadowing at " f"{source_info_util.summarize(source_info)} " - f"(shadowed axes: {pxla.show_axes(overlap)})") + f"(shadowed axes: {mesh.show_axes(overlap)})") if resource_env.physical_mesh != params['resource_env'].physical_mesh: raise RuntimeError("Changing the physical mesh is not allowed inside xmap.") @@ -992,9 +992,9 @@ def _resource_typing_xmap(avals, raise JAXTypeError( f"One of xmapped function ({params['name']}) outputs is broadcast " f"along axis `{baxis}` which is assigned to resources " - f"{pxla.show_axes(baxis_resources)}, but the output is already " - f"partitioned along {pxla.show_axes(overlap)}, because its " - f"named shape contains {pxla.show_axes(partitioning_axes)}") + f"{mesh.show_axes(baxis_resources)}, but the output is already " + f"partitioned along {mesh.show_axes(overlap)}, because its " + f"named shape contains {mesh.show_axes(partitioning_axes)}") pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap @@ -1419,8 +1419,9 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, f = pxla.vtile_by_mesh(f, mesh, mesh_in_axes, mesh_out_axes) # XXX: We modify mesh_in_axes and mesh_out_axes here - def add_spmd_axes(flat_mesh_axes: Sequence[pxla.ArrayMapping], - flat_extra_axes: Optional[Sequence[Sequence[Sequence[pxla.MeshAxisName]]]]): + def add_spmd_axes( + flat_mesh_axes: Sequence[pxla.ArrayMapping], + flat_extra_axes: Optional[Sequence[Sequence[Sequence[MeshAxisName]]]]): if flat_extra_axes is None: return for axes, extra in zip(flat_mesh_axes, flat_extra_axes): diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py new file mode 100644 index 000000000..dec7a5f53 --- /dev/null +++ b/jax/_src/mesh.py @@ -0,0 +1,267 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Definitions of Mesh and ResourceEnv.""" + +from __future__ import annotations + +import collections +import contextlib +import functools +import threading +from typing import Any, Hashable, NamedTuple, Set, Sequence, Tuple, Union + +import numpy as np + +from jax._src import config as jax_config +from jax._src import xla_bridge as xb +from jax._src import util +from jax._src.lib import xla_client as xc + +MeshAxisName = Any +ResourceAxisName = Hashable + +class Loop(NamedTuple): + name: ResourceAxisName + length: int + + +def show_axes(axes): + return ", ".join(sorted(f"`{a}`" for a in axes)) + + +class ResourceEnv(NamedTuple): + physical_mesh: Mesh + loops: Tuple[Loop, ...] + + def with_mesh(self, mesh: Mesh): + overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names)) + if overlap: + raise ValueError(f"Cannot update the mesh of the current resource " + f"environment. The new mesh shadows already defined axes " + f"{show_axes(overlap)}") + return self._replace(physical_mesh=mesh) + + def with_extra_loop(self, loop: Loop): + if loop.name in self.resource_axes: + raise ValueError(f"Cannot extend the resource environment with loop named " + f"`{loop.name}`. An axis of this name is already defined!") + return self._replace(loops=self.loops + (loop,)) + + @property + def physical_resource_axes(self) -> Set[ResourceAxisName]: + return set(self.physical_mesh.axis_names) + + @property + def loop_resource_axes(self) -> Set[ResourceAxisName]: + return {loop.name for loop in self.loops} + + @property + def resource_axes(self) -> Set[ResourceAxisName]: + return self.physical_resource_axes | self.loop_resource_axes + + @property + def shape(self): + shape = self.physical_mesh.shape + shape.update(self.loops) + return shape + + @property + def local_shape(self): + shape = self.physical_mesh.local_mesh.shape + shape.update(self.loops) + return shape + + def __repr__(self): + return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})" + +class Mesh(contextlib.ContextDecorator): + """Declare the hardware resources available in the scope of this manager. + + In particular, all ``axis_names`` become valid resource names inside the + managed block and can be used e.g. in the ``in_axis_resources`` argument of + :py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming + model (https://jax.readthedocs.io/en/latest/multi_process.html) + and the Distributed arrays and automatic parallelization tutorial + (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) + + If you are compiling in multiple threads, make sure that the + ``with Mesh`` context manager is inside the function that the threads will + execute. + + Args: + devices: A NumPy ndarray object containing JAX device objects (as + obtained e.g. from :py:func:`jax.devices`). + axis_names: A sequence of resource axis names to be assigned to the + dimensions of the ``devices`` argument. Its length should match the + rank of ``devices``. + + Example: + + >>> from jax.experimental.pjit import pjit + >>> from jax.sharding import Mesh + >>> from jax.sharding import PartitionSpec as P + >>> import numpy as np + ... + >>> inp = np.arange(16).reshape((8, 2)) + >>> devices = np.array(jax.devices()).reshape(4, 2) + ... + >>> # Declare a 2D mesh with axes `x` and `y`. + >>> global_mesh = Mesh(devices, ('x', 'y')) + >>> # Use the mesh object directly as a context manager. + >>> with global_mesh: + ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) + + >>> # Initialize the Mesh and use the mesh as the context manager. + >>> with Mesh(devices, ('x', 'y')) as global_mesh: + ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) + + >>> # Also you can use it as `with ... as ...`. + >>> global_mesh = Mesh(devices, ('x', 'y')) + >>> with global_mesh as m: + ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) + + >>> # You can also use it as `with Mesh(...)`. + >>> with Mesh(devices, ('x', 'y')): + ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) + """ + + devices: np.ndarray + axis_names: Tuple[MeshAxisName, ...] + + def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]], + axis_names: Union[str, Sequence[MeshAxisName]]): + if not isinstance(devices, np.ndarray): + devices = np.array(devices) + if isinstance(axis_names, str): + axis_names = (axis_names,) + assert devices.ndim == len(axis_names) + # TODO: Make sure that devices are unique? At least with the quick and + # dirty check that the array size is not larger than the number of + # available devices? + self.devices = devices.copy() + self.devices.flags.writeable = False + self.axis_names = tuple(axis_names) + + def __eq__(self, other): + if not isinstance(other, Mesh): + return False + # This is a performance optimization. Comparing thousands of devices + # can be expensive. + if id(self) == id(other): + return True + return (self.axis_names == other.axis_names and + np.array_equal(self.devices, other.devices)) + + def __hash__(self): + if not hasattr(self, '_hash'): + self._hash = hash( + (self.axis_names, tuple(self.devices.flat), self.devices.shape)) + return self._hash + + def __setattr__(self, name, value): + if hasattr(self, name): + raise RuntimeError("Cannot reassign attributes of immutable mesh objects") + super().__setattr__(name, value) + + def __enter__(self): + new_env = thread_resources.stack[-1].with_mesh(self) + thread_resources.stack.append(new_env) + thread_resources.env = new_env + jax_config.update_thread_local_jit_state( + mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) + return self + + def __exit__(self, exc_type, exc_value, traceback): + thread_resources.stack.pop() + thread_resources.env = thread_resources.stack[-1] + jax_config.update_thread_local_jit_state( + mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) + return False + + @property + def shape(self): + return collections.OrderedDict( + (name, size) + for name, size in util.safe_zip(self.axis_names, self.devices.shape)) + + @property + def size(self): + return np.prod(list(self.shape.values())) + + @property + def empty(self): + return self.devices.ndim == 0 + + @property + def is_multi_process(self): + return self.devices.size != len(self.local_devices) + + @functools.cached_property + def local_mesh(self): + return self._local_mesh(xb.process_index()) + + def _local_mesh(self, process_index): + if self.empty: + return self + is_local_device = np.vectorize( + lambda d: d.process_index == process_index, otypes=[bool])(self.devices) + subcube_indices = [] + # We take the smallest slice of each dimension that doesn't skip any local device. + for axis in range(self.devices.ndim): + other_axes = util.tuple_delete(tuple(range(self.devices.ndim)), axis) + # NOTE: This re-reduces over many axes multiple times, so we could definitely + # optimize it, but I hope it won't be a bottleneck anytime soon. + local_slices = is_local_device.any(other_axes, keepdims=False) + nonzero_indices = np.flatnonzero(local_slices) + start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices)) + subcube_indices.append(slice(start, end + 1)) + subcube_indices = tuple(subcube_indices) + # We only end up with all conditions being true if the local devices formed a + # subcube of the full array. This is because we were biased towards taking a + # "hull" spanned by the devices, and in case the local devices don't form a + # subcube that hull will contain non-local devices. + if not is_local_device[subcube_indices].all(): + raise ValueError( + "When passing host local inputs to pjit or xmap, devices " + "connected to a single host must form a contiguous subcube of the " + "global device mesh") + return Mesh(self.devices[subcube_indices], self.axis_names) + + @property + def device_ids(self): + assert not self.empty + return np.vectorize(lambda d: d.id, otypes=[int])(self.devices) + + def __repr__(self): + if self.empty: + return "Mesh(device_ids=[], axis_names=())" + return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})" + + @functools.cached_property + def local_devices(self): + return [d for d in self.devices.flat + if d.process_index == d.client.process_index()] + + +EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ()) + +class _ThreadResourcesLocalState(threading.local): + + def __init__(self): + self.stack = [EMPTY_ENV] + self.env = self.stack[-1] + +thread_resources = _ThreadResourcesLocalState() diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b8a449043..a03288686 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -41,6 +41,7 @@ from jax._src.sharding import ( XLADeviceAssignment, SingleDeviceSharding, PmapSharding) from jax._src import array from jax._src import dispatch +from jax._src import mesh from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import traceback_util @@ -714,7 +715,7 @@ def pjit( def infer_params(*args, **kwargs): # Putting this outside of wrapped would make resources lexically scoped - resource_env = pxla.thread_resources.env + resource_env = mesh.thread_resources.env pjit_info_args = PjitInfo( fun=fun, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, @@ -1173,7 +1174,7 @@ def _check_unique_resources(axis_resources, arg_name): if multiple_uses: raise ValueError(f"A single {arg_name} specification can map every mesh axis " f"to at most one positional dimension, but {arg_axis_resources.user_spec} " - f"has duplicate entries for {pxla.show_axes(multiple_uses)}") + f"has duplicate entries for {mesh.show_axes(multiple_uses)}") # -------------------- pjit rules -------------------- @@ -1916,7 +1917,7 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax f"{pos_axis_resources.unsynced_user_spec(SpecSync.DIM_PERMUTE)} " f"that uses one or more mesh axes already used by xmap to partition " f"a named axis appearing in its named_shape (both use mesh axes " - f"{pxla.show_axes(overlap)})") + f"{mesh.show_axes(overlap)})") def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources): jaxpr = params["jaxpr"] @@ -2024,7 +2025,7 @@ def with_sharding_constraint(x, axis_resources=_UNSPECIFIED, flatten_axes("with_sharding_constraint shardings", tree, user_shardings)) del user_shardings - resource_env = pxla.thread_resources.env + resource_env = jax._src.mesh.thread_resources.env mesh = resource_env.physical_mesh if config.jax_array: @@ -2171,7 +2172,8 @@ def global_to_local(positional_semantics, avals, shardings, mesh): # replicated avals don't go through this code path. To convert global # avals to host local avals, round trip it via NamedSharding. parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0] - out.append(mesh._global_to_local(get_array_mapping(parsed_pspec), aval)) + out.append(pxla.mesh_global_to_local( + mesh, get_array_mapping(parsed_pspec), aval)) return out @@ -2188,7 +2190,8 @@ def local_to_global(positional_semantics, avals, shardings, mesh): # replicated avals don't go through this code path. To convert host local # avals to global avals, round trip it via NamedSharding. parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0] - out.append(mesh._local_to_global(get_array_mapping(parsed_pspec), aval)) + out.append(pxla.mesh_local_to_global( + mesh, get_array_mapping(parsed_pspec), aval)) return out diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 7cb662bbd..26fa5be7f 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -14,21 +14,21 @@ from jax._src.maps import ( AxisName as AxisName, - EMPTY_ENV as EMPTY_ENV, - FrozenDict as FrozenDict, - ResourceEnv as ResourceEnv, ResourceSet as ResourceSet, SerialLoop as SerialLoop, _PositionalSemantics as _PositionalSemantics, make_xmap_callable as make_xmap_callable, serial_loop as serial_loop, - thread_resources as thread_resources, xmap as xmap, xmap_p as xmap_p, _positional_semantics as _positional_semantics, _prepare_axes as _prepare_axes, ) - +from jax._src.mesh import ( + EMPTY_ENV as EMPTY_ENV, + ResourceEnv as ResourceEnv, + thread_resources as thread_resources, +) # Deprecations diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index e5d804f0b..0a467056b 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -115,8 +115,8 @@ def _handle_array_process_allgather(inp, tiled): host_np_arr = np.expand_dims(host_np_arr, axis=0) aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype) - global_aval = global_mesh._local_to_global( - pxla.get_array_mapping(pspec), aval) + global_aval = pxla.mesh_local_to_global( + global_mesh, pxla.get_array_mapping(pspec), aval) bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( @@ -243,12 +243,13 @@ def _flatten_pspecs(name, in_tree, pspecs_thunk): @functools.lru_cache() def _local_to_global_aval(local_aval, mesh, pspec): - return mesh._local_to_global(pxla.get_array_mapping(pspec), local_aval) + return pxla.mesh_local_to_global(mesh, pxla.get_array_mapping(pspec), + local_aval) @functools.lru_cache() def _global_to_local_aval(global_aval, mesh, pspec): - return mesh._global_to_local( - pxla.get_array_mapping(pspec), global_aval) + return pxla.mesh_global_to_local(mesh, pxla.get_array_mapping(pspec), + global_aval) def _device_put(x, device): try: diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 47c45af6c..d2487e48e 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -18,17 +18,14 @@ from jax._src.interpreters.pxla import ( ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified, AvalDimSharding as AvalDimSharding, Chunked as Chunked, - ContextDecorator as ContextDecorator, DynamicAxisEnv as DynamicAxisEnv, DynamicAxisEnvFrame as DynamicAxisEnvFrame, - EMPTY_ENV as EMPTY_ENV, EmapInfo as EmapInfo, ExecuteReplicated as ExecuteReplicated, Index as Index, InputsHandler as InputsHandler, MapTrace as MapTrace, MapTracer as MapTracer, - MeshAxisName as MeshAxisName, MeshComputation as MeshComputation, MeshDimAssignment as MeshDimAssignment, MeshExecutable as MeshExecutable, @@ -44,8 +41,6 @@ from jax._src.interpreters.pxla import ( PxlaResultHandler as PxlaResultHandler, ReplicaInfo as ReplicaInfo, Replicated as Replicated, - ResourceAxisName as ResourceAxisName, - ResourceEnv as ResourceEnv, ResultsHandler as ResultsHandler, SPMDBatchTrace as SPMDBatchTrace, ShardInfo as ShardInfo, @@ -109,11 +104,9 @@ from jax._src.interpreters.pxla import ( shard_to_full_p as shard_to_full_p, sharding_internal as sharding_internal, sharding_spec_sharding_proto as sharding_spec_sharding_proto, - show_axes as show_axes, spec_to_indices as spec_to_indices, spmd_primitive_batchers as spmd_primitive_batchers, stage_parallel_callable as stage_parallel_callable, - thread_resources as thread_resources, tile_aval_nd as tile_aval_nd, untile_aval_nd as untile_aval_nd, vtile_by_mesh as vtile_by_mesh, @@ -125,19 +118,23 @@ from jax._src.interpreters.pxla import ( xla_pmap_impl_lazy as xla_pmap_impl_lazy, xla_pmap_p as xla_pmap_p, ) +from jax._src.mesh import ( + MeshAxisName as MeshAxisName, + thread_resources as thread_resources, +) # Deprecations +from jax._src.mesh import Mesh as _deprecated_Mesh from jax._src.interpreters.pxla import ( - Mesh as _deprecated_Mesh, PartitionSpec as _deprecated_PartitionSpec, make_sharded_device_array as _deprecated_make_sharded_device_array, ) import typing if typing.TYPE_CHECKING: + from jax._src.mesh import Mesh as Mesh from jax._src.interpreters.pxla import ( - Mesh as Mesh, PartitionSpec as PartitionSpec, make_sharded_device_array as make_sharded_device_array, ) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 23241fa96..035d784d3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -48,6 +48,7 @@ from jax._src import array from jax._src.sharding import NamedSharding, Sharding, GSPMDSharding import jax._src.pjit as pjit_lib from jax._src.pjit import (pjit, pjit_p, FROM_GDA, AUTO) +from jax._src import mesh from jax._src.interpreters import pxla from jax.interpreters import mlir from jax._src import xla_bridge @@ -265,19 +266,20 @@ class PJitTest(jtu.BufferDonationTestCase): def testDifferentNestedMesh(self): with jtu.create_global_mesh((2, 1), ("x", "y")) as m1: with jtu.create_global_mesh((2, 2), ("a", "b")) as m2: - self.assertEqual(pxla.thread_resources.env.physical_mesh, m2) - self.assertEqual(pxla.thread_resources.env.physical_mesh, m1) - self.assertEqual(pxla.thread_resources.env.physical_mesh, - pxla.EMPTY_ENV.physical_mesh) + self.assertEqual(mesh.thread_resources.env.physical_mesh, m2) + self.assertEqual(mesh.thread_resources.env.physical_mesh, m1) + self.assertEqual(mesh.thread_resources.env.physical_mesh, + mesh.EMPTY_ENV.physical_mesh) def testSameNestedMesh(self): mesh = jtu.create_global_mesh((2, 1), ("a", "b")) + thread_resources = jax._src.mesh.thread_resources with mesh as m1: with mesh as m2: - self.assertEqual(pxla.thread_resources.env.physical_mesh, m2) - self.assertEqual(pxla.thread_resources.env.physical_mesh, m1) - self.assertEqual(pxla.thread_resources.env.physical_mesh, - pxla.EMPTY_ENV.physical_mesh) + self.assertEqual(thread_resources.env.physical_mesh, m2) + self.assertEqual(thread_resources.env.physical_mesh, m1) + self.assertEqual(thread_resources.env.physical_mesh, + jax._src.mesh.EMPTY_ENV.physical_mesh) def testMeshDecorator(self): x = jnp.arange(8)