2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2021 The JAX Authors.
|
2022-05-26 11:41:50 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2022-10-14 17:19:37 -07:00
|
|
|
from __future__ import annotations
|
2022-05-26 11:41:50 -07:00
|
|
|
|
|
|
|
import abc
|
2022-07-20 13:04:49 -07:00
|
|
|
import functools
|
2022-06-06 17:31:20 -07:00
|
|
|
from collections import Counter
|
2022-10-14 17:19:37 -07:00
|
|
|
import operator as op
|
|
|
|
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
|
|
|
FrozenSet, Union, cast)
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-10-20 13:16:11 -07:00
|
|
|
from jax._src.util import safe_map, safe_zip
|
2022-05-26 11:41:50 -07:00
|
|
|
from jax._src.lib import xla_bridge as xb
|
|
|
|
from jax._src.lib import xla_client as xc
|
2022-08-05 12:17:41 -07:00
|
|
|
from jax._src.lib import xla_extension_version
|
2022-07-07 10:41:27 -07:00
|
|
|
from jax.interpreters import pxla, mlir
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-06-16 19:51:56 -07:00
|
|
|
import numpy as np
|
|
|
|
|
2022-05-26 11:41:50 -07:00
|
|
|
Shape = Tuple[int, ...]
|
|
|
|
Device = xc.Device
|
|
|
|
Index = Tuple[slice, ...]
|
|
|
|
XLADeviceAssignment = Sequence[Device]
|
|
|
|
|
|
|
|
|
2022-09-13 16:18:31 -07:00
|
|
|
@pxla.use_cpp_class(xc.Sharding if xc._version >= 94 else None)
|
2022-05-26 11:41:50 -07:00
|
|
|
class Sharding(metaclass=abc.ABCMeta):
|
|
|
|
|
2022-08-29 09:00:03 -07:00
|
|
|
# Abstract methods below that subclasses should implement.
|
|
|
|
|
2022-08-04 09:59:10 -07:00
|
|
|
@abc.abstractproperty
|
2022-05-26 11:41:50 -07:00
|
|
|
def device_set(self) -> Set[Device]:
|
|
|
|
"""A unique set of devices that this sharding represents.
|
|
|
|
|
|
|
|
Devices can be non-addressable too.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
2022-08-29 09:00:03 -07:00
|
|
|
@abc.abstractmethod
|
|
|
|
def devices_indices_map(
|
|
|
|
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
|
|
|
#############################################################################
|
|
|
|
# Default implementations below that all subclasses will inherit.
|
|
|
|
|
2022-05-26 11:41:50 -07:00
|
|
|
@pxla.maybe_cached_property
|
|
|
|
def addressable_devices(self) -> Set[Device]:
|
|
|
|
"""A set of addressable devices by the current process"""
|
|
|
|
process_index = xb.process_index()
|
2022-06-06 17:31:20 -07:00
|
|
|
return {d for d in self.device_set if d.process_index == process_index}
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-10-08 19:23:32 -07:00
|
|
|
@pxla.maybe_cached_property
|
2022-07-15 16:12:42 -07:00
|
|
|
def is_fully_addressable(self) -> bool:
|
|
|
|
# The pytype disable is because pytype can't recognize a cached property.
|
|
|
|
return len(self.device_set) == len(self.addressable_devices) # type: ignore
|
|
|
|
|
2022-05-26 11:41:50 -07:00
|
|
|
def device_indices(self, device: Device,
|
|
|
|
global_shape: Shape) -> Optional[Index]:
|
2022-07-29 11:37:08 -07:00
|
|
|
return self.devices_indices_map(global_shape)[device]
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-09-19 16:58:46 -07:00
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
|
|
def addressable_devices_indices_map(
|
|
|
|
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
|
|
|
process_index = xb.process_index()
|
|
|
|
return {d: ind for d, ind in self.devices_indices_map(global_shape).items()
|
|
|
|
if d.process_index == process_index}
|
|
|
|
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-09-13 16:18:31 -07:00
|
|
|
@pxla.use_cpp_class(xc.XLACompatibleSharding if xc._version >= 94 else None)
|
|
|
|
class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-08-29 09:00:03 -07:00
|
|
|
# Abstract methods below that subclasses should implement.
|
|
|
|
|
2022-08-04 09:59:10 -07:00
|
|
|
@abc.abstractproperty
|
2022-06-06 17:31:20 -07:00
|
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
2022-08-29 09:00:03 -07:00
|
|
|
@abc.abstractmethod
|
2022-09-15 10:34:11 -07:00
|
|
|
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
|
2022-08-29 09:00:03 -07:00
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
|
|
|
#############################################################################
|
|
|
|
# Default implementations below that all subclasses will inherit.
|
|
|
|
|
2022-10-18 09:50:18 -07:00
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
|
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
|
|
|
op_sharding = self._to_xla_op_sharding(len(global_shape))
|
|
|
|
op_sharding_sharding = OpShardingSharding(self._device_assignment,
|
|
|
|
op_sharding)
|
|
|
|
return op_sharding_sharding.devices_indices_map(global_shape)
|
|
|
|
|
2022-08-04 09:59:10 -07:00
|
|
|
@pxla.maybe_cached_property
|
2022-06-06 17:31:20 -07:00
|
|
|
def _addressable_device_assignment(self) -> XLADeviceAssignment:
|
|
|
|
process_index = xb.process_index()
|
2022-08-04 09:59:10 -07:00
|
|
|
return [d for d in self._device_assignment if d.process_index == process_index]
|
2022-06-06 17:31:20 -07:00
|
|
|
|
2022-08-29 09:00:03 -07:00
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
|
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
|
|
|
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
|
|
|
|
if pxla.is_op_sharding_replicated(op_sharding):
|
|
|
|
return global_shape
|
|
|
|
partitions, _ = pxla._get_num_ways_dim_sharded(op_sharding)
|
|
|
|
assert len(partitions) == len(global_shape), (len(partitions), len(global_shape))
|
|
|
|
out = []
|
|
|
|
for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
|
|
|
|
quotient, remainder = divmod(s, p)
|
|
|
|
if remainder != 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"Sharding {self} implies that array axis {dim} is partitioned "
|
|
|
|
f"{p} times, but the dimension size is {s} "
|
|
|
|
f"(full shape: {global_shape}, "
|
|
|
|
f"per-dimension tiling factors: {partitions} should evenly divide "
|
|
|
|
"the shape)")
|
|
|
|
out.append(quotient)
|
|
|
|
return tuple(out)
|
2022-05-26 11:41:50 -07:00
|
|
|
|
|
|
|
|
2022-07-25 13:17:47 -07:00
|
|
|
@functools.lru_cache()
|
|
|
|
def _check_mesh_resource_axis(mesh, parsed_pspec):
|
|
|
|
try:
|
|
|
|
[mesh.shape[r] for p in parsed_pspec if p is not None
|
|
|
|
for r in p]
|
|
|
|
except KeyError as e:
|
|
|
|
raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is "
|
|
|
|
"undefined.") from None
|
|
|
|
|
|
|
|
|
2022-07-29 11:37:08 -07:00
|
|
|
def _hashed_index(x) -> int:
|
|
|
|
# This works for both `pjit`/`xmap` indices and `pmap` indices (which might
|
|
|
|
# have an integer instead of a slice).
|
|
|
|
assert all(v.step is None for v in x if isinstance(v, slice))
|
|
|
|
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))
|
|
|
|
|
|
|
|
|
2022-08-05 09:59:22 -07:00
|
|
|
@functools.lru_cache(maxsize=4096)
|
2022-08-29 14:49:17 -07:00
|
|
|
def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
|
|
|
|
try:
|
|
|
|
device_indices_map_fn = sharding.devices_indices_map
|
|
|
|
except AttributeError:
|
|
|
|
raise ValueError(
|
|
|
|
f'Cannot calculate replica ids from sharding: {sharding}. Please '
|
|
|
|
'create a device to index mapping for your sharding from which replica '
|
|
|
|
'ids will be calculated.') from None
|
|
|
|
|
2022-07-29 11:37:08 -07:00
|
|
|
index_to_replica: Dict[int, int] = Counter()
|
|
|
|
out = {}
|
2022-08-29 14:49:17 -07:00
|
|
|
for device, index in device_indices_map_fn(global_shape).items():
|
2022-07-29 11:37:08 -07:00
|
|
|
h_index = _hashed_index(index)
|
|
|
|
replica_id = index_to_replica[h_index]
|
|
|
|
index_to_replica[h_index] += 1
|
|
|
|
out[device] = replica_id
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
2022-09-21 20:17:38 -07:00
|
|
|
@pxla.use_cpp_class(xc.MeshPspecSharding if xc._version >= 95 else None)
|
2022-05-26 11:41:50 -07:00
|
|
|
class MeshPspecSharding(XLACompatibleSharding):
|
|
|
|
|
2022-09-21 20:17:38 -07:00
|
|
|
@pxla.use_cpp_method
|
2022-07-07 10:41:27 -07:00
|
|
|
def __init__(
|
2022-07-08 09:44:48 -07:00
|
|
|
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):
|
2022-07-07 10:41:27 -07:00
|
|
|
|
2022-05-26 11:41:50 -07:00
|
|
|
self.mesh = mesh
|
|
|
|
self.spec = spec
|
2022-09-21 20:17:38 -07:00
|
|
|
self._parsed_pspec = _parsed_pspec
|
|
|
|
self._preprocess()
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-09-21 20:17:38 -07:00
|
|
|
def _preprocess(self):
|
2022-07-07 10:41:27 -07:00
|
|
|
# This split exists because you can pass `_parsed_pspec` that has been
|
|
|
|
# modified from the original. For example: Adding extra dimension to
|
|
|
|
# axis_resources for vmap handlers. In such cases you need to preserve the
|
|
|
|
# `sync` attribute of parsed pspecs.
|
|
|
|
# PartitionSpec is inferred from the parsed pspec in this case.
|
|
|
|
# TODO(yaskatariya): Remove this and replace this with a normalized
|
|
|
|
# representation of Parsed Pspec
|
2022-09-21 20:17:38 -07:00
|
|
|
if self._parsed_pspec is None:
|
2022-07-07 10:41:27 -07:00
|
|
|
from jax.experimental import pjit
|
|
|
|
self._parsed_pspec, _, _, _ = pjit._prepare_axis_resources(
|
|
|
|
self.spec, "MeshPspecSharding spec")
|
|
|
|
|
2022-07-25 13:17:47 -07:00
|
|
|
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
|
|
|
|
|
2022-07-07 10:41:27 -07:00
|
|
|
def __repr__(self):
|
2022-07-19 13:42:56 -07:00
|
|
|
return f'MeshPspecSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
|
2022-07-07 10:41:27 -07:00
|
|
|
|
|
|
|
def __hash__(self):
|
2022-07-13 01:50:55 -07:00
|
|
|
if not hasattr(self, '_hash'):
|
|
|
|
self._hash = hash((self.mesh, self._parsed_pspec))
|
|
|
|
return self._hash
|
2022-07-07 10:41:27 -07:00
|
|
|
|
|
|
|
def __eq__(self, other):
|
2022-07-13 01:50:55 -07:00
|
|
|
if not isinstance(other, MeshPspecSharding):
|
|
|
|
return False
|
2022-09-18 15:35:18 -07:00
|
|
|
if id(self) == id(other):
|
|
|
|
return True
|
2022-07-13 01:50:55 -07:00
|
|
|
if id(self.mesh) == id(other.mesh) and self._parsed_pspec == other._parsed_pspec:
|
|
|
|
return True
|
|
|
|
return self.mesh == other.mesh and self._parsed_pspec == other._parsed_pspec
|
2022-07-07 10:41:27 -07:00
|
|
|
|
2022-07-25 13:17:47 -07:00
|
|
|
def is_compatible_aval(self, aval_shape: Shape):
|
|
|
|
if len(aval_shape) < len(self._parsed_pspec):
|
|
|
|
raise ValueError(
|
|
|
|
f"Sharding {self} is only valid for values of rank at least "
|
|
|
|
f"{len(self._parsed_pspec)}, but was applied to a value of rank "
|
|
|
|
f"{len(aval_shape)}")
|
|
|
|
|
2022-07-07 10:41:27 -07:00
|
|
|
@classmethod
|
|
|
|
def _from_parsed_pspec(cls, mesh, parsed_pspec):
|
2022-09-09 09:13:10 -07:00
|
|
|
return cls(mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
|
2022-07-07 10:41:27 -07:00
|
|
|
|
2022-06-06 17:31:20 -07:00
|
|
|
@pxla.maybe_cached_property
|
2022-05-26 11:41:50 -07:00
|
|
|
def device_set(self) -> Set[Device]:
|
|
|
|
return set(self.mesh.devices.flat)
|
|
|
|
|
2022-08-04 09:59:10 -07:00
|
|
|
@pxla.maybe_cached_property
|
2022-06-06 17:31:20 -07:00
|
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
|
|
return list(self.mesh.devices.flat)
|
|
|
|
|
2022-07-20 13:04:49 -07:00
|
|
|
@functools.lru_cache(maxsize=4096)
|
2022-07-07 10:41:27 -07:00
|
|
|
def _to_xla_op_sharding(
|
2022-07-15 16:12:42 -07:00
|
|
|
self,
|
|
|
|
num_dimensions: int,
|
|
|
|
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
|
2022-08-22 13:34:47 -07:00
|
|
|
) -> xc.OpSharding:
|
2022-07-13 01:50:55 -07:00
|
|
|
from jax.experimental.pjit import get_array_mapping
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-07-13 01:50:55 -07:00
|
|
|
array_mapping = get_array_mapping(self._parsed_pspec)
|
2022-05-26 11:41:50 -07:00
|
|
|
# TODO(yashkatariya): Move away from sharding spec in MeshPspecSharding
|
|
|
|
# since we don't really need sharding spec.
|
|
|
|
sharding_spec = pxla.new_mesh_sharding_specs(
|
|
|
|
self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping)
|
2022-07-07 10:41:27 -07:00
|
|
|
# Used in `with_sharding_constraint`.
|
|
|
|
special_axes = {}
|
2022-07-15 16:12:42 -07:00
|
|
|
# Manual axes is only used with xmap.
|
2022-08-06 22:36:17 -07:00
|
|
|
if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext):
|
2022-07-07 10:41:27 -07:00
|
|
|
axis_names = self.mesh.axis_names
|
2022-07-15 16:12:42 -07:00
|
|
|
# Ignore type because mypy doesn't recognize the `hasattr` check above.
|
|
|
|
for manual_axis in axis_ctx.manual_axes: # type: ignore
|
2022-07-07 10:41:27 -07:00
|
|
|
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
|
|
|
|
return sharding_spec.sharding_proto(special_axes=special_axes)
|
2022-06-06 17:31:20 -07:00
|
|
|
|
|
|
|
|
2022-08-05 22:24:46 -07:00
|
|
|
@functools.lru_cache()
|
|
|
|
def _get_replicated_op_sharding():
|
|
|
|
proto = xc.OpSharding()
|
|
|
|
proto.type = xc.OpSharding.Type.REPLICATED
|
|
|
|
return proto
|
|
|
|
|
|
|
|
|
2022-09-21 20:17:38 -07:00
|
|
|
@pxla.use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
|
2022-06-06 17:31:20 -07:00
|
|
|
class SingleDeviceSharding(XLACompatibleSharding):
|
|
|
|
|
2022-09-21 20:17:38 -07:00
|
|
|
@pxla.use_cpp_method
|
2022-06-06 17:31:20 -07:00
|
|
|
def __init__(self, device: Device):
|
|
|
|
self._device = device
|
|
|
|
|
2022-07-07 10:41:27 -07:00
|
|
|
def __repr__(self):
|
2022-07-09 17:08:08 -07:00
|
|
|
return f"SingleDeviceSharding(device={repr(self._device)})"
|
2022-07-07 10:41:27 -07:00
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self._device)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
if not isinstance(other, SingleDeviceSharding):
|
|
|
|
return False
|
2022-09-18 15:35:18 -07:00
|
|
|
if id(self) == id(other):
|
|
|
|
return True
|
2022-07-07 10:41:27 -07:00
|
|
|
return self._device == other._device
|
|
|
|
|
2022-08-04 09:59:10 -07:00
|
|
|
@property
|
2022-06-06 17:31:20 -07:00
|
|
|
def device_set(self) -> Set[Device]:
|
|
|
|
return {self._device}
|
|
|
|
|
2022-10-18 09:50:18 -07:00
|
|
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
|
2022-06-06 17:31:20 -07:00
|
|
|
return {self._device: (slice(None),) * len(global_shape)}
|
|
|
|
|
2022-08-04 09:59:10 -07:00
|
|
|
@property
|
2022-06-06 17:31:20 -07:00
|
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
|
|
return [self._device]
|
|
|
|
|
2022-08-22 13:34:47 -07:00
|
|
|
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
|
2022-08-05 22:24:46 -07:00
|
|
|
return _get_replicated_op_sharding()
|
2022-06-16 19:51:56 -07:00
|
|
|
|
|
|
|
|
2022-09-13 16:18:31 -07:00
|
|
|
@pxla.use_cpp_class(xc.PmapSharding if xc._version >= 94 else None)
|
2022-06-16 19:51:56 -07:00
|
|
|
class PmapSharding(XLACompatibleSharding):
|
|
|
|
|
2022-09-13 16:18:31 -07:00
|
|
|
@pxla.use_cpp_method
|
2022-06-16 19:51:56 -07:00
|
|
|
def __init__(self, devices: np.ndarray, sharding_spec: pxla.ShardingSpec):
|
|
|
|
self.devices = devices
|
2022-06-22 11:36:39 -07:00
|
|
|
# The sharding spec should be pmap's sharding spec.
|
2022-06-16 19:51:56 -07:00
|
|
|
self.sharding_spec = sharding_spec
|
|
|
|
|
2022-08-31 14:26:41 -07:00
|
|
|
def __eq__(self, other):
|
|
|
|
if not isinstance(other, PmapSharding):
|
|
|
|
return False
|
2022-09-18 15:35:18 -07:00
|
|
|
if id(self) == id(other):
|
|
|
|
return True
|
2022-08-31 14:26:41 -07:00
|
|
|
return (self.sharding_spec == other.sharding_spec and
|
|
|
|
np.array_equal(self.devices, other.devices))
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
if not hasattr(self, '_hash'):
|
|
|
|
self._hash = hash((tuple(self.devices.flat), self.sharding_spec))
|
|
|
|
return self._hash
|
|
|
|
|
2022-10-25 10:08:26 -07:00
|
|
|
def __str__(self):
|
|
|
|
device_ids = [d.id for d in self.devices.flat]
|
|
|
|
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
|
|
|
|
f'device_ids={device_ids}, '
|
|
|
|
f'device_platform={self.devices.flat[0].platform.upper()}, '
|
|
|
|
f'device_shape={self.devices.shape})')
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
|
|
|
|
f'devices={self.devices})')
|
|
|
|
|
2022-06-16 19:51:56 -07:00
|
|
|
@pxla.maybe_cached_property
|
|
|
|
def device_set(self) -> Set[Device]:
|
|
|
|
return set(self.devices.flat)
|
|
|
|
|
2022-07-20 13:04:49 -07:00
|
|
|
@functools.lru_cache(maxsize=4096)
|
2022-10-18 09:50:18 -07:00
|
|
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
2022-06-16 19:51:56 -07:00
|
|
|
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
|
2022-10-20 10:15:04 -07:00
|
|
|
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
|
2022-06-16 19:51:56 -07:00
|
|
|
|
2022-08-04 09:59:10 -07:00
|
|
|
@pxla.maybe_cached_property
|
2022-06-16 19:51:56 -07:00
|
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
|
|
return list(self.devices.flat)
|
|
|
|
|
|
|
|
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
|
|
|
|
raise NotImplementedError("pmap doesn't use OpSharding.")
|
2022-07-29 11:37:08 -07:00
|
|
|
|
2022-08-29 09:00:03 -07:00
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
|
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
|
|
|
sharded_dim = None
|
|
|
|
for i, s in enumerate(self.sharding_spec.sharding):
|
|
|
|
if isinstance(s, pxla.Unstacked):
|
|
|
|
sharded_dim = i
|
|
|
|
break
|
|
|
|
if sharded_dim is None:
|
|
|
|
return global_shape
|
|
|
|
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
|
|
|
|
2022-07-29 11:37:08 -07:00
|
|
|
|
2022-10-14 17:19:37 -07:00
|
|
|
class DevicesSharding(XLACompatibleSharding):
|
|
|
|
_devices: List[xc.Device]
|
|
|
|
_ids: np.ndarray # dtype DeviceIdSet
|
|
|
|
|
|
|
|
def __init__(self, devices: Union[Sequence[xc.Device], np.ndarray]):
|
|
|
|
if not isinstance(devices, np.ndarray):
|
|
|
|
devices = np.array(devices, dtype='object')
|
|
|
|
if not devices.size:
|
|
|
|
raise ValueError(f"{self.__class__.__name__}.__init__ requires at least "
|
|
|
|
f"one devices, got {devices}")
|
|
|
|
self._devices = list(devices.flat)
|
|
|
|
name = self._devices[0].platform.upper()
|
|
|
|
self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)],
|
|
|
|
dtype='object')
|
|
|
|
|
|
|
|
shape = property(op.attrgetter('_ids.shape'))
|
|
|
|
ndim = property(op.attrgetter('_ids.ndim'))
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
cls_name = self.__class__.__name__
|
|
|
|
body = np.array2string(self._ids, prefix=cls_name + '(', suffix=')',
|
|
|
|
max_line_width=100)
|
|
|
|
return f'{cls_name}({body})'
|
|
|
|
|
|
|
|
def reshape(self, *shape):
|
|
|
|
return self.remake(self._devices, self._ids.reshape(*shape))
|
|
|
|
|
|
|
|
def transpose(self, *axes):
|
|
|
|
return self.remake(self._devices, self._ids.transpose(*axes))
|
|
|
|
T = property(transpose)
|
|
|
|
|
|
|
|
def replicate(self, axis=None, keepdims=True):
|
|
|
|
new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union
|
|
|
|
return self.remake(self._devices, new_ids)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def remake(cls, devices: List[xc.Device], ids: np.ndarray) -> DevicesSharding:
|
|
|
|
self = cls.__new__(cls)
|
|
|
|
self._devices = devices
|
|
|
|
self._ids = ids
|
|
|
|
return self
|
|
|
|
|
|
|
|
# Hashable
|
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
|
|
|
return id(self._devices)
|
|
|
|
|
|
|
|
def __eq__(self, other) -> bool:
|
|
|
|
return (isinstance(other, DevicesSharding) and
|
|
|
|
id(self._devices) == id(other._devices) and
|
|
|
|
bool(np.all(self._ids == other._ids)))
|
|
|
|
|
|
|
|
# Sharding interface
|
|
|
|
|
|
|
|
@pxla.maybe_cached_property
|
|
|
|
def device_set(self) -> set[xc.Device]:
|
|
|
|
return set(self._devices)
|
|
|
|
|
|
|
|
# XLACompatibleSharding interface
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
|
|
|
def _to_xla_op_sharding(self, num_dimensions: int, axis_ctx=None):
|
|
|
|
assert axis_ctx is None
|
|
|
|
|
|
|
|
pbuf = xc.OpSharding()
|
|
|
|
if self.shape == (1,) * self.ndim:
|
|
|
|
pbuf.type = xc.OpSharding.Type.REPLICATED
|
|
|
|
return pbuf
|
|
|
|
|
|
|
|
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
|
|
|
|
set_size, = {len(device_set) for device_set in self._ids.flat}
|
|
|
|
pbuf.type = xc.OpSharding.Type.OTHER
|
|
|
|
if set_size > 1:
|
|
|
|
pbuf.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
|
|
|
|
pbuf.tile_assignment_dimensions = (*shape, set_size)
|
|
|
|
else:
|
|
|
|
pbuf.tile_assignment_dimensions = shape
|
|
|
|
pbuf.tile_assignment_devices = [i for ids in self._ids.flat for i in ids]
|
|
|
|
return pbuf
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _device_assignment(self) -> list[xc.Device]:
|
|
|
|
return self._devices
|
|
|
|
|
|
|
|
class DeviceIdSet:
|
|
|
|
_name: str
|
|
|
|
_ids: FrozenSet[int]
|
|
|
|
def __init__(self, name, *ids):
|
|
|
|
self._name = name
|
|
|
|
self._ids = frozenset(ids)
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
return iter(sorted(self._ids))
|
|
|
|
|
|
|
|
def __add__(self, other) -> DeviceIdSet:
|
|
|
|
assert isinstance(other, DeviceIdSet)
|
|
|
|
return DeviceIdSet(self._name, *(self._ids | other._ids))
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
return len(self._ids)
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
ids = ', '.join(safe_map(str, sorted(self._ids)))
|
|
|
|
return f'{{{self._name} {ids}}}'
|
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
|
|
|
return hash((self._name, self._ids))
|
|
|
|
|
|
|
|
def __eq__(self, other) -> bool:
|
|
|
|
return (isinstance(other, DeviceIdSet) and self._name == other._name and
|
|
|
|
self._ids == other._ids)
|
|
|
|
|
|
|
|
|
2022-08-05 12:17:41 -07:00
|
|
|
# TODO(yashkatariya): Remove this when minimum_jaxlib version is 0.3.17
|
|
|
|
def _hash_op_sharding(op: xc.OpSharding):
|
|
|
|
if op.type == xc.OpSharding.Type.TUPLE:
|
|
|
|
return hash(tuple(_hash_op_sharding(o) for o in op.tuple_shardings))
|
|
|
|
return hash((tuple(op.tile_assignment_devices), tuple(op.tile_assignment_dimensions),
|
|
|
|
op.type, op.replicate_on_last_tile_dim, tuple(op.last_tile_dims)))
|
|
|
|
|
|
|
|
|
2022-09-21 20:17:38 -07:00
|
|
|
@pxla.use_cpp_class(xc.OpShardingSharding if xc._version >= 95 else None)
|
2022-07-29 11:37:08 -07:00
|
|
|
class OpShardingSharding(XLACompatibleSharding):
|
|
|
|
|
2022-09-21 20:17:38 -07:00
|
|
|
@pxla.use_cpp_method
|
2022-07-29 11:37:08 -07:00
|
|
|
def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
|
2022-08-11 14:35:28 -07:00
|
|
|
self._devices = tuple(devices)
|
2022-07-29 11:37:08 -07:00
|
|
|
self._op_sharding = op_sharding
|
|
|
|
|
2022-08-11 14:35:28 -07:00
|
|
|
@pxla.maybe_cached_property
|
|
|
|
def _op_sharding_hash(self):
|
|
|
|
if xla_extension_version >= 81:
|
|
|
|
return hash(xc.HloSharding.from_proto(self._op_sharding))
|
|
|
|
else:
|
|
|
|
return _hash_op_sharding(self._op_sharding)
|
|
|
|
|
2022-07-29 11:37:08 -07:00
|
|
|
def __eq__(self, other):
|
|
|
|
if not isinstance(other, OpShardingSharding):
|
|
|
|
return False
|
2022-09-18 15:35:18 -07:00
|
|
|
if id(self) == id(other):
|
|
|
|
return True
|
2022-08-11 14:35:28 -07:00
|
|
|
return (pxla.are_op_shardings_equal(self._op_sharding, other._op_sharding) and
|
|
|
|
self._devices == other._devices)
|
2022-07-29 11:37:08 -07:00
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
if not hasattr(self, '_hash'):
|
2022-08-11 14:35:28 -07:00
|
|
|
self._hash = hash((self._devices, self._op_sharding_hash))
|
2022-07-29 11:37:08 -07:00
|
|
|
return self._hash
|
|
|
|
|
|
|
|
def __repr__(self):
|
2022-09-28 16:59:43 -07:00
|
|
|
if xla_extension_version >= 96:
|
|
|
|
return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
|
|
|
|
else:
|
|
|
|
if pxla.is_op_sharding_replicated(self._op_sharding):
|
|
|
|
return 'OpShardingSharding(REPLICATED)'
|
|
|
|
return f'OpShardingSharding({repr(self._op_sharding)})'
|
2022-07-29 11:37:08 -07:00
|
|
|
|
|
|
|
def is_compatible_aval(self, aval_shape: Shape):
|
|
|
|
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
|
|
|
|
if len(aval_shape) < len(num_ways_dim_sharded):
|
|
|
|
raise ValueError(
|
|
|
|
f"Sharding {self} is only valid for values of rank at least "
|
|
|
|
f"{len(num_ways_dim_sharded)}, but was applied to a value of rank "
|
|
|
|
f"{len(aval_shape)}")
|
|
|
|
|
|
|
|
@pxla.maybe_cached_property
|
|
|
|
def device_set(self) -> Set[Device]:
|
|
|
|
return set(self._devices)
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=4096)
|
2022-10-18 09:50:18 -07:00
|
|
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
2022-07-29 11:37:08 -07:00
|
|
|
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
|
|
|
|
len(self._devices))
|
|
|
|
return dict(safe_zip(self._devices, indices))
|
|
|
|
|
2022-08-04 09:59:10 -07:00
|
|
|
@property
|
2022-07-29 11:37:08 -07:00
|
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
|
|
return list(self._devices)
|
|
|
|
|
|
|
|
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
|
|
|
|
return self._op_sharding
|
2022-08-05 12:17:41 -07:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_replicated(cls, device_assignment):
|
2022-08-05 22:24:46 -07:00
|
|
|
proto = _get_replicated_op_sharding()
|
2022-08-05 12:17:41 -07:00
|
|
|
return cls(device_assignment, proto)
|