rocm_jax/jax/_src/sharding.py

565 lines
20 KiB
Python
Raw Normal View History

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import functools
from collections import Counter
import operator as op
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
FrozenSet, Union, cast)
import jax
from jax._src.util import safe_map, safe_zip
from jax._src.lib import xla_client as xc
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
from jax.interpreters import pxla, mlir
import numpy as np
Shape = Tuple[int, ...]
Device = xc.Device
Index = Tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]
@pxla.use_cpp_class(xc.Sharding if xc._version >= 94 else None)
class Sharding(metaclass=abc.ABCMeta):
# Abstract methods below that subclasses should implement.
@abc.abstractproperty
def device_set(self) -> Set[Device]:
"""A unique set of devices that this sharding represents.
Devices can be non-addressable too.
"""
raise NotImplementedError('Subclasses should implement this method.')
@abc.abstractmethod
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
raise NotImplementedError('Subclasses should implement this method.')
@abc.abstractmethod
def shard_shape(self, global_shape: Shape) -> Shape:
raise NotImplementedError('Subclasses should implement this method.')
#############################################################################
# Default implementations below that all subclasses will inherit.
@pxla.maybe_cached_property
def addressable_devices(self) -> Set[Device]:
"""A set of addressable devices by the current process"""
return {d for d in self.device_set
if d.process_index == d.client.process_index()}
@pxla.maybe_cached_property
def is_fully_addressable(self) -> bool:
# The pytype disable is because pytype can't recognize a cached property.
return len(self.device_set) == len(self.addressable_devices) # type: ignore
def device_indices(self, device: Device,
global_shape: Shape) -> Optional[Index]:
return self.devices_indices_map(global_shape)[device]
@functools.lru_cache(maxsize=4096)
def addressable_devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
return {d: ind for d, ind in self.devices_indices_map(global_shape).items()
if d.process_index == d.client.process_index()}
@pxla.use_cpp_class(xc.XLACompatibleSharding if xc._version >= 94 else None)
class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
# Abstract methods below that subclasses should implement.
@abc.abstractproperty
def _device_assignment(self) -> XLADeviceAssignment:
raise NotImplementedError('Subclasses should implement this method.')
@abc.abstractmethod
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
raise NotImplementedError('Subclasses should implement this method.')
#############################################################################
# Default implementations below that all subclasses will inherit.
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
op_sharding = self._to_xla_op_sharding(len(global_shape))
op_sharding_sharding = OpShardingSharding(self._device_assignment,
op_sharding)
return op_sharding_sharding.devices_indices_map(global_shape)
@pxla.maybe_cached_property
def _addressable_device_assignment(self) -> XLADeviceAssignment:
return [d for d in self._device_assignment
if d.process_index == d.client.process_index()]
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
if pxla.is_op_sharding_replicated(op_sharding):
return global_shape
partitions, _ = pxla._get_num_ways_dim_sharded(op_sharding)
assert len(partitions) == len(global_shape), (len(partitions), len(global_shape))
out = []
for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
quotient, remainder = divmod(s, p)
if remainder != 0:
raise ValueError(
f"Sharding {self} implies that array axis {dim} is partitioned "
f"{p} times, but the dimension size is {s} "
f"(full shape: {global_shape}, "
f"per-dimension tiling factors: {partitions} should evenly divide "
"the shape)")
out.append(quotient)
return tuple(out)
@functools.lru_cache()
def _check_mesh_resource_axis(mesh, parsed_pspec):
try:
[mesh.shape[r] for p in parsed_pspec if p is not None
for r in p]
except KeyError as e:
raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is "
"undefined.") from None
def _hashed_index(x) -> int:
# This works for both `pjit`/`xmap` indices and `pmap` indices (which might
# have an integer instead of a slice).
assert all(v.step is None for v in x if isinstance(v, slice))
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))
@functools.lru_cache(maxsize=4096)
def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
try:
device_indices_map_fn = sharding.devices_indices_map
except AttributeError:
raise ValueError(
f'Cannot calculate replica ids from sharding: {sharding}. Please '
'create a device to index mapping for your sharding from which replica '
'ids will be calculated.') from None
index_to_replica: Dict[int, int] = Counter()
out = {}
for device, index in device_indices_map_fn(global_shape).items():
h_index = _hashed_index(index)
replica_id = index_to_replica[h_index]
index_to_replica[h_index] += 1
out[device] = replica_id
return out
@pxla.use_cpp_class(xc.MeshPspecSharding if xc._version >= 95 else None)
class MeshPspecSharding(XLACompatibleSharding):
@pxla.use_cpp_method
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
def __init__(
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
self.mesh = mesh
self.spec = spec
self._parsed_pspec = _parsed_pspec
self._preprocess()
def _preprocess(self):
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
# This split exists because you can pass `_parsed_pspec` that has been
# modified from the original. For example: Adding extra dimension to
# axis_resources for vmap handlers. In such cases you need to preserve the
# `sync` attribute of parsed pspecs.
# PartitionSpec is inferred from the parsed pspec in this case.
# TODO(yaskatariya): Remove this and replace this with a normalized
# representation of Parsed Pspec
if self._parsed_pspec is None:
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
from jax.experimental import pjit
self._parsed_pspec, _, _, _ = pjit._prepare_axis_resources(
self.spec, "MeshPspecSharding spec")
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
def __repr__(self):
return f'NamedSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self.mesh, self._parsed_pspec))
return self._hash
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
def __eq__(self, other):
if not isinstance(other, MeshPspecSharding):
return False
if id(self) == id(other):
return True
if id(self.mesh) == id(other.mesh) and self._parsed_pspec == other._parsed_pspec:
return True
return self.mesh == other.mesh and self._parsed_pspec == other._parsed_pspec
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
def is_compatible_aval(self, aval_shape: Shape):
if len(aval_shape) < len(self._parsed_pspec):
raise ValueError(
f"Sharding {self} is only valid for values of rank at least "
f"{len(self._parsed_pspec)}, but was applied to a value of rank "
f"{len(aval_shape)}")
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
@classmethod
def _from_parsed_pspec(cls, mesh, parsed_pspec):
return cls(mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
@pxla.maybe_cached_property
def device_set(self) -> Set[Device]:
return set(self.mesh.devices.flat)
@pxla.maybe_cached_property
def _device_assignment(self) -> XLADeviceAssignment:
return list(self.mesh.devices.flat)
@functools.lru_cache(maxsize=4096)
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
def _to_xla_op_sharding(
self,
num_dimensions: int,
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
) -> xc.OpSharding:
from jax.experimental.pjit import get_array_mapping
array_mapping = get_array_mapping(self._parsed_pspec)
# TODO(yashkatariya): Move away from sharding spec in MeshPspecSharding
# since we don't really need sharding spec.
sharding_spec = pxla.new_mesh_sharding_specs(
self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping)
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
# Used in `with_sharding_constraint`.
special_axes = {}
# Manual axes is only used with xmap.
if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext):
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
axis_names = self.mesh.axis_names
# Ignore type because mypy doesn't recognize the `hasattr` check above.
for manual_axis in axis_ctx.manual_axes: # type: ignore
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
return sharding_spec.sharding_proto(special_axes=special_axes)
# New name of MeshPspecSharding to match with PositionalSharding below.
NamedSharding = MeshPspecSharding
@functools.lru_cache()
def _get_replicated_op_sharding():
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.REPLICATED
return proto
@pxla.use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
class SingleDeviceSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(self, device: Device):
self._device = device
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
def __repr__(self):
return f"SingleDeviceSharding(device={repr(self._device)})"
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
def __hash__(self):
return hash(self._device)
def __eq__(self, other):
if not isinstance(other, SingleDeviceSharding):
return False
if id(self) == id(other):
return True
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
2022-07-07 10:41:27 -07:00
return self._device == other._device
@property
def device_set(self) -> Set[Device]:
return {self._device}
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
return {self._device: (slice(None),) * len(global_shape)}
@property
def _device_assignment(self) -> XLADeviceAssignment:
return [self._device]
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
return _get_replicated_op_sharding()
@pxla.use_cpp_class(xc.PmapSharding if xc._version >= 94 else None)
class PmapSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(self, devices: np.ndarray, sharding_spec: pxla.ShardingSpec):
self.devices = devices
# The sharding spec should be pmap's sharding spec.
self.sharding_spec = sharding_spec
def __eq__(self, other):
if not isinstance(other, PmapSharding):
return False
if id(self) == id(other):
return True
return (self.sharding_spec == other.sharding_spec and
np.array_equal(self.devices, other.devices))
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((tuple(self.devices.flat), self.sharding_spec))
return self._hash
def __str__(self):
device_ids = [d.id for d in self.devices.flat]
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
f'device_ids={device_ids}, '
f'device_platform={self.devices.flat[0].platform.upper()}, '
f'device_shape={self.devices.shape})')
def __repr__(self):
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
f'devices={self.devices})')
# TODO(yashkatariya): Expose `sharded_dim_size` in the API if required.
@classmethod
def default(cls, shape: Shape, sharded_dim: int = 0) -> PmapSharding:
"""Creates a `PmapSharding` which matches the implicit device order used by
`pmap`.
Args:
shape: The shape of the input array.
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
"""
# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
aval = jax.ShapedArray(shape, np.int32)
sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim)
num_ways_sharded = None
for s in sharding_spec.sharding:
if isinstance(s, pxla.Unstacked):
num_ways_sharded = s.size
if num_ways_sharded is None:
raise NotImplementedError(
'`None` to sharded_dim is not supported. Please file a jax '
'issue if you need this feature.')
pmap_devices = jax.local_devices()[:num_ways_sharded]
return cls(pmap_devices, sharding_spec)
@pxla.maybe_cached_property
def device_set(self) -> Set[Device]:
return set(self.devices.flat)
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
@pxla.maybe_cached_property
def _device_assignment(self) -> XLADeviceAssignment:
return list(self.devices.flat)
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
raise NotImplementedError("pmap doesn't use OpSharding.")
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
sharded_dim = None
for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked):
sharded_dim = i
break
if sharded_dim is None:
return global_shape
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
class PositionalSharding(XLACompatibleSharding):
_devices: List[xc.Device]
_ids: np.ndarray # dtype DeviceIdSet
def __init__(self, devices: Union[Sequence[xc.Device], np.ndarray]):
if not isinstance(devices, np.ndarray):
devices = np.array(devices, dtype='object')
if not devices.size:
raise ValueError(f"{self.__class__.__name__}.__init__ requires at least "
f"one device, got {devices}")
self._devices = list(devices.flat)
name = self._devices[0].platform.upper()
self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)],
dtype='object').reshape(devices.shape)
shape = property(op.attrgetter('_ids.shape'))
ndim = property(op.attrgetter('_ids.ndim'))
def __repr__(self) -> str:
cls_name = self.__class__.__name__
ids = self._ids.copy()
platform_name = self._devices[0].platform.upper()
for idx, x in np.ndenumerate(ids):
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x))
body = np.array2string(ids, prefix=cls_name + '(', suffix=')',
max_line_width=100)
return f'{cls_name}({body})'
def reshape(self, *shape):
return self.remake(self._devices, self._ids.reshape(*shape))
def transpose(self, *axes):
return self.remake(self._devices, self._ids.transpose(*axes))
T = property(transpose)
def replicate(self, axis=None, keepdims=True):
new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union
return self.remake(self._devices, new_ids)
@classmethod
def remake(
cls, devices: List[xc.Device], ids: np.ndarray) -> PositionalSharding:
self = cls.__new__(cls)
self._devices = devices
self._ids = ids
return self
# Hashable
def __hash__(self) -> int:
return id(self._devices)
def __eq__(self, other) -> bool:
return (isinstance(other, PositionalSharding) and
id(self._devices) == id(other._devices) and
bool(np.all(self._ids == other._ids)))
# Sharding interface
@pxla.maybe_cached_property
def device_set(self) -> set[xc.Device]:
return set(self._devices)
# XLACompatibleSharding interface
@functools.lru_cache(maxsize=4096)
def _to_xla_op_sharding(self, num_dimensions: int, axis_ctx=None):
assert axis_ctx is None
pbuf = xc.OpSharding()
if self.shape == (1,) * self.ndim:
pbuf.type = xc.OpSharding.Type.REPLICATED
return pbuf
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
set_size, = {len(device_set) for device_set in self._ids.flat}
pbuf.type = xc.OpSharding.Type.OTHER
if set_size > 1:
pbuf.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
pbuf.tile_assignment_dimensions = (*shape, set_size)
else:
pbuf.tile_assignment_dimensions = shape
pbuf.tile_assignment_devices = [i for ids in self._ids.flat for i in ids]
return pbuf
@property
def _device_assignment(self) -> list[xc.Device]:
return self._devices
class DeviceIdSet:
_name: str
_ids: FrozenSet[int]
def __init__(self, name, *ids):
self._name = name
self._ids = frozenset(ids)
def __iter__(self):
return iter(sorted(self._ids))
def __add__(self, other) -> DeviceIdSet:
assert isinstance(other, DeviceIdSet)
return DeviceIdSet(self._name, *(self._ids | other._ids))
def __len__(self) -> int:
return len(self._ids)
def __repr__(self) -> str:
ids = ', '.join(safe_map(str, sorted(self._ids)))
return f'{{{self._name} {ids}}}'
def __hash__(self) -> int:
return hash((self._name, self._ids))
def __eq__(self, other) -> bool:
return (isinstance(other, DeviceIdSet) and self._name == other._name and
self._ids == other._ids)
@pxla.use_cpp_class(xc.OpShardingSharding if xc._version >= 95 else None)
class OpShardingSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
self._devices = tuple(devices)
self._op_sharding = op_sharding
@pxla.maybe_cached_property
def _op_sharding_hash(self):
return hash(xc.HloSharding.from_proto(self._op_sharding))
def __eq__(self, other):
if not isinstance(other, OpShardingSharding):
return False
if id(self) == id(other):
return True
return (pxla.are_op_shardings_equal(self._op_sharding, other._op_sharding) and
self._devices == other._devices)
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self._devices, self._op_sharding_hash))
return self._hash
def __repr__(self):
return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
def is_compatible_aval(self, aval_shape: Shape):
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
if len(aval_shape) < len(num_ways_dim_sharded):
raise ValueError(
f"Sharding {self} is only valid for values of rank at least "
f"{len(num_ways_dim_sharded)}, but was applied to a value of rank "
f"{len(aval_shape)}")
@pxla.maybe_cached_property
def device_set(self) -> Set[Device]:
return set(self._devices)
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
len(self._devices))
return dict(safe_zip(self._devices, indices))
@property
def _device_assignment(self) -> XLADeviceAssignment:
return list(self._devices)
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
return self._op_sharding
@classmethod
def get_replicated(cls, device_assignment):
proto = _get_replicated_op_sharding()
return cls(device_assignment, proto)