mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
1317 lines
48 KiB
Python
1317 lines
48 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import operator as op
|
|
import sys
|
|
from typing import (Any, Mapping, Optional, OrderedDict, NamedTuple, Sequence,
|
|
Union, cast)
|
|
|
|
from jax._src import mesh as mesh_lib
|
|
from jax._src.op_shardings import (
|
|
is_op_sharding_replicated, are_op_shardings_equal, get_num_ways_dim_sharded,
|
|
op_sharding_to_indices)
|
|
from jax._src import sharding
|
|
from jax._src import sharding_specs
|
|
from jax._src import tree_util
|
|
from jax._src import util
|
|
from jax._src import xla_bridge
|
|
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.lib import xla_extension_version
|
|
from jax._src.partition_spec import PartitionSpec
|
|
|
|
import numpy as np
|
|
|
|
if sys.version_info >= (3, 9):
|
|
OrderedDictType = OrderedDict
|
|
else:
|
|
OrderedDictType = dict
|
|
|
|
|
|
Shape = tuple[int, ...]
|
|
Device = xc.Device
|
|
Index = tuple[slice, ...]
|
|
XLADeviceAssignment = tuple[Device, ...]
|
|
|
|
|
|
# Shardings that inherit from XLACompatibleSharding should implement the
|
|
# `_device_assignment` property and `_to_xla_hlo_sharding` method.
|
|
@use_cpp_class(xc.XLACompatibleSharding)
|
|
class XLACompatibleSharding(sharding.Sharding):
|
|
"""A `Sharding` that describes shardings expressible to XLA.
|
|
|
|
Any ``Sharding`` that is a subclass of ``XLACompatibleSharding`` will work
|
|
with all JAX APIs and transformations that use XLA.
|
|
"""
|
|
|
|
# Abstract methods below that subclasses should implement.
|
|
|
|
@property
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
#############################################################################
|
|
# Default implementations below that all subclasses will inherit.
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
|
hlo_sharding = self._to_xla_hlo_sharding(len(global_shape))
|
|
gspmd_sharding = GSPMDSharding(self._device_assignment, hlo_sharding)
|
|
return gspmd_sharding.devices_indices_map(global_shape)
|
|
|
|
@functools.cached_property
|
|
def _addressable_device_assignment(self) -> XLADeviceAssignment:
|
|
return tuple(d for d in self._device_assignment
|
|
if d.process_index == d.client.process_index())
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
|
hlo_sharding = self._to_xla_hlo_sharding(len(global_shape))
|
|
if is_op_sharding_replicated(hlo_sharding):
|
|
return global_shape
|
|
partitions, _ = get_num_ways_dim_sharded(hlo_sharding)
|
|
assert len(partitions) == len(global_shape), (len(partitions), len(global_shape))
|
|
out = []
|
|
for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
|
|
try:
|
|
quotient, remainder = divmod(s, p)
|
|
except TypeError:
|
|
# TODO Figure out how to partition dynamic shapes
|
|
raise NotImplementedError
|
|
if remainder != 0:
|
|
raise ValueError(
|
|
f"Sharding {self} implies that array axis {dim} is partitioned "
|
|
f"{p} times, but the dimension size is {s} "
|
|
f"(full shape: {global_shape}, "
|
|
f"per-dimension tiling factors: {partitions} should evenly divide "
|
|
"the shape)")
|
|
out.append(quotient)
|
|
return tuple(out)
|
|
|
|
def is_equivalent_to(self: XLACompatibleSharding, # type: ignore
|
|
other: XLACompatibleSharding, ndim: int) -> bool:
|
|
try:
|
|
if xla_extension_version >= 168:
|
|
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
|
|
other._to_xla_hlo_sharding(ndim))
|
|
and self._device_assignment == other._device_assignment and
|
|
self.memory_kind == other.memory_kind)
|
|
else:
|
|
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
|
|
other._to_xla_hlo_sharding(ndim))
|
|
and self._device_assignment == other._device_assignment)
|
|
# NotImplementedError is raised by PmapSharding because it can't lower
|
|
# to OpSharding. So if `other` is a PmapSharding, default to a strict
|
|
# equality check.
|
|
except NotImplementedError:
|
|
return self == other
|
|
|
|
|
|
@functools.lru_cache()
|
|
def _check_mesh_resource_axis(mesh, parsed_pspec):
|
|
try:
|
|
[mesh.shape[r] for p in parsed_pspec if p is not None
|
|
for r in p]
|
|
except KeyError as e:
|
|
raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is "
|
|
"undefined.") from None
|
|
|
|
|
|
def hashed_index(x) -> int:
|
|
# This works for both `pjit`/`xmap` indices and `pmap` indices (which might
|
|
# have an integer instead of a slice).
|
|
assert all(v.step is None for v in x if isinstance(v, slice))
|
|
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))
|
|
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
|
|
try:
|
|
device_indices_map_fn = sharding.devices_indices_map
|
|
except AttributeError:
|
|
raise ValueError(
|
|
f'Cannot calculate replica ids from sharding: {sharding}. Please '
|
|
'create a device to index mapping for your sharding from which replica '
|
|
'ids will be calculated.') from None
|
|
|
|
index_to_replica: dict[int, int] = collections.Counter()
|
|
out = {}
|
|
for device, index in device_indices_map_fn(global_shape).items():
|
|
h_index = hashed_index(index)
|
|
replica_id = index_to_replica[h_index]
|
|
index_to_replica[h_index] += 1
|
|
out[device] = replica_id
|
|
return out
|
|
|
|
|
|
@use_cpp_class(xc.NamedSharding)
|
|
class NamedSharding(XLACompatibleSharding):
|
|
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
|
|
|
|
``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name.
|
|
|
|
``Mesh`` is a NumPy array of JAX devices in a multi-dimensional grid,
|
|
where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for
|
|
``Mesh`` is "logical mesh".
|
|
|
|
``PartitionSpec`` is a tuple, whose elements can be a ``None``,
|
|
a mesh axis or a tuple of mesh axes. Each element describes how an input
|
|
dimension is partitioned across zero or more mesh dimensions. For example,
|
|
PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data
|
|
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
|
|
across ``y`` axis of the mesh.
|
|
|
|
The Distributed arrays and automatic parallelization
|
|
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)
|
|
goes into more details and has diagrams to help explain the concept about
|
|
``Mesh`` and ``PartitionSpec``.
|
|
|
|
Args:
|
|
mesh: A ``jax.sharding.Mesh`` object.
|
|
spec: A ``jax.sharding.PartitionSpec`` object.
|
|
|
|
Example:
|
|
|
|
>>> from jax.sharding import Mesh
|
|
>>> from jax.sharding import PartitionSpec as P
|
|
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
|
>>> spec = P('x', 'y')
|
|
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
"""
|
|
|
|
mesh: mesh_lib.Mesh
|
|
spec: PartitionSpec
|
|
memory_kind: Optional[str]
|
|
_parsed_pspec: ParsedPartitionSpec
|
|
|
|
@use_cpp_method()
|
|
def __init__(
|
|
self, mesh: mesh_lib.Mesh, spec: PartitionSpec, *,
|
|
memory_kind: Optional[str] = None, _parsed_pspec = None):
|
|
|
|
self.mesh = mesh
|
|
self.spec = spec
|
|
self.memory_kind = memory_kind
|
|
self._parsed_pspec = _parsed_pspec
|
|
self._preprocess()
|
|
|
|
def __reduce__(self):
|
|
if xla_extension_version >= 168:
|
|
return (
|
|
type(self),
|
|
(self.mesh, self.spec),
|
|
{'memory_kind': self.memory_kind},
|
|
)
|
|
else:
|
|
return type(self), (self.mesh, self.spec)
|
|
|
|
def _preprocess(self):
|
|
if xla_extension_version >= 170 and self.memory_kind is not None:
|
|
# Will error if memory_kind does not exist on the device.
|
|
self.mesh.devices.flat[0].memory(self.memory_kind)
|
|
|
|
# This split exists because you can pass `_parsed_pspec` that has been
|
|
# modified from the original. For example: Adding extra dimension to
|
|
# axis_resources for vmap handlers. In such cases you need to preserve the
|
|
# `sync` attribute of parsed pspecs.
|
|
# PartitionSpec is inferred from the parsed pspec in this case.
|
|
# TODO(yaskatariya): Remove this and replace this with a normalized
|
|
# representation of Parsed Pspec
|
|
if self._parsed_pspec is None:
|
|
self._parsed_pspec, _, _ = prepare_axis_resources(
|
|
PartitionSpec() if self.spec is None else self.spec,
|
|
"NamedSharding spec", allow_unconstrained_dims=True)
|
|
|
|
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
|
|
|
|
def __repr__(self):
|
|
if xla_extension_version >= 168:
|
|
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
|
|
return f'NamedSharding(mesh={dict(self.mesh.shape)}, spec={self.spec}{mem})'
|
|
else:
|
|
return f'NamedSharding(mesh={dict(self.mesh.shape)}, spec={self.spec})'
|
|
|
|
def __hash__(self):
|
|
if not hasattr(self, '_hash'):
|
|
if xla_extension_version >= 168:
|
|
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
|
|
else:
|
|
self._hash = hash((self.mesh, self._parsed_pspec))
|
|
return self._hash
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, NamedSharding):
|
|
return False
|
|
if id(self) == id(other):
|
|
return True
|
|
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
|
|
if xla_extension_version >= 168:
|
|
if (id(self.mesh) == id(other.mesh) and
|
|
self.memory_kind == other.memory_kind and parsed_pspec_equal):
|
|
return True
|
|
return (self.mesh == other.mesh and self.memory_kind == other.memory_kind
|
|
and parsed_pspec_equal)
|
|
else:
|
|
if id(self.mesh) == id(other.mesh) and parsed_pspec_equal:
|
|
return True
|
|
return self.mesh == other.mesh and parsed_pspec_equal
|
|
|
|
def is_compatible_aval(self, aval_shape: Shape):
|
|
assert self._parsed_pspec is not None
|
|
if len(aval_shape) < len(self._parsed_pspec):
|
|
extra_msg = (' For scalars the PartitionSpec should be P()'
|
|
if len(aval_shape) == 0 else '')
|
|
raise ValueError(
|
|
f"Sharding {self} is only valid for values of rank at least "
|
|
f"{len(self._parsed_pspec)}, but was applied to a value of rank "
|
|
f"{len(aval_shape)}.{extra_msg}")
|
|
|
|
@classmethod
|
|
def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None):
|
|
if xla_extension_version >= 168:
|
|
return cls(mesh, parsed_pspec.get_partition_spec(),
|
|
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
|
|
else:
|
|
return cls(mesh, parsed_pspec.get_partition_spec(),
|
|
_parsed_pspec=parsed_pspec)
|
|
|
|
@property
|
|
def device_set(self) -> set[Device]:
|
|
return self.mesh._flat_devices_set
|
|
|
|
@property
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
return self.mesh._flat_devices_tuple
|
|
|
|
@property
|
|
def is_fully_addressable(self) -> bool:
|
|
# Speed up `is_fully_addressable` since there is a high chance that the
|
|
# mesh across multiple NamedSharding objects will be the same.
|
|
return not self.mesh.is_multi_process
|
|
|
|
@property
|
|
def addressable_devices(self) -> set[Device]:
|
|
# Override addressable devices because there is a high chance that the mesh
|
|
# across multiple NamedSharding objects will be the same.
|
|
return self.mesh._local_devices_set
|
|
|
|
@functools.cached_property
|
|
def is_fully_replicated(self) -> bool:
|
|
if self.mesh.size == 1:
|
|
return True
|
|
array_mapping = cast(ParsedPartitionSpec, get_array_mapping(self._parsed_pspec))
|
|
mesh_shape = self.mesh.shape
|
|
num_partitions = 1
|
|
for name in array_mapping:
|
|
num_partitions *= mesh_shape[name]
|
|
return num_partitions == 1
|
|
|
|
def _get_sharding_spec(self, num_dimensions, axis_ctx):
|
|
assert self._parsed_pspec is not None
|
|
array_mapping = get_array_mapping(self._parsed_pspec)
|
|
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
|
|
# since we don't really need sharding spec.
|
|
sharding_spec = sharding_specs.new_mesh_sharding_specs(
|
|
self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping)
|
|
# Used in `with_sharding_constraint`.
|
|
special_axes = {}
|
|
# Manual axes is only used with xmap.
|
|
if axis_ctx is not None and isinstance(axis_ctx, SPMDAxisContext):
|
|
axis_names = self.mesh.axis_names
|
|
# Ignore type because mypy doesn't recognize the `hasattr` check above.
|
|
for manual_axis in axis_ctx.manual_axes: # type: ignore
|
|
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
|
|
return sharding_spec, special_axes
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
def _to_xla_hlo_sharding(
|
|
self, num_dimensions: int,
|
|
axis_ctx: Optional[Union[SPMDAxisContext, ShardingContext]] = None
|
|
) -> xc.HloSharding:
|
|
sharding_spec, special_axes = self._get_sharding_spec(
|
|
num_dimensions, axis_ctx)
|
|
return sharding_spec.sharding_proto(special_axes=special_axes)
|
|
|
|
|
|
@functools.lru_cache()
|
|
def get_replicated_hlo_sharding():
|
|
return xc.HloSharding.replicate()
|
|
|
|
|
|
@use_cpp_class(xc.SingleDeviceSharding)
|
|
class SingleDeviceSharding(XLACompatibleSharding):
|
|
"""A subclass of ``XLACompatibleSharding`` that places its data on a single device.
|
|
|
|
Args:
|
|
device: A single :py:class:`Device`.
|
|
|
|
Example:
|
|
|
|
>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
|
|
... jax.devices()[0])
|
|
"""
|
|
|
|
_device: Device
|
|
_memory_kind: Optional[str]
|
|
|
|
@use_cpp_method()
|
|
def __init__(self, device: Device, *, memory_kind: Optional[str] = None):
|
|
self._device = device
|
|
self._memory_kind = memory_kind
|
|
|
|
def __reduce__(self):
|
|
if xla_extension_version >= 168:
|
|
return type(self), (self._device,), {'memory_kind': self._memory_kind}
|
|
else:
|
|
return type(self), (self._device,)
|
|
|
|
def __repr__(self):
|
|
if xla_extension_version >= 168:
|
|
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
|
|
return f"SingleDeviceSharding(device={repr(self._device)}{mem})"
|
|
else:
|
|
return f"SingleDeviceSharding(device={repr(self._device)})"
|
|
|
|
def __hash__(self):
|
|
if not hasattr(self, '_hash'):
|
|
if xla_extension_version >= 168:
|
|
self._hash = hash((self._device, self._memory_kind))
|
|
else:
|
|
self._hash = hash(self._device)
|
|
return self._hash
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, SingleDeviceSharding):
|
|
return False
|
|
if id(self) == id(other):
|
|
return True
|
|
if xla_extension_version >= 168:
|
|
return (self._device == other._device and
|
|
self._memory_kind == other._memory_kind)
|
|
else:
|
|
return self._device == other._device
|
|
|
|
@property
|
|
def device_set(self) -> set[Device]:
|
|
return {self._device}
|
|
|
|
@property
|
|
def memory_kind(self) -> Optional[str]:
|
|
if xla_extension_version >= 168:
|
|
return self._memory_kind
|
|
else:
|
|
return None
|
|
|
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
|
|
return {self._device: (slice(None),) * len(global_shape)}
|
|
|
|
@property
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
return (self._device,)
|
|
|
|
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
|
return get_replicated_hlo_sharding()
|
|
|
|
@property
|
|
def is_fully_replicated(self) -> bool:
|
|
return True
|
|
|
|
|
|
@use_cpp_class(xc.PmapSharding)
|
|
class PmapSharding(XLACompatibleSharding):
|
|
devices: np.ndarray
|
|
sharding_spec: sharding_specs.ShardingSpec
|
|
|
|
@use_cpp_method()
|
|
def __init__(self, devices: Union[Sequence[Device], np.ndarray],
|
|
sharding_spec: sharding_specs.ShardingSpec):
|
|
self.devices = np.asarray(devices)
|
|
# The sharding spec should be pmap's sharding spec.
|
|
self.sharding_spec = sharding_spec
|
|
|
|
def __reduce__(self):
|
|
return type(self), (self.devices, self.sharding_spec)
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, PmapSharding):
|
|
return False
|
|
if id(self) == id(other):
|
|
return True
|
|
return (self.sharding_spec == other.sharding_spec and
|
|
np.array_equal(self.devices, other.devices))
|
|
|
|
def __hash__(self):
|
|
if not hasattr(self, '_hash'):
|
|
self._hash = hash((tuple(self.devices.flat), self.sharding_spec))
|
|
return self._hash
|
|
|
|
def __str__(self):
|
|
device_ids = [d.id for d in self.devices.flat]
|
|
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
|
|
f'{device_ids=}, '
|
|
f'device_platform={self.devices.flat[0].platform.upper()}, '
|
|
f'device_shape={self.devices.shape})')
|
|
|
|
def __repr__(self):
|
|
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
|
|
f'devices={self.devices})')
|
|
|
|
def is_equivalent_to(self: PmapSharding, other: PmapSharding, # type: ignore
|
|
ndim: int) -> bool:
|
|
return self == other
|
|
|
|
# TODO(yashkatariya): Expose `sharded_dim_size` in the API if required.
|
|
@classmethod
|
|
def default(cls, shape: Shape, sharded_dim: int = 0,
|
|
devices: Optional[Sequence[xc.Device]] = None) -> PmapSharding:
|
|
"""Creates a `PmapSharding` which matches the implicit device order used by
|
|
`pmap` if devices is None. If devices is specified, it will use those
|
|
devices.
|
|
|
|
Args:
|
|
shape: The shape of the input array.
|
|
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
|
|
devices: Optional sequence of devices used to create PmapSharding. If not
|
|
specified, it will use the implicit device order used by pmap which is
|
|
the order of jax.local_devices()
|
|
"""
|
|
# The dtype doesn't matter here. Its only used for creating the
|
|
# sharding_spec.
|
|
sharding_spec = sharding_specs.create_pmap_sharding_spec(
|
|
tuple(shape), sharded_dim)
|
|
|
|
num_ways_sharded = None
|
|
for s in sharding_spec.sharding:
|
|
if isinstance(s, sharding_specs.Unstacked):
|
|
num_ways_sharded = s.size
|
|
if num_ways_sharded is None:
|
|
raise NotImplementedError(
|
|
'`None` to sharded_dim is not supported. Please file a jax '
|
|
'issue if you need this feature.')
|
|
|
|
if devices is None:
|
|
pmap_devices: np.ndarray = np.array(
|
|
xla_bridge.local_devices()[:num_ways_sharded])
|
|
else:
|
|
pmap_devices = np.array(devices)
|
|
return cls(pmap_devices, sharding_spec)
|
|
|
|
@functools.cached_property
|
|
def device_set(self) -> set[Device]:
|
|
return set(self.devices.flat)
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
|
self.shard_shape(global_shape) # raises a good error message
|
|
indices = sharding_specs.spec_to_indices(global_shape, self.sharding_spec)
|
|
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
|
|
|
|
@functools.cached_property
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
return tuple(self.devices.flat)
|
|
|
|
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
|
raise NotImplementedError("pmap doesn't use OpSharding.")
|
|
|
|
@functools.cached_property
|
|
def is_fully_replicated(self) -> bool:
|
|
for s in self.sharding_spec.sharding:
|
|
if isinstance(s, sharding_specs.Unstacked):
|
|
return False
|
|
return True
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
|
sharded_dim = None
|
|
sharded_dim_size = None
|
|
for i, s in enumerate(self.sharding_spec.sharding):
|
|
if isinstance(s, sharding_specs.Unstacked):
|
|
sharded_dim = i
|
|
sharded_dim_size = s.size
|
|
break
|
|
if sharded_dim is None:
|
|
return global_shape
|
|
if global_shape[sharded_dim] != sharded_dim_size:
|
|
raise ValueError(
|
|
f'The sharded dimension must be equal to the number of '
|
|
f'devices passed to PmapSharding. Got sharded dimension {sharded_dim} '
|
|
f'with value {global_shape[sharded_dim]} in shape {global_shape} and '
|
|
f'the number of devices={len(self._device_assignment)}')
|
|
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
|
|
|
|
|
def _op_sharding_to_pos_sharding(
|
|
op_sharding: Union[xc.OpSharding, xc.HloSharding],
|
|
device_assignment: Sequence[xc.Device]) -> PositionalSharding:
|
|
if isinstance(op_sharding, xc.HloSharding):
|
|
op_sharding = op_sharding.to_proto() # type: ignore
|
|
|
|
if op_sharding.type == xc.OpSharding.Type.REPLICATED:
|
|
return PositionalSharding(device_assignment).replicate()
|
|
|
|
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
|
|
replicate_on_last_tile_dim = True
|
|
else:
|
|
replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim
|
|
if op_sharding.last_tile_dims:
|
|
raise NotImplementedError(
|
|
"Unhandled OpSharding type. Please open a bug report!")
|
|
|
|
name = device_assignment[0].platform.upper()
|
|
ids = np.array([DeviceIdSet(name, i)
|
|
for i in op_sharding.tile_assignment_devices])
|
|
p = PositionalSharding.remake(tuple(device_assignment), ids)
|
|
p = p.reshape(op_sharding.tile_assignment_dimensions)
|
|
if replicate_on_last_tile_dim:
|
|
p = p.replicate(-1, keepdims=False)
|
|
return p
|
|
|
|
|
|
class PositionalSharding(XLACompatibleSharding):
|
|
_devices: tuple[xc.Device, ...]
|
|
_memory_kind: Optional[str]
|
|
_ids: np.ndarray # dtype DeviceIdSet
|
|
|
|
def __init__(self, devices: Union[Sequence[xc.Device], np.ndarray],
|
|
*, memory_kind: Optional[str] = None):
|
|
if not isinstance(devices, np.ndarray):
|
|
devices = np.array(devices, dtype='object')
|
|
if not devices.size:
|
|
raise ValueError(f"{self.__class__.__name__}.__init__ requires at least "
|
|
f"one device, got {devices}")
|
|
self._devices = tuple(devices.flat)
|
|
self._memory_kind = memory_kind
|
|
name = self._devices[0].platform.upper()
|
|
self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)],
|
|
dtype='object').reshape(devices.shape)
|
|
if self._memory_kind is not None:
|
|
# Will error if memory_kind does not exist on the device.
|
|
self._devices[0].memory(self._memory_kind)
|
|
|
|
shape = property(op.attrgetter('_ids.shape'))
|
|
ndim = property(op.attrgetter('_ids.ndim'))
|
|
|
|
def __repr__(self) -> str:
|
|
cls_name = self.__class__.__name__
|
|
ids = self._ids.copy()
|
|
platform_name = self._devices[0].platform.upper()
|
|
for idx, x in np.ndenumerate(ids):
|
|
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x))
|
|
body = np.array2string(ids, prefix=cls_name + '(', suffix=')',
|
|
max_line_width=100)
|
|
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
|
|
return f'{cls_name}({body}{mem})'
|
|
|
|
def reshape(self, *shape):
|
|
return self.remake(self._devices, self._ids.reshape(*shape))
|
|
|
|
def transpose(self, *axes):
|
|
return self.remake(self._devices, self._ids.transpose(*axes))
|
|
T = property(transpose)
|
|
|
|
def replicate(self, axis=None, keepdims=True):
|
|
new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union
|
|
return self.remake(self._devices, new_ids)
|
|
|
|
@classmethod
|
|
def remake(
|
|
cls, devices: tuple[xc.Device, ...], ids: np.ndarray,
|
|
*, memory_kind: Optional[str] = None) -> PositionalSharding:
|
|
self = cls.__new__(cls)
|
|
self._devices = devices
|
|
self._ids = ids
|
|
self._memory_kind = memory_kind
|
|
return self
|
|
|
|
# Hashable
|
|
|
|
def __hash__(self) -> int:
|
|
if not hasattr(self, '_hash'):
|
|
self._hash = hash((self._devices, self._memory_kind))
|
|
return self._hash
|
|
|
|
def __eq__(self, other) -> bool:
|
|
if not isinstance(other, PositionalSharding):
|
|
return False
|
|
if id(self) == id(other):
|
|
return True
|
|
all_ids_equal = np.array_equal(self._ids,other._ids)
|
|
if (id(self._devices) == id(other._devices) and
|
|
self._memory_kind == other._memory_kind and all_ids_equal):
|
|
return True
|
|
return (self._devices == other._devices and
|
|
self._memory_kind == other._memory_kind and all_ids_equal)
|
|
|
|
# Sharding interface
|
|
|
|
@functools.cached_property
|
|
def device_set(self) -> set[xc.Device]:
|
|
return set(self._devices)
|
|
|
|
@property
|
|
def memory_kind(self) -> Optional[str]:
|
|
if xla_extension_version >= 168:
|
|
return self._memory_kind
|
|
else:
|
|
return None
|
|
|
|
@functools.cached_property
|
|
def is_fully_replicated(self) -> bool:
|
|
return self.shape == (1,) * self.ndim
|
|
|
|
# XLACompatibleSharding interface
|
|
|
|
@property
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
return self._devices
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
|
if self.shape == (1,) * self.ndim:
|
|
return get_replicated_hlo_sharding()
|
|
|
|
pbuf = xc.OpSharding()
|
|
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
|
|
set_size, = {len(device_set) for device_set in self._ids.flat}
|
|
pbuf.type = xc.OpSharding.Type.OTHER
|
|
if set_size > 1:
|
|
pbuf.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
|
|
pbuf.tile_assignment_dimensions = (*shape, set_size)
|
|
else:
|
|
pbuf.tile_assignment_dimensions = shape
|
|
pbuf.tile_assignment_devices = [i for ids in self._ids.flat for i in ids]
|
|
product_of_dims = math.prod(pbuf.tile_assignment_dimensions)
|
|
num_devices = len(pbuf.tile_assignment_devices)
|
|
assert product_of_dims == num_devices, (product_of_dims, num_devices)
|
|
return xc.HloSharding.from_proto(pbuf)
|
|
|
|
|
|
class DeviceIdSet:
|
|
_name: str
|
|
_ids: frozenset[int]
|
|
def __init__(self, name, *ids):
|
|
self._name = name
|
|
self._ids = frozenset(ids)
|
|
|
|
def __iter__(self):
|
|
return iter(sorted(self._ids))
|
|
|
|
def __add__(self, other) -> DeviceIdSet:
|
|
assert isinstance(other, DeviceIdSet)
|
|
return DeviceIdSet(self._name, *(self._ids | other._ids))
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._ids)
|
|
|
|
def __repr__(self) -> str:
|
|
ids = ', '.join(safe_map(str, sorted(self._ids)))
|
|
return f'{{{self._name} {ids}}}'
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self._name, self._ids))
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return (isinstance(other, DeviceIdSet) and self._name == other._name and
|
|
self._ids == other._ids)
|
|
|
|
|
|
@use_cpp_class(xc.GSPMDSharding)
|
|
class GSPMDSharding(XLACompatibleSharding):
|
|
_devices: tuple[Device, ...]
|
|
_hlo_sharding: xc.HloSharding
|
|
_memory_kind: Optional[str]
|
|
|
|
@use_cpp_method()
|
|
def __init__(self, devices: Sequence[Device],
|
|
op_sharding: Union[xc.OpSharding, xc.HloSharding],
|
|
*, memory_kind: Optional[str] = None):
|
|
self._devices = tuple(devices)
|
|
if isinstance(op_sharding, xc.OpSharding):
|
|
self._hlo_sharding = xc.HloSharding.from_proto(op_sharding)
|
|
else:
|
|
self._hlo_sharding = op_sharding
|
|
self._memory_kind = memory_kind
|
|
|
|
if xla_extension_version < 159:
|
|
@property
|
|
def _hlo_sharding(self): # type: ignore
|
|
if isinstance(self._op_sharding, xc.OpSharding): # type: ignore
|
|
return xc.HloSharding.from_proto(self._op_sharding) # type: ignore
|
|
return self._op_sharding # type: ignore
|
|
|
|
def __reduce__(self):
|
|
if xla_extension_version >= 168:
|
|
return (
|
|
type(self),
|
|
(self._devices, self._hlo_sharding.to_proto()),
|
|
{'memory_kind': self._memory_kind},
|
|
)
|
|
else:
|
|
return type(self), (self._devices, self._hlo_sharding.to_proto())
|
|
|
|
@functools.cached_property
|
|
def _hlo_sharding_hash(self):
|
|
return hash(self._hlo_sharding)
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, GSPMDSharding):
|
|
return False
|
|
if id(self) == id(other):
|
|
return True
|
|
if xla_extension_version >= 168:
|
|
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
|
|
and self._devices == other._devices and
|
|
self._memory_kind == other._memory_kind)
|
|
else:
|
|
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
|
|
and self._devices == other._devices)
|
|
|
|
def __hash__(self):
|
|
if not hasattr(self, '_hash'):
|
|
if xla_extension_version >= 168:
|
|
self._hash = hash((self._devices, self._hlo_sharding_hash,
|
|
self._memory_kind))
|
|
else:
|
|
self._hash = hash((self._devices, self._hlo_sharding_hash))
|
|
return self._hash
|
|
|
|
def __repr__(self):
|
|
if xla_extension_version >= 168:
|
|
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
|
|
return f'GSPMDSharding({repr(self._hlo_sharding)}{mem})'
|
|
else:
|
|
return f'GSPMDSharding({repr(self._hlo_sharding)})'
|
|
|
|
def is_compatible_aval(self, aval_shape: Shape):
|
|
num_ways_dim_sharded, _ = get_num_ways_dim_sharded(self._hlo_sharding)
|
|
if len(aval_shape) < len(num_ways_dim_sharded):
|
|
raise ValueError(
|
|
f"Sharding {self} is only valid for values of rank at least "
|
|
f"{len(num_ways_dim_sharded)}, but was applied to a value of rank "
|
|
f"{len(aval_shape)}")
|
|
|
|
@functools.cached_property
|
|
def device_set(self) -> set[Device]:
|
|
return set(self._devices)
|
|
|
|
@property
|
|
def memory_kind(self) -> Optional[str]:
|
|
if xla_extension_version >= 168:
|
|
return self._memory_kind
|
|
else:
|
|
return None
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
|
self.shard_shape(global_shape) # raises a good error message
|
|
indices = op_sharding_to_indices(self._hlo_sharding, global_shape,
|
|
len(self._devices))
|
|
return dict(safe_zip(self._devices, indices))
|
|
|
|
@property
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
return self._devices
|
|
|
|
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
|
return self._hlo_sharding
|
|
|
|
@functools.cached_property
|
|
def is_fully_replicated(self) -> bool:
|
|
return is_op_sharding_replicated(self._hlo_sharding)
|
|
|
|
@classmethod
|
|
def get_replicated(cls, device_assignment, *, memory_kind: Optional[str] = None):
|
|
if xla_extension_version >= 168:
|
|
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
|
|
memory_kind=memory_kind)
|
|
else:
|
|
return cls(tuple(device_assignment), get_replicated_hlo_sharding())
|
|
|
|
|
|
class AUTO:
|
|
|
|
def __init__(self, mesh: mesh_lib.Mesh):
|
|
self.mesh = mesh
|
|
|
|
|
|
def is_auto(x):
|
|
return isinstance(x, AUTO)
|
|
|
|
|
|
class UnspecifiedValue:
|
|
def __repr__(self):
|
|
return "UnspecifiedValue"
|
|
UNSPECIFIED = UnspecifiedValue()
|
|
|
|
def is_unspecified(x):
|
|
return isinstance(x, UnspecifiedValue)
|
|
|
|
def is_unspecified_or_auto(x):
|
|
return is_auto(x) or is_unspecified(x)
|
|
|
|
|
|
MeshAxisName = Any
|
|
|
|
"""
|
|
ArrayMapping specifies how an ndarray should map to mesh axes.
|
|
|
|
Note that the ordering is crucial for the cases when this mapping is non-injective
|
|
(i.e. when multiple mesh axes map to the same positional axis). Then, the
|
|
order of entries of the mapping determines a major-to-minor order on mesh axes,
|
|
according to which chunks of the value along the repeated dimension will be assigned.
|
|
|
|
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
|
|
The second dimension of the value would get chunked into 6 pieces, and assigned to the
|
|
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
|
|
that would mean that a flat list of chunks would get assigned to a flattened list of
|
|
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
|
|
mesh devices ndarray would have to be transposed before flattening and assignment.
|
|
"""
|
|
ArrayMapping = OrderedDictType[MeshAxisName, int]
|
|
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue]
|
|
|
|
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
|
if not array_mapping:
|
|
return PartitionSpec()
|
|
max_index = -1
|
|
reverse_map = collections.defaultdict(list)
|
|
for axis, index in array_mapping.items():
|
|
reverse_map[index].append(axis)
|
|
if index > max_index:
|
|
max_index = index
|
|
partitions = []
|
|
for i in range(max_index + 1):
|
|
axis = reverse_map[i]
|
|
if axis:
|
|
partitions.append(axis[0] if len(axis) == 1 else tuple(axis))
|
|
else:
|
|
partitions.append(None)
|
|
return PartitionSpec(*partitions)
|
|
|
|
def get_array_mapping(
|
|
axis_resources: Union[ParsedPartitionSpec, AUTO, UnspecifiedValue]
|
|
) -> ArrayMappingOrAutoOrUnspecified:
|
|
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
|
|
# Don't use `is_auto` here to satisfy pytype and mypy.
|
|
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
|
|
return axis_resources
|
|
return OrderedDict((axis, i)
|
|
for i, axes in enumerate(axis_resources)
|
|
if axes is not None for axis in axes)
|
|
|
|
|
|
get_single_pspec = lambda p: array_mapping_to_axis_resources(
|
|
cast(ArrayMapping, get_array_mapping(p)))
|
|
|
|
|
|
class SpecSync(enum.IntEnum):
|
|
"""Encodes how much out of sync the real value of partitions is compared to the user specified one.
|
|
|
|
We use this to make sure we don't show garbage modified values while claiming
|
|
that the users have specified them like that.
|
|
"""
|
|
OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted
|
|
DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes
|
|
IN_SYNC = 2 # Entirely in sync
|
|
|
|
class ParsedPartitionSpec:
|
|
__slots__ = ('unsafe_user_spec', 'partitions', 'sync')
|
|
|
|
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
|
|
self.unsafe_user_spec = user_spec
|
|
# None in partitions represents unconstrained dim.
|
|
# TODO(yashkatariya): May use a sentinel value.
|
|
self.partitions = tuple(partitions)
|
|
self.sync = sync
|
|
|
|
@property
|
|
def user_spec(self):
|
|
return self.unsynced_user_spec(SpecSync.IN_SYNC)
|
|
|
|
def get_partition_spec(self) -> PartitionSpec:
|
|
if self.sync < SpecSync.IN_SYNC:
|
|
return get_single_pspec(self)
|
|
else:
|
|
if isinstance(self.unsafe_user_spec, PartitionSpec):
|
|
return self.unsafe_user_spec
|
|
else:
|
|
return get_single_pspec(self)
|
|
|
|
def unsynced_user_spec(self, min_sync):
|
|
if self.sync < min_sync:
|
|
raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
|
|
return self.unsafe_user_spec
|
|
|
|
def insert_axis_partitions(self, dim, val):
|
|
parts = self.partitions
|
|
too_short = dim - len(parts)
|
|
if too_short > 0:
|
|
parts += ((),) * too_short
|
|
new_partitions = util.tuple_insert(parts, dim, val)
|
|
new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC
|
|
return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync)
|
|
|
|
@classmethod
|
|
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
|
|
if entry is None:
|
|
return cls(entry, ())
|
|
if not isinstance(entry, PartitionSpec):
|
|
raise TypeError(f"{arg_name} are expected to be "
|
|
f"PartitionSpec instances or None, but got {entry}")
|
|
axis_specs = []
|
|
for axis_spec in entry:
|
|
if axis_spec is None:
|
|
axis_spec = ()
|
|
elif isinstance(axis_spec, (list, tuple)):
|
|
axis_spec = tuple(axis_spec)
|
|
elif axis_spec == PartitionSpec.UNCONSTRAINED:
|
|
if not allow_unconstrained_dims:
|
|
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
|
|
axis_spec = None
|
|
else:
|
|
axis_spec = (axis_spec,)
|
|
axis_specs.append(axis_spec)
|
|
return cls(entry, axis_specs)
|
|
|
|
def __hash__(self):
|
|
return hash((self.partitions, self.sync))
|
|
|
|
def __eq__(self, other):
|
|
return (self.partitions == other.partitions and
|
|
self.sync == other.sync)
|
|
|
|
def __len__(self):
|
|
return len(self.partitions)
|
|
|
|
def __getitem__(self, i):
|
|
return self.partitions[i]
|
|
|
|
def __iter__(self):
|
|
return iter(self.partitions)
|
|
|
|
def __repr__(self):
|
|
return (f"ParsedPartitionSpec(partitions={self.partitions}, "
|
|
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
|
f"sync={self.sync})")
|
|
|
|
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
|
|
"""ParsedPartitionSpecs that are canonicalized.
|
|
|
|
ParsedPartitionSpecs may contain trailing empty tuples, that make them
|
|
semantically different in general, and yet in some situations we prefer
|
|
to regard them as equivalent. For example, partitions of () and ((),)
|
|
cannot be always considered equivalent, since the first one is a valid
|
|
spec for a scalar value, while the second is not! However, when either of
|
|
those are applied to a 2D array, they both mean that the array is fully
|
|
replicated.
|
|
|
|
So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
|
|
partitions.
|
|
"""
|
|
|
|
def __init__(self, parsed_pspec: ParsedPartitionSpec):
|
|
partitions = list(parsed_pspec.partitions)
|
|
while partitions and partitions[-1] == ():
|
|
partitions.pop()
|
|
|
|
super().__init__(parsed_pspec.unsafe_user_spec, partitions,
|
|
parsed_pspec.sync)
|
|
|
|
def __repr__(self):
|
|
return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
|
|
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
|
f"sync={self.sync})")
|
|
|
|
|
|
def check_all_or_none_unspecified(axis_resources, name):
|
|
if not axis_resources:
|
|
return False
|
|
unspecified_count = 0
|
|
unspecified = is_unspecified(axis_resources[0])
|
|
for resource in axis_resources:
|
|
current_is_unspecified = is_unspecified(resource)
|
|
if current_is_unspecified:
|
|
unspecified_count += 1
|
|
assert unspecified_count == 1
|
|
if current_is_unspecified != unspecified:
|
|
raise ValueError(f'`pjit.UNSPECIFIED` exists in {name}. '
|
|
f'Make sure that every entry in {name} is '
|
|
'`pjit.UNSPECIFIED`.')
|
|
return unspecified
|
|
|
|
|
|
def prepare_axis_resources(axis_resources,
|
|
arg_name,
|
|
allow_unconstrained_dims=False):
|
|
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
|
|
entries, treedef = tree_util.tree_flatten(
|
|
axis_resources, is_leaf=lambda x: x is None)
|
|
what = f"{arg_name} leaf specifications"
|
|
# All entries should be specified or if unspecified then there should only
|
|
# be 1 entry for that since UNSPECIFIED is a private API.
|
|
check_all_or_none_unspecified(entries, arg_name)
|
|
|
|
new_entries = []
|
|
for entry in entries:
|
|
if is_unspecified_or_auto(entry) or entry is None:
|
|
new_entries.append(entry)
|
|
elif isinstance(entry, sharding.Sharding):
|
|
if isinstance(entry, PmapSharding):
|
|
raise ValueError(f'One of {what} got sharding {entry} which is not '
|
|
'allowed.')
|
|
if not isinstance(entry, XLACompatibleSharding):
|
|
raise ValueError(f'One of {what} got sharding {entry} which is not a '
|
|
'subclass of XLACompatibleSharding.')
|
|
new_entries.append(entry)
|
|
else:
|
|
new_entries.append(ParsedPartitionSpec.from_user_input(
|
|
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
|
|
|
|
_check_unique_resources(new_entries, arg_name)
|
|
return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef
|
|
|
|
|
|
def _check_unique_resources(axis_resources, arg_name):
|
|
for arg_axis_resources in axis_resources:
|
|
if not arg_axis_resources: continue
|
|
if (is_unspecified_or_auto(arg_axis_resources) or
|
|
isinstance(arg_axis_resources, XLACompatibleSharding)):
|
|
continue
|
|
constrained_dims = [d for d in arg_axis_resources if d is not None]
|
|
resource_counts = collections.Counter(
|
|
itertools.chain.from_iterable(constrained_dims))
|
|
if not resource_counts: continue
|
|
if resource_counts.most_common(1)[0][1] > 1:
|
|
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
|
if multiple_uses:
|
|
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
|
|
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
|
|
f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
|
|
|
|
# Axis environments
|
|
|
|
class AxisEnv(NamedTuple):
|
|
"""Represents a pmap mesh (only along the replica axes)."""
|
|
nreps: int
|
|
names: tuple[Any, ...]
|
|
sizes: tuple[int, ...]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SPMDAxisContext:
|
|
"""A hardware axis context for parallel computations that use the GSPMD partitioner.
|
|
|
|
This includes the mesh that will later by used to execute this computation,
|
|
as well as a set of mesh axes that are currently (e.g. because the current lowering
|
|
is invoked inside an xmap) lowered in the MANUAL sharding mode.
|
|
"""
|
|
mesh: mesh_lib.Mesh
|
|
manual_axes: frozenset[MeshAxisName] = frozenset()
|
|
|
|
@property
|
|
def axis_env(self):
|
|
# All collectives that touch axis_env should remember to set use_global_device_ids
|
|
# when this context is enabled!
|
|
if self.manual_axes != frozenset(self.mesh.axis_names):
|
|
raise NotImplementedError(
|
|
"Collectives in manually partitioned computations are only supported "
|
|
"when all mesh axes are partitioned manually (no partial automatic sharding). "
|
|
"Make sure that you mention all mesh axes in axis_resources!")
|
|
return self.unsafe_axis_env
|
|
|
|
@property
|
|
def unsafe_axis_env(self):
|
|
return AxisEnv(
|
|
nreps=self.mesh.size,
|
|
names=self.mesh.axis_names,
|
|
sizes=tuple(self.mesh.shape.values()))
|
|
|
|
def extend_manual(self, axes: frozenset[MeshAxisName]) -> SPMDAxisContext:
|
|
return SPMDAxisContext(self.mesh, self.manual_axes | axes)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ReplicaAxisContext:
|
|
"""A hardware axis context for parallel computations that are partitioned by JAX.
|
|
|
|
Unlike in the SPMDAxisContext, this means that JAX might need to emit calls to
|
|
explicit collectives.
|
|
"""
|
|
axis_env: AxisEnv
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ShardingContext:
|
|
"""A hardware axis context for parallel computations that use the sharding
|
|
interface.
|
|
|
|
This context also uses the GSPMD partitioner.
|
|
"""
|
|
device_assignment: Sequence[xc.Device]
|
|
|
|
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
|
|
@property
|
|
def axis_env(self):
|
|
return AxisEnv(nreps=1, names=(), sizes=())
|
|
|
|
|
|
# -------------------- XLA OpSharding to PartitionSpec --------------------
|
|
# Note that OpSharding is more expressive than PartitionSpecs, so it's not
|
|
# always possible to convert them, but the code below should at least
|
|
# support handle all cases when this is possible.
|
|
|
|
def strides_for_sizes(sizes):
|
|
"""Returns an array of strides for major-to-minor sizes."""
|
|
return np.cumprod(sizes[::-1])[::-1] // np.asarray(sizes)
|
|
|
|
def unflatten_array(named_sizes, assignment):
|
|
"""Recovers the ordering of axis names based on a device assignment.
|
|
|
|
The device assignments that this function can convert into axis orders
|
|
are of the form::
|
|
|
|
np.arange(np.prod(named_sizes.values())).transpose(...).flatten()
|
|
|
|
for some transposition ``...``. This is satisfied by all OpSharding assignments
|
|
generated from partition specs.
|
|
|
|
Arguments:
|
|
named_sizes: A dictionary mapping axis names to their sizes.
|
|
assignment: A permutation of integers between 0 and the product of all
|
|
named sizes.
|
|
|
|
Returns:
|
|
A major-to-minor list of axis names that corresponds to the given assignment.
|
|
"""
|
|
named_sizes = {name: size for name, size in named_sizes.items() if size != 1}
|
|
sizes = np.fromiter(named_sizes.values(), dtype=np.int64)
|
|
strides = strides_for_sizes(sizes)
|
|
dims = explode_superdims(sizes, unflatten_superdims(assignment))
|
|
dim_to_name = {(size, stride): name for size, stride, name in zip(sizes, strides, named_sizes)}
|
|
return [dim_to_name[d] for d in dims]
|
|
|
|
def unflatten_superdims(assignment):
|
|
"""Unflatten a list of dimension sizes and their strides that generates assignment.
|
|
|
|
If this function succeeds for a given ``assignment``, then the following property
|
|
should be satisfied::
|
|
|
|
dims_with_strides = unflatten_superdims(assignment)
|
|
base_array = np.arange(map(fst, sorted(dims_with_strides, key=snd, reverse=True)))
|
|
assignment == base_array.transpose(argsort(dims_with_strides, key=snd, reverse=True)).flatten()
|
|
|
|
That is, the returned dimensions list all sizes of the base array (with strides
|
|
indicating their initial order). The order of dimensions in the list corresponds
|
|
to the permutation that applied to the base array generates the assignment.
|
|
"""
|
|
def check(cond):
|
|
if cond: return
|
|
raise NotImplementedError("Failed to convert OpSharding into a ShardingSpec. "
|
|
"Please open a bug report!")
|
|
flat_assignment = np.asarray(assignment, dtype=np.int64)
|
|
check(flat_assignment[0] == 0)
|
|
dims = []
|
|
while flat_assignment.size > 1:
|
|
stride = flat_assignment[1]
|
|
for i in range(len(flat_assignment)):
|
|
if flat_assignment[i] != i * stride: break
|
|
else:
|
|
# After this loop i should point to an "element after the sequence", so
|
|
# we have to increment it if the whole array is a strided sequence.
|
|
i += 1
|
|
size = i
|
|
dims.append((size, stride))
|
|
assert size > 1 # Ensure progress
|
|
flat_assignment = flat_assignment[::size]
|
|
return dims
|
|
|
|
def explode_superdims(sizes, dims):
|
|
"""Explode superdims to fit a known shape.
|
|
|
|
The unflattening process might mistakenly generate too few too large dimensions.
|
|
For example, ``unflatten_superdims(np.arange(n))`` always returns ``[(n, 1)]``.
|
|
This function takes a list of such contiguous super-dimensions and splits them
|
|
into smaller dimensions such that::
|
|
|
|
set(map(fst, explode_superdims(sizes, dims))) == set(sizes)
|
|
"""
|
|
strides_to_sizes = {stride: size for size, stride in zip(sizes, strides_for_sizes(sizes))}
|
|
dims = list(reversed(dims))
|
|
final_dims = []
|
|
for size, stride in dims:
|
|
target_size = strides_to_sizes[stride]
|
|
new_dims = []
|
|
while size > target_size:
|
|
assert target_size > 1 # Ensure progress
|
|
assert size % target_size == 0
|
|
new_dims.append((target_size, stride))
|
|
size //= target_size
|
|
stride *= target_size
|
|
target_size = strides_to_sizes[stride]
|
|
assert size == target_size
|
|
new_dims.append((size, stride))
|
|
final_dims += reversed(new_dims)
|
|
return final_dims
|
|
|
|
def parse_flatten_op_sharding(op_sharding: Union[xc.OpSharding, xc.HloSharding],
|
|
mesh: mesh_lib.Mesh) -> Sequence[ParsedPartitionSpec]:
|
|
if isinstance(op_sharding, xc.HloSharding):
|
|
op_sharding = op_sharding.to_proto() # type: ignore
|
|
if op_sharding.type == xc.OpSharding.Type.TUPLE:
|
|
out: list[ParsedPartitionSpec] = []
|
|
for s in op_sharding.tuple_shardings:
|
|
out.extend(parse_flatten_op_sharding(s, mesh))
|
|
return out
|
|
elif op_sharding.type == xc.OpSharding.Type.REPLICATED:
|
|
return [CanonicalizedParsedPartitionSpec(
|
|
ParsedPartitionSpec(PartitionSpec(), ()))]
|
|
elif op_sharding.type == xc.OpSharding.Type.OTHER:
|
|
mesh_shape = mesh.shape
|
|
mesh_axis_order = unflatten_array(mesh.shape, op_sharding.tile_assignment_devices)
|
|
mesh_axis = iter(mesh_axis_order)
|
|
shape = op_sharding.tile_assignment_dimensions
|
|
partitions = []
|
|
for dim_size in shape:
|
|
dim_partitions = []
|
|
while dim_size > 1:
|
|
axis = next(mesh_axis)
|
|
axis_size = mesh_shape[axis]
|
|
assert dim_size % axis_size == 0
|
|
dim_size //= axis_size
|
|
dim_partitions.append(axis)
|
|
partitions.append(tuple(dim_partitions))
|
|
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
|
|
replicate_on_last_tile_dim = True
|
|
else:
|
|
replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim
|
|
if op_sharding.last_tile_dims:
|
|
raise NotImplementedError("Unhandled OpSharding type. Please open a bug report!")
|
|
if replicate_on_last_tile_dim:
|
|
partitions = partitions[:-1]
|
|
return [CanonicalizedParsedPartitionSpec(
|
|
ParsedPartitionSpec('<internally generated spec>', partitions))]
|
|
else:
|
|
raise AssertionError("Unhandled OpSharding type. Please open a bug report!")
|