Split Mesh and ResourceEnv into a new module jax._src.mesh.

This work is an effort to reduce cyclic dependencies in JAX internals.

Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.

PiperOrigin-RevId: 515667671
This commit is contained in:
Peter Hawkins 2023-03-10 10:07:37 -08:00 committed by jax authors
parent 00b90e9073
commit 623282715d
9 changed files with 348 additions and 308 deletions

View File

@ -27,12 +27,12 @@ from jax import lax
from jax._src import core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import pjit
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
@ -310,7 +310,7 @@ class ShardingCallbackInfo:
def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
callback):
mesh = pxla.thread_resources.env.physical_mesh
mesh = mesh_lib.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, mlir.ShardingContext):

View File

@ -31,7 +31,7 @@
from __future__ import annotations
import enum
from contextlib import contextmanager, ContextDecorator
from contextlib import contextmanager
from collections import defaultdict, OrderedDict, namedtuple
import dataclasses
from functools import partial, lru_cache, cached_property
@ -60,6 +60,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh
from jax._src import profiler
from jax._src import sharding as sharding_internal
from jax._src import source_info_util
@ -68,7 +69,6 @@ from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src import config as jax_config
from jax._src.config import flags
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.interpreters import ad
@ -113,10 +113,11 @@ Replicated = pmap_lib.Replicated
_UNSHARDED_INSTANCE = NoSharding()
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
Mesh = jax._src.mesh.Mesh
MeshAxisName = mesh.MeshAxisName
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = pmap_lib.ShardingSpec
MeshAxisName = Any
OpShardingType = Any
PartitionSpec = sharding_internal.PartitionSpec
@ -191,7 +192,7 @@ def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType
# specially over some mesh axes.
if replicated_maxes:
last_tile_dims = []
axes_by_type: Dict[OpShardingType, List[MeshAxisName]] = {}
axes_by_type: Dict[OpShardingType, List[jax._src.mesh.MeshAxisName]] = {}
size_by_type: Dict[OpShardingType, int] = defaultdict(lambda: 1)
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys()))
for axis, size in replicated_maxes:
@ -541,7 +542,7 @@ that would mean that a flat list of chunks would get assigned to a flattened lis
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]
ArrayMapping = OrderedDictType[mesh.MeshAxisName, int]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource,
UnspecifiedValue]
@ -691,7 +692,7 @@ def make_sharded_device_array(
sharding_spec = _create_pmap_sharding_spec(aval)
if jax.config.jax_array:
mesh = thread_resources.env.physical_mesh
mesh = jax._src.mesh.thread_resources.env.physical_mesh
if mesh.empty:
sharding = sharding_internal.PmapSharding(
np.asarray([d.device() for d in device_buffers]), sharding_spec)
@ -1993,8 +1994,9 @@ def global_avals_to_results_handler(
else:
# This path is taken when the outputs are SDAs.
assert all(isinstance(s, sharding_internal.NamedSharding) for s in shardings)
local_out_avals = [s.mesh._global_to_local(get_array_mapping(s.spec), aval)
for aval, s in safe_zip(global_out_avals, shardings)]
local_out_avals = [
mesh_global_to_local(s.mesh, get_array_mapping(s.spec), aval) # type: ignore
for aval, s in safe_zip(global_out_avals, shardings)]
local_shardings = [sharding_internal.NamedSharding(s.mesh.local_mesh, s.spec) # type: ignore
for s in shardings]
return local_avals_to_results_handler(local_out_avals, local_shardings)
@ -2388,250 +2390,6 @@ mlir.register_lowering(xla_pmap_p, _pmap_lowering)
# ------------------- xmap -------------------
class Mesh(ContextDecorator):
"""Declare the hardware resources available in the scope of this manager.
In particular, all ``axis_names`` become valid resource names inside the
managed block and can be used e.g. in the ``in_axis_resources`` argument of
:py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming
model (https://jax.readthedocs.io/en/latest/multi_process.html)
and the Distributed arrays and automatic parallelization tutorial
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
If you are compiling in multiple threads, make sure that the
``with Mesh`` context manager is inside the function that the threads will
execute.
Args:
devices: A NumPy ndarray object containing JAX device objects (as
obtained e.g. from :py:func:`jax.devices`).
axis_names: A sequence of resource axis names to be assigned to the
dimensions of the ``devices`` argument. Its length should match the
rank of ``devices``.
Example:
>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
"""
devices: np.ndarray
axis_names: Tuple[MeshAxisName, ...]
def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]],
axis_names: Union[str, Sequence[MeshAxisName]]):
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
if isinstance(axis_names, str):
axis_names = (axis_names,)
assert devices.ndim == len(axis_names)
# TODO: Make sure that devices are unique? At least with the quick and
# dirty check that the array size is not larger than the number of
# available devices?
self.devices = devices.copy()
self.devices.flags.writeable = False
self.axis_names = tuple(axis_names)
def __eq__(self, other):
if not isinstance(other, Mesh):
return False
# This is a performance optimization. Comparing thousands of devices
# can be expensive.
if id(self) == id(other):
return True
return (self.axis_names == other.axis_names and
np.array_equal(self.devices, other.devices))
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash(
(self.axis_names, tuple(self.devices.flat), self.devices.shape))
return self._hash
def __setattr__(self, name, value):
if hasattr(self, name):
raise RuntimeError("Cannot reassign attributes of immutable mesh objects")
super().__setattr__(name, value)
def __enter__(self):
new_env = thread_resources.stack[-1].with_mesh(self)
thread_resources.stack.append(new_env)
thread_resources.env = new_env
jax_config.update_thread_local_jit_state(
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
if not t.physical_mesh.empty))
return self
def __exit__(self, exc_type, exc_value, traceback):
thread_resources.stack.pop()
thread_resources.env = thread_resources.stack[-1]
jax_config.update_thread_local_jit_state(
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
if not t.physical_mesh.empty))
return False
@property
def shape(self):
return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape))
@property
def size(self):
return np.prod(list(self.shape.values()))
@property
def empty(self):
return self.devices.ndim == 0
@property
def is_multi_process(self):
return self.devices.size != len(self.local_devices)
@cached_property
def local_mesh(self):
return self._local_mesh(xb.process_index())
def _local_mesh(self, process_index):
if self.empty:
return self
is_local_device = np.vectorize(
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
subcube_indices = []
# We take the smallest slice of each dimension that doesn't skip any local device.
for axis in range(self.devices.ndim):
other_axes = tuple_delete(tuple(range(self.devices.ndim)), axis)
# NOTE: This re-reduces over many axes multiple times, so we could definitely
# optimize it, but I hope it won't be a bottleneck anytime soon.
local_slices = is_local_device.any(other_axes, keepdims=False)
nonzero_indices = np.flatnonzero(local_slices)
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
subcube_indices.append(slice(start, end + 1))
subcube_indices = tuple(subcube_indices)
# We only end up with all conditions being true if the local devices formed a
# subcube of the full array. This is because we were biased towards taking a
# "hull" spanned by the devices, and in case the local devices don't form a
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices].all():
raise ValueError(
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
return Mesh(self.devices[subcube_indices], self.axis_names)
@property
def device_ids(self):
assert not self.empty
return np.vectorize(lambda d: d.id, otypes=[int])(self.devices)
def __repr__(self):
if self.empty:
return "Mesh(device_ids=[], axis_names=())"
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})"
@cached_property
def local_devices(self):
return [d for d in self.devices.flat
if d.process_index == d.client.process_index()]
def _local_to_global(self, axes: ArrayMapping, aval):
return untile_aval_nd(self.shape, axes,
tile_aval_nd(self.local_mesh.shape, axes, aval))
def _global_to_local(self, axes: ArrayMapping, aval):
return untile_aval_nd(self.local_mesh.shape, axes,
tile_aval_nd(self.shape, axes, aval))
ResourceAxisName = core.AxisName
class Loop(NamedTuple):
name: ResourceAxisName
length: int
def show_axes(axes):
return ", ".join(sorted(f"`{a}`" for a in axes))
class ResourceEnv(NamedTuple):
physical_mesh: Mesh
loops: Tuple[Loop, ...]
def with_mesh(self, mesh: Mesh):
overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names))
if overlap:
raise ValueError(f"Cannot update the mesh of the current resource "
f"environment. The new mesh shadows already defined axes "
f"{show_axes(overlap)}")
return self._replace(physical_mesh=mesh)
def with_extra_loop(self, loop: Loop):
if loop.name in self.resource_axes:
raise ValueError(f"Cannot extend the resource environment with loop named "
f"`{loop.name}`. An axis of this name is already defined!")
return self._replace(loops=self.loops + (loop,))
@property
def physical_resource_axes(self) -> Set[ResourceAxisName]:
return set(self.physical_mesh.axis_names)
@property
def loop_resource_axes(self) -> Set[ResourceAxisName]:
return {loop.name for loop in self.loops}
@property
def resource_axes(self) -> Set[ResourceAxisName]:
return self.physical_resource_axes | self.loop_resource_axes
@property
def shape(self):
shape = self.physical_mesh.shape
shape.update(self.loops)
return shape
@property
def local_shape(self):
shape = self.physical_mesh.local_mesh.shape
shape.update(self.loops)
return shape
def __repr__(self):
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ())
class _ThreadResourcesLocalState(threading.local):
def __init__(self):
self.stack = [EMPTY_ENV]
self.env = self.stack[-1]
thread_resources = _ThreadResourcesLocalState()
def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
assert isinstance(aval, ShapedArray)
shape = list(aval.shape)
@ -2653,6 +2411,15 @@ def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
return aval.update(shape=tuple(shape), named_shape=named_shape)
def mesh_local_to_global(mesh, axes: ArrayMapping, aval):
return untile_aval_nd(mesh.shape, axes,
tile_aval_nd(mesh.local_mesh.shape, axes, aval))
def mesh_global_to_local(mesh, axes: ArrayMapping, aval):
return untile_aval_nd(mesh.local_mesh.shape, axes,
tile_aval_nd(mesh.shape, axes, aval))
class SPMDBatchTrace(batching.BatchTrace):
def get_axis_primitive_batcher(self, primitive, frame):
if primitive in spmd_primitive_batchers:
@ -2688,7 +2455,7 @@ def _full_to_shard_abstract_eval(x, axes, mesh, **_):
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
return tile_aval_nd(mesh.shape, axes, x)
def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[MeshAxisName], mesh: Mesh):
def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[mesh.MeshAxisName], mesh: Mesh):
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
and all others as replicated.
"""
@ -2714,7 +2481,8 @@ def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[MeshAxisName
return proto
@partial(mlir.register_lowering, full_to_shard_p)
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
manual_axes: FrozenSet[mesh.MeshAxisName]):
# TODO: Can we short-circuit for replicated values? Probably not.
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
@ -2733,7 +2501,8 @@ def _shard_to_full_abstract_eval(x, axes, mesh, **_):
return untile_aval_nd(mesh.shape, axes, x)
@partial(mlir.register_lowering, shard_to_full_p)
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
manual_axes: FrozenSet[mesh.MeshAxisName]):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
proto = manual_proto(aval_in, manual_axes, mesh)
@ -2744,7 +2513,7 @@ def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_ax
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims),
@lu.transformation
def vtile_manual(manual_axes: FrozenSet[MeshAxisName],
def vtile_manual(manual_axes: FrozenSet[mesh.MeshAxisName],
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
@ -2763,7 +2532,7 @@ class TileVectorize:
@dataclasses.dataclass(frozen=True)
class TileManual:
manual_axes: FrozenSet[MeshAxisName]
manual_axes: FrozenSet[mesh.MeshAxisName]
TilingMethod = Union[TileVectorize, TileManual]
@ -3415,7 +3184,7 @@ def _get_normalized_avals_and_shardings(
in_sharding = i
else:
assert isinstance(i, sharding_internal.NamedSharding)
aval = i.mesh._global_to_local(
aval = mesh_global_to_local(i.mesh,
cast(ArrayMapping, get_array_mapping(i.spec)), gaval) # pylint: disable=g-bare-generic
in_sharding = sharding_internal.NamedSharding(i.mesh.local_mesh, i.spec)
avals.append(aval)

View File

@ -22,6 +22,7 @@ from functools import wraps, partial, partialmethod, lru_cache
from jax import numpy as jnp
from jax._src import core
from jax._src import mesh
from jax._src import linear_util as lu
from jax import stages
from jax._src import dispatch
@ -85,13 +86,12 @@ class FrozenDict(abc.Mapping):
# Multi-dimensional generalized map
AxisName = core.AxisName
ResourceAxisName = AxisName # Different name just for documentation purposes
# TODO(https://github.com/google/jax/issues/13487): Remove Mesh in
# 3 months from `jax.experimental.maps.Mesh`.
Mesh = pxla.Mesh
ResourceEnv = pxla.ResourceEnv
EMPTY_ENV = pxla.EMPTY_ENV
thread_resources = pxla.thread_resources
ResourceAxisName = mesh.ResourceAxisName # Different name just for documentation purposes
Mesh = mesh.Mesh
MeshAxisName = mesh.MeshAxisName
ResourceEnv = mesh.ResourceEnv
EMPTY_ENV = mesh.EMPTY_ENV
thread_resources = mesh.thread_resources
class SerialLoop:
@ -161,7 +161,7 @@ def serial_loop(name: ResourceAxisName, length: int):
axis_resources={'i': 'l'})(x)
"""
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
thread_resources.env = old_env.with_extra_loop(pxla.Loop(name, length))
thread_resources.env = old_env.with_extra_loop(mesh.Loop(name, length))
try:
yield
finally:
@ -686,7 +686,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
mesh = resource_env.physical_mesh
global_in_avals = [
av if ips == _PositionalSemantics.GLOBAL else mesh._local_to_global(ax, av)
av if ips == _PositionalSemantics.GLOBAL else pxla.mesh_local_to_global(mesh, ax, av)
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
]
in_is_global = [ips == _PositionalSemantics.GLOBAL or not ia
@ -964,7 +964,7 @@ def _resource_typing_xmap(avals,
raise JAXTypeError(
f"Detected disallowed xmap axis name shadowing at "
f"{source_info_util.summarize(source_info)} "
f"(shadowed axes: {pxla.show_axes(overlap)})")
f"(shadowed axes: {mesh.show_axes(overlap)})")
if resource_env.physical_mesh != params['resource_env'].physical_mesh:
raise RuntimeError("Changing the physical mesh is not allowed inside xmap.")
@ -992,9 +992,9 @@ def _resource_typing_xmap(avals,
raise JAXTypeError(
f"One of xmapped function ({params['name']}) outputs is broadcast "
f"along axis `{baxis}` which is assigned to resources "
f"{pxla.show_axes(baxis_resources)}, but the output is already "
f"partitioned along {pxla.show_axes(overlap)}, because its "
f"named shape contains {pxla.show_axes(partitioning_axes)}")
f"{mesh.show_axes(baxis_resources)}, but the output is already "
f"partitioned along {mesh.show_axes(overlap)}, because its "
f"named shape contains {mesh.show_axes(partitioning_axes)}")
pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
@ -1419,8 +1419,9 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
f = pxla.vtile_by_mesh(f, mesh, mesh_in_axes, mesh_out_axes)
# XXX: We modify mesh_in_axes and mesh_out_axes here
def add_spmd_axes(flat_mesh_axes: Sequence[pxla.ArrayMapping],
flat_extra_axes: Optional[Sequence[Sequence[Sequence[pxla.MeshAxisName]]]]):
def add_spmd_axes(
flat_mesh_axes: Sequence[pxla.ArrayMapping],
flat_extra_axes: Optional[Sequence[Sequence[Sequence[MeshAxisName]]]]):
if flat_extra_axes is None:
return
for axes, extra in zip(flat_mesh_axes, flat_extra_axes):

267
jax/_src/mesh.py Normal file
View File

@ -0,0 +1,267 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definitions of Mesh and ResourceEnv."""
from __future__ import annotations
import collections
import contextlib
import functools
import threading
from typing import Any, Hashable, NamedTuple, Set, Sequence, Tuple, Union
import numpy as np
from jax._src import config as jax_config
from jax._src import xla_bridge as xb
from jax._src import util
from jax._src.lib import xla_client as xc
MeshAxisName = Any
ResourceAxisName = Hashable
class Loop(NamedTuple):
name: ResourceAxisName
length: int
def show_axes(axes):
return ", ".join(sorted(f"`{a}`" for a in axes))
class ResourceEnv(NamedTuple):
physical_mesh: Mesh
loops: Tuple[Loop, ...]
def with_mesh(self, mesh: Mesh):
overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names))
if overlap:
raise ValueError(f"Cannot update the mesh of the current resource "
f"environment. The new mesh shadows already defined axes "
f"{show_axes(overlap)}")
return self._replace(physical_mesh=mesh)
def with_extra_loop(self, loop: Loop):
if loop.name in self.resource_axes:
raise ValueError(f"Cannot extend the resource environment with loop named "
f"`{loop.name}`. An axis of this name is already defined!")
return self._replace(loops=self.loops + (loop,))
@property
def physical_resource_axes(self) -> Set[ResourceAxisName]:
return set(self.physical_mesh.axis_names)
@property
def loop_resource_axes(self) -> Set[ResourceAxisName]:
return {loop.name for loop in self.loops}
@property
def resource_axes(self) -> Set[ResourceAxisName]:
return self.physical_resource_axes | self.loop_resource_axes
@property
def shape(self):
shape = self.physical_mesh.shape
shape.update(self.loops)
return shape
@property
def local_shape(self):
shape = self.physical_mesh.local_mesh.shape
shape.update(self.loops)
return shape
def __repr__(self):
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"
class Mesh(contextlib.ContextDecorator):
"""Declare the hardware resources available in the scope of this manager.
In particular, all ``axis_names`` become valid resource names inside the
managed block and can be used e.g. in the ``in_axis_resources`` argument of
:py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming
model (https://jax.readthedocs.io/en/latest/multi_process.html)
and the Distributed arrays and automatic parallelization tutorial
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
If you are compiling in multiple threads, make sure that the
``with Mesh`` context manager is inside the function that the threads will
execute.
Args:
devices: A NumPy ndarray object containing JAX device objects (as
obtained e.g. from :py:func:`jax.devices`).
axis_names: A sequence of resource axis names to be assigned to the
dimensions of the ``devices`` argument. Its length should match the
rank of ``devices``.
Example:
>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
"""
devices: np.ndarray
axis_names: Tuple[MeshAxisName, ...]
def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]],
axis_names: Union[str, Sequence[MeshAxisName]]):
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
if isinstance(axis_names, str):
axis_names = (axis_names,)
assert devices.ndim == len(axis_names)
# TODO: Make sure that devices are unique? At least with the quick and
# dirty check that the array size is not larger than the number of
# available devices?
self.devices = devices.copy()
self.devices.flags.writeable = False
self.axis_names = tuple(axis_names)
def __eq__(self, other):
if not isinstance(other, Mesh):
return False
# This is a performance optimization. Comparing thousands of devices
# can be expensive.
if id(self) == id(other):
return True
return (self.axis_names == other.axis_names and
np.array_equal(self.devices, other.devices))
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash(
(self.axis_names, tuple(self.devices.flat), self.devices.shape))
return self._hash
def __setattr__(self, name, value):
if hasattr(self, name):
raise RuntimeError("Cannot reassign attributes of immutable mesh objects")
super().__setattr__(name, value)
def __enter__(self):
new_env = thread_resources.stack[-1].with_mesh(self)
thread_resources.stack.append(new_env)
thread_resources.env = new_env
jax_config.update_thread_local_jit_state(
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
if not t.physical_mesh.empty))
return self
def __exit__(self, exc_type, exc_value, traceback):
thread_resources.stack.pop()
thread_resources.env = thread_resources.stack[-1]
jax_config.update_thread_local_jit_state(
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
if not t.physical_mesh.empty))
return False
@property
def shape(self):
return collections.OrderedDict(
(name, size)
for name, size in util.safe_zip(self.axis_names, self.devices.shape))
@property
def size(self):
return np.prod(list(self.shape.values()))
@property
def empty(self):
return self.devices.ndim == 0
@property
def is_multi_process(self):
return self.devices.size != len(self.local_devices)
@functools.cached_property
def local_mesh(self):
return self._local_mesh(xb.process_index())
def _local_mesh(self, process_index):
if self.empty:
return self
is_local_device = np.vectorize(
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
subcube_indices = []
# We take the smallest slice of each dimension that doesn't skip any local device.
for axis in range(self.devices.ndim):
other_axes = util.tuple_delete(tuple(range(self.devices.ndim)), axis)
# NOTE: This re-reduces over many axes multiple times, so we could definitely
# optimize it, but I hope it won't be a bottleneck anytime soon.
local_slices = is_local_device.any(other_axes, keepdims=False)
nonzero_indices = np.flatnonzero(local_slices)
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
subcube_indices.append(slice(start, end + 1))
subcube_indices = tuple(subcube_indices)
# We only end up with all conditions being true if the local devices formed a
# subcube of the full array. This is because we were biased towards taking a
# "hull" spanned by the devices, and in case the local devices don't form a
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices].all():
raise ValueError(
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
return Mesh(self.devices[subcube_indices], self.axis_names)
@property
def device_ids(self):
assert not self.empty
return np.vectorize(lambda d: d.id, otypes=[int])(self.devices)
def __repr__(self):
if self.empty:
return "Mesh(device_ids=[], axis_names=())"
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})"
@functools.cached_property
def local_devices(self):
return [d for d in self.devices.flat
if d.process_index == d.client.process_index()]
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ())
class _ThreadResourcesLocalState(threading.local):
def __init__(self):
self.stack = [EMPTY_ENV]
self.env = self.stack[-1]
thread_resources = _ThreadResourcesLocalState()

View File

@ -41,6 +41,7 @@ from jax._src.sharding import (
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
from jax._src import array
from jax._src import dispatch
from jax._src import mesh
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import traceback_util
@ -714,7 +715,7 @@ def pjit(
def infer_params(*args, **kwargs):
# Putting this outside of wrapped would make resources lexically scoped
resource_env = pxla.thread_resources.env
resource_env = mesh.thread_resources.env
pjit_info_args = PjitInfo(
fun=fun, in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
@ -1173,7 +1174,7 @@ def _check_unique_resources(axis_resources, arg_name):
if multiple_uses:
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
f"has duplicate entries for {pxla.show_axes(multiple_uses)}")
f"has duplicate entries for {mesh.show_axes(multiple_uses)}")
# -------------------- pjit rules --------------------
@ -1916,7 +1917,7 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax
f"{pos_axis_resources.unsynced_user_spec(SpecSync.DIM_PERMUTE)} "
f"that uses one or more mesh axes already used by xmap to partition "
f"a named axis appearing in its named_shape (both use mesh axes "
f"{pxla.show_axes(overlap)})")
f"{mesh.show_axes(overlap)})")
def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources):
jaxpr = params["jaxpr"]
@ -2024,7 +2025,7 @@ def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
del user_shardings
resource_env = pxla.thread_resources.env
resource_env = jax._src.mesh.thread_resources.env
mesh = resource_env.physical_mesh
if config.jax_array:
@ -2171,7 +2172,8 @@ def global_to_local(positional_semantics, avals, shardings, mesh):
# replicated avals don't go through this code path. To convert global
# avals to host local avals, round trip it via NamedSharding.
parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0]
out.append(mesh._global_to_local(get_array_mapping(parsed_pspec), aval))
out.append(pxla.mesh_global_to_local(
mesh, get_array_mapping(parsed_pspec), aval))
return out
@ -2188,7 +2190,8 @@ def local_to_global(positional_semantics, avals, shardings, mesh):
# replicated avals don't go through this code path. To convert host local
# avals to global avals, round trip it via NamedSharding.
parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0]
out.append(mesh._local_to_global(get_array_mapping(parsed_pspec), aval))
out.append(pxla.mesh_local_to_global(
mesh, get_array_mapping(parsed_pspec), aval))
return out

View File

@ -14,21 +14,21 @@
from jax._src.maps import (
AxisName as AxisName,
EMPTY_ENV as EMPTY_ENV,
FrozenDict as FrozenDict,
ResourceEnv as ResourceEnv,
ResourceSet as ResourceSet,
SerialLoop as SerialLoop,
_PositionalSemantics as _PositionalSemantics,
make_xmap_callable as make_xmap_callable,
serial_loop as serial_loop,
thread_resources as thread_resources,
xmap as xmap,
xmap_p as xmap_p,
_positional_semantics as _positional_semantics,
_prepare_axes as _prepare_axes,
)
from jax._src.mesh import (
EMPTY_ENV as EMPTY_ENV,
ResourceEnv as ResourceEnv,
thread_resources as thread_resources,
)
# Deprecations

View File

@ -115,8 +115,8 @@ def _handle_array_process_allgather(inp, tiled):
host_np_arr = np.expand_dims(host_np_arr, axis=0)
aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
global_aval = global_mesh._local_to_global(
pxla.get_array_mapping(pspec), aval)
global_aval = pxla.mesh_local_to_global(
global_mesh, pxla.get_array_mapping(pspec), aval)
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
global_arr = array.make_array_from_single_device_arrays(
@ -243,12 +243,13 @@ def _flatten_pspecs(name, in_tree, pspecs_thunk):
@functools.lru_cache()
def _local_to_global_aval(local_aval, mesh, pspec):
return mesh._local_to_global(pxla.get_array_mapping(pspec), local_aval)
return pxla.mesh_local_to_global(mesh, pxla.get_array_mapping(pspec),
local_aval)
@functools.lru_cache()
def _global_to_local_aval(global_aval, mesh, pspec):
return mesh._global_to_local(
pxla.get_array_mapping(pspec), global_aval)
return pxla.mesh_global_to_local(mesh, pxla.get_array_mapping(pspec),
global_aval)
def _device_put(x, device):
try:

View File

@ -18,17 +18,14 @@ from jax._src.interpreters.pxla import (
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
AvalDimSharding as AvalDimSharding,
Chunked as Chunked,
ContextDecorator as ContextDecorator,
DynamicAxisEnv as DynamicAxisEnv,
DynamicAxisEnvFrame as DynamicAxisEnvFrame,
EMPTY_ENV as EMPTY_ENV,
EmapInfo as EmapInfo,
ExecuteReplicated as ExecuteReplicated,
Index as Index,
InputsHandler as InputsHandler,
MapTrace as MapTrace,
MapTracer as MapTracer,
MeshAxisName as MeshAxisName,
MeshComputation as MeshComputation,
MeshDimAssignment as MeshDimAssignment,
MeshExecutable as MeshExecutable,
@ -44,8 +41,6 @@ from jax._src.interpreters.pxla import (
PxlaResultHandler as PxlaResultHandler,
ReplicaInfo as ReplicaInfo,
Replicated as Replicated,
ResourceAxisName as ResourceAxisName,
ResourceEnv as ResourceEnv,
ResultsHandler as ResultsHandler,
SPMDBatchTrace as SPMDBatchTrace,
ShardInfo as ShardInfo,
@ -109,11 +104,9 @@ from jax._src.interpreters.pxla import (
shard_to_full_p as shard_to_full_p,
sharding_internal as sharding_internal,
sharding_spec_sharding_proto as sharding_spec_sharding_proto,
show_axes as show_axes,
spec_to_indices as spec_to_indices,
spmd_primitive_batchers as spmd_primitive_batchers,
stage_parallel_callable as stage_parallel_callable,
thread_resources as thread_resources,
tile_aval_nd as tile_aval_nd,
untile_aval_nd as untile_aval_nd,
vtile_by_mesh as vtile_by_mesh,
@ -125,19 +118,23 @@ from jax._src.interpreters.pxla import (
xla_pmap_impl_lazy as xla_pmap_impl_lazy,
xla_pmap_p as xla_pmap_p,
)
from jax._src.mesh import (
MeshAxisName as MeshAxisName,
thread_resources as thread_resources,
)
# Deprecations
from jax._src.mesh import Mesh as _deprecated_Mesh
from jax._src.interpreters.pxla import (
Mesh as _deprecated_Mesh,
PartitionSpec as _deprecated_PartitionSpec,
make_sharded_device_array as _deprecated_make_sharded_device_array,
)
import typing
if typing.TYPE_CHECKING:
from jax._src.mesh import Mesh as Mesh
from jax._src.interpreters.pxla import (
Mesh as Mesh,
PartitionSpec as PartitionSpec,
make_sharded_device_array as make_sharded_device_array,
)

View File

@ -48,6 +48,7 @@ from jax._src import array
from jax._src.sharding import NamedSharding, Sharding, GSPMDSharding
import jax._src.pjit as pjit_lib
from jax._src.pjit import (pjit, pjit_p, FROM_GDA, AUTO)
from jax._src import mesh
from jax._src.interpreters import pxla
from jax.interpreters import mlir
from jax._src import xla_bridge
@ -265,19 +266,20 @@ class PJitTest(jtu.BufferDonationTestCase):
def testDifferentNestedMesh(self):
with jtu.create_global_mesh((2, 1), ("x", "y")) as m1:
with jtu.create_global_mesh((2, 2), ("a", "b")) as m2:
self.assertEqual(pxla.thread_resources.env.physical_mesh, m2)
self.assertEqual(pxla.thread_resources.env.physical_mesh, m1)
self.assertEqual(pxla.thread_resources.env.physical_mesh,
pxla.EMPTY_ENV.physical_mesh)
self.assertEqual(mesh.thread_resources.env.physical_mesh, m2)
self.assertEqual(mesh.thread_resources.env.physical_mesh, m1)
self.assertEqual(mesh.thread_resources.env.physical_mesh,
mesh.EMPTY_ENV.physical_mesh)
def testSameNestedMesh(self):
mesh = jtu.create_global_mesh((2, 1), ("a", "b"))
thread_resources = jax._src.mesh.thread_resources
with mesh as m1:
with mesh as m2:
self.assertEqual(pxla.thread_resources.env.physical_mesh, m2)
self.assertEqual(pxla.thread_resources.env.physical_mesh, m1)
self.assertEqual(pxla.thread_resources.env.physical_mesh,
pxla.EMPTY_ENV.physical_mesh)
self.assertEqual(thread_resources.env.physical_mesh, m2)
self.assertEqual(thread_resources.env.physical_mesh, m1)
self.assertEqual(thread_resources.env.physical_mesh,
jax._src.mesh.EMPTY_ENV.physical_mesh)
def testMeshDecorator(self):
x = jnp.arange(8)