2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2021 The JAX Authors.
|
2022-05-26 11:41:50 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2023-03-13 08:49:39 -07:00
|
|
|
|
2022-10-14 17:19:37 -07:00
|
|
|
from __future__ import annotations
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
from collections.abc import Mapping, Sequence
|
2022-07-20 13:04:49 -07:00
|
|
|
import functools
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
from jax._src.util import safe_zip, use_cpp_class, cache
|
2023-04-14 08:46:17 -07:00
|
|
|
from jax._src import xla_bridge as xb
|
2022-05-26 11:41:50 -07:00
|
|
|
from jax._src.lib import xla_client as xc
|
2024-06-05 08:02:39 -07:00
|
|
|
from jax._src.op_shardings import (
|
|
|
|
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated,
|
|
|
|
op_sharding_to_indices)
|
2022-06-16 19:51:56 -07:00
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
Shape = tuple[int, ...]
|
2022-05-26 11:41:50 -07:00
|
|
|
Device = xc.Device
|
2023-06-23 15:11:37 -07:00
|
|
|
Index = tuple[slice, ...]
|
2022-05-26 11:41:50 -07:00
|
|
|
XLADeviceAssignment = Sequence[Device]
|
|
|
|
|
2023-03-13 08:49:39 -07:00
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
@cache(max_size=4096, trace_context_in_key=False)
|
2023-04-13 15:18:56 -07:00
|
|
|
def _addressable_devices_indices_map(
|
2023-07-21 14:20:39 -04:00
|
|
|
sharding: Sharding, global_shape: Shape) -> Mapping[Device, Index | None]:
|
2024-01-05 14:16:32 -08:00
|
|
|
global_map = sharding.devices_indices_map(global_shape)
|
2023-04-14 08:46:17 -07:00
|
|
|
if sharding.is_fully_addressable:
|
2024-01-05 14:16:32 -08:00
|
|
|
return global_map
|
|
|
|
if hasattr(sharding, '_internal_device_list'):
|
|
|
|
return {d: global_map[d]
|
|
|
|
for d in sharding._internal_device_list.addressable_device_list}
|
|
|
|
return {d: ind for d, ind in global_map.items()
|
2023-04-13 15:18:56 -07:00
|
|
|
if d.process_index == d.client.process_index()}
|
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
@cache(max_size=4096, trace_context_in_key=False)
|
2024-12-02 12:39:56 -08:00
|
|
|
def common_devices_indices_map(
|
|
|
|
s: Sharding, global_shape: Shape) -> Mapping[Device, Index]:
|
2024-06-05 08:02:39 -07:00
|
|
|
s.shard_shape(global_shape) # raises a good error message
|
|
|
|
hlo_sharding = s._to_xla_hlo_sharding(len(global_shape))
|
|
|
|
indices = op_sharding_to_indices(hlo_sharding, global_shape,
|
|
|
|
len(s._device_assignment))
|
|
|
|
return dict(safe_zip(s._device_assignment, indices))
|
|
|
|
|
2023-04-13 15:18:56 -07:00
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
@cache(max_size=4096, trace_context_in_key=False)
|
2024-06-05 08:02:39 -07:00
|
|
|
def _common_shard_shape(self, global_shape: Shape) -> Shape:
|
|
|
|
hlo_sharding = self._to_xla_hlo_sharding(len(global_shape))
|
|
|
|
if is_op_sharding_replicated(hlo_sharding):
|
|
|
|
return global_shape
|
|
|
|
partitions, _ = get_num_ways_dim_sharded(hlo_sharding)
|
|
|
|
assert len(partitions) == len(global_shape), (len(partitions), len(global_shape))
|
|
|
|
out = []
|
|
|
|
for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
|
|
|
|
try:
|
|
|
|
quotient, remainder = divmod(s, p)
|
|
|
|
except TypeError:
|
|
|
|
# TODO Figure out how to partition dynamic shapes
|
|
|
|
raise NotImplementedError
|
|
|
|
if remainder != 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"Sharding {self} implies that array axis {dim} is partitioned "
|
|
|
|
f"{p} times, but the dimension size is {s} "
|
|
|
|
f"(full shape: {global_shape}, "
|
|
|
|
f"per-dimension tiling factors: {partitions} should evenly divide "
|
|
|
|
"the shape)")
|
|
|
|
out.append(quotient)
|
|
|
|
return tuple(out)
|
|
|
|
|
|
|
|
|
|
|
|
@use_cpp_class(xc.Sharding)
|
2023-03-14 14:19:25 -07:00
|
|
|
class Sharding:
|
2023-08-03 10:15:09 -04:00
|
|
|
"""Describes how a :class:`jax.Array` is laid out across devices.
|
2022-11-14 15:47:06 -08:00
|
|
|
"""
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-08-29 09:00:03 -07:00
|
|
|
# Abstract methods below that subclasses should implement.
|
2023-03-14 14:19:25 -07:00
|
|
|
@property
|
2023-06-23 15:11:37 -07:00
|
|
|
def device_set(self) -> set[Device]:
|
2023-08-03 10:15:09 -04:00
|
|
|
"""The set of devices that this :class:`Sharding` spans.
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-11-14 15:47:06 -08:00
|
|
|
In multi-controller JAX, the set of devices is global, i.e., includes
|
|
|
|
non-addressable devices from other processes.
|
2022-05-26 11:41:50 -07:00
|
|
|
"""
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
2023-04-14 13:55:52 -07:00
|
|
|
@property
|
|
|
|
def is_fully_replicated(self) -> bool:
|
2023-08-03 10:15:09 -04:00
|
|
|
"""Is this sharding fully replicated?
|
|
|
|
|
|
|
|
A sharding is fully replicated if each device has a complete copy of the
|
2023-09-05 17:27:47 -07:00
|
|
|
entire data.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_fully_addressable(self) -> bool:
|
|
|
|
"""Is this sharding fully addressable?
|
|
|
|
|
|
|
|
A sharding is fully addressable if the current process can address all of
|
|
|
|
the devices named in the :class:`Sharding`. ``is_fully_addressable`` is
|
|
|
|
equivalent to "is_local" in multi-process JAX.
|
|
|
|
"""
|
2023-04-14 13:55:52 -07:00
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
2024-08-14 09:02:20 -07:00
|
|
|
@property
|
|
|
|
def num_devices(self) -> int:
|
|
|
|
"""Number of devices that the sharding contains."""
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
2023-07-18 07:38:40 -07:00
|
|
|
@property
|
2023-07-21 14:20:39 -04:00
|
|
|
def memory_kind(self) -> str | None:
|
2023-07-18 07:38:40 -07:00
|
|
|
"""Returns the memory kind of the sharding."""
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
2023-08-04 09:43:39 -07:00
|
|
|
def with_memory_kind(self, kind: str) -> Sharding:
|
|
|
|
"""Returns a new Sharding instance with the specified memory kind."""
|
|
|
|
raise NotImplementedError('Subclasses should implement this method')
|
|
|
|
|
2024-06-05 08:02:39 -07:00
|
|
|
@property
|
|
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
|
|
|
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
|
|
|
|
2024-10-14 10:07:08 -07:00
|
|
|
def _to_sdy_sharding(self, num_dimensions: int):
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
raise NotImplementedError('Subclasses should implement this method.')
|
2024-06-05 08:02:39 -07:00
|
|
|
|
2022-08-29 09:00:03 -07:00
|
|
|
#############################################################################
|
|
|
|
# Default implementations below that all subclasses will inherit.
|
|
|
|
|
2025-01-28 11:04:05 -08:00
|
|
|
@property
|
2025-01-29 09:33:44 -08:00
|
|
|
def _is_concrete(self) -> bool:
|
2025-01-28 11:04:05 -08:00
|
|
|
return True
|
|
|
|
|
2022-11-29 16:39:45 -08:00
|
|
|
@functools.cached_property
|
2023-06-23 15:11:37 -07:00
|
|
|
def addressable_devices(self) -> set[Device]:
|
2023-08-03 10:15:09 -04:00
|
|
|
"""The set of devices in the :class:`Sharding` that are addressable by the
|
|
|
|
current process.
|
|
|
|
"""
|
2023-04-14 08:46:17 -07:00
|
|
|
# Add a fast path for single controller runtimes.
|
|
|
|
if xb.process_count() == 1:
|
|
|
|
return self.device_set
|
2022-10-27 18:43:40 -07:00
|
|
|
return {d for d in self.device_set
|
|
|
|
if d.process_index == d.client.process_index()}
|
2022-05-26 11:41:50 -07:00
|
|
|
|
2022-09-19 16:58:46 -07:00
|
|
|
def addressable_devices_indices_map(
|
2023-07-21 14:20:39 -04:00
|
|
|
self, global_shape: Shape) -> Mapping[Device, Index | None]:
|
2023-08-03 10:15:09 -04:00
|
|
|
"""A mapping from addressable devices to the slice of array data each contains.
|
2022-11-14 15:47:06 -08:00
|
|
|
|
|
|
|
``addressable_devices_indices_map`` contains that part of
|
|
|
|
``device_indices_map`` that applies to the addressable devices.
|
|
|
|
"""
|
2023-04-13 15:18:56 -07:00
|
|
|
return _addressable_devices_indices_map(self, global_shape)
|
2024-06-05 08:02:39 -07:00
|
|
|
|
|
|
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
|
|
|
"""Returns a mapping from devices to the array slices each contains.
|
|
|
|
|
|
|
|
The mapping includes all global devices, i.e., including
|
|
|
|
non-addressable devices from other processes.
|
|
|
|
"""
|
|
|
|
return common_devices_indices_map(self, global_shape)
|
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def _addressable_device_assignment(self) -> XLADeviceAssignment:
|
|
|
|
if self.is_fully_addressable:
|
|
|
|
return self._device_assignment
|
|
|
|
if hasattr(self, '_internal_device_list'):
|
|
|
|
return tuple(self._internal_device_list.addressable_device_list)
|
|
|
|
return tuple(d for d in self._device_assignment
|
|
|
|
if d.process_index == d.client.process_index())
|
|
|
|
|
|
|
|
def shard_shape(self, global_shape: Shape) -> Shape:
|
|
|
|
"""Returns the shape of the data on each device.
|
|
|
|
|
|
|
|
The shard shape returned by this function is calculated from
|
|
|
|
``global_shape`` and the properties of the sharding.
|
|
|
|
"""
|
|
|
|
return _common_shard_shape(self, global_shape)
|
|
|
|
|
|
|
|
def is_equivalent_to(self: Sharding, other: Sharding, ndim: int) -> bool:
|
|
|
|
"""Returns ``True`` if two shardings are equivalent.
|
|
|
|
|
|
|
|
Two shardings are equivalent if they place the same logical array shards on
|
|
|
|
the same devices.
|
|
|
|
|
|
|
|
For example, a :class:`NamedSharding` may be equivalent
|
|
|
|
to a :class:`PositionalSharding` if both place the same shards of the array
|
|
|
|
on the same devices.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
|
|
|
|
other._to_xla_hlo_sharding(ndim))
|
|
|
|
and self._internal_device_list == other._internal_device_list and # type: ignore
|
|
|
|
self.memory_kind == other.memory_kind)
|
|
|
|
# NotImplementedError is raised by PmapSharding because it can't lower
|
|
|
|
# to OpSharding. So if `other` is a PmapSharding, default to a strict
|
|
|
|
# equality check.
|
|
|
|
except NotImplementedError:
|
|
|
|
return self == other
|