Add _manual_axes support to NamedSharding. This is needed because

custom_partitioning may produce manually sharded axes.

PiperOrigin-RevId: 559288864
This commit is contained in:
Parker Schuh 2023-08-22 19:23:53 -07:00 committed by jax authors
parent 517e0a93ca
commit e58ddb7258
2 changed files with 55 additions and 31 deletions

View File

@ -54,6 +54,7 @@ from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
@ -1887,14 +1888,17 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
out_aval, = ctx.avals_out
axis_ctx = ctx.module_context.axis_context
# axis_ctx and manual_axes is *only used with xmap* and xmap only works with
# NamedSharding. So convert the GSPMDSharding to NamedSharding
# and then convert it back with the added special axes.
# NamedSharding. So update the NamedSharding to have the manual axes.
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
mesh = resource_env.physical_mesh
parsed_pspec = parse_flatten_op_sharding(sharding._hlo_sharding, mesh)[0]
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
sharding = GSPMDSharding(
mps._device_assignment, mps._to_xla_hlo_sharding(aval.ndim, axis_ctx=axis_ctx))
if xla_extension_version >= 188:
sharding = NamedSharding._from_parsed_pspec(
mesh, parsed_pspec, _manual_axes=axis_ctx.manual_axes)
else:
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
sharding = GSPMDSharding(
mps._device_assignment, mps._to_xla_hlo_sharding(aval.ndim, axis_ctx=axis_ctx))
return [
mlir.wrap_with_sharding_op(ctx,
x_node, out_aval,

View File

@ -206,16 +206,20 @@ class NamedSharding(XLACompatibleSharding):
spec: PartitionSpec
_memory_kind: str | None
_parsed_pspec: ParsedPartitionSpec
_manual_axes: frozenset[MeshAxisName]
if xla_extension_version < 188:
_manual_axes = frozenset()
@use_cpp_method()
def __init__(
self, mesh: mesh_lib.Mesh, spec: PartitionSpec, *,
memory_kind: str | None = None, _parsed_pspec = None):
memory_kind: str | None = None, _parsed_pspec=None,
_manual_axes=frozenset()):
self.mesh = mesh
self.spec = spec
self._memory_kind = memory_kind
self._parsed_pspec = _parsed_pspec
self._manual_axes = _manual_axes
self._preprocess()
def _preprocess(self):
@ -240,7 +244,8 @@ class NamedSharding(XLACompatibleSharding):
def __reduce__(self):
return (type(self), (self.mesh, self.spec),
{'memory_kind': self.memory_kind})
{'memory_kind': self.memory_kind,
'_manual_axes': self._manual_axes})
if xla_extension_version >= 178:
@property
@ -249,7 +254,8 @@ class NamedSharding(XLACompatibleSharding):
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
self._hash = hash(
(self.mesh, self.memory_kind, self._parsed_pspec, self._manual_axes))
return self._hash
def __eq__(self, other):
@ -257,12 +263,11 @@ class NamedSharding(XLACompatibleSharding):
return False
if id(self) == id(other):
return True
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
mem_kind_equal = self.memory_kind == other.memory_kind
if (id(self.mesh) == id(other.mesh) and mem_kind_equal and
parsed_pspec_equal):
return True
return self.mesh == other.mesh and mem_kind_equal and parsed_pspec_equal
if (self._parsed_pspec != other._parsed_pspec
or self.memory_kind != other.memory_kind
or self._manual_axes != other._manual_axes):
return False
return id(self.mesh) == id(other.mesh) or self.mesh == other.mesh
def is_compatible_aval(self, aval_shape: Shape):
assert self._parsed_pspec is not None
@ -275,9 +280,16 @@ class NamedSharding(XLACompatibleSharding):
f"{len(aval_shape)}.{extra_msg}")
@classmethod
def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None):
return cls(mesh, parsed_pspec.get_partition_spec(),
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
def _from_parsed_pspec(
cls, mesh, parsed_pspec, *, memory_kind=None, _manual_axes=frozenset()
):
if xla_extension_version >= 188:
return cls(mesh, parsed_pspec.get_partition_spec(),
memory_kind=memory_kind, _parsed_pspec=parsed_pspec,
_manual_axes=_manual_axes)
else:
return cls(mesh, parsed_pspec.get_partition_spec(),
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
@property
def device_set(self) -> set[Device]:
@ -313,7 +325,7 @@ class NamedSharding(XLACompatibleSharding):
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
def _get_sharding_spec(self, num_dimensions, axis_ctx):
def _get_sharding_spec(self, num_dimensions, manual_axes):
assert self._parsed_pspec is not None
array_mapping = get_array_mapping(self._parsed_pspec)
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
@ -322,22 +334,30 @@ class NamedSharding(XLACompatibleSharding):
self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping)
# Used in `with_sharding_constraint`.
special_axes = {}
# Manual axes is only used with xmap.
if axis_ctx is not None and isinstance(axis_ctx, SPMDAxisContext):
if manual_axes:
axis_names = self.mesh.axis_names
# Ignore type because mypy doesn't recognize the `hasattr` check above.
for manual_axis in axis_ctx.manual_axes: # type: ignore
for manual_axis in manual_axes:
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
return sharding_spec, special_axes
@functools.lru_cache(maxsize=4096)
def _to_xla_hlo_sharding(
self, num_dimensions: int,
axis_ctx: SPMDAxisContext | ShardingContext | None = None
) -> xc.HloSharding:
sharding_spec, special_axes = self._get_sharding_spec(
num_dimensions, axis_ctx)
return sharding_spec.sharding_proto(special_axes=special_axes)
if xla_extension_version >= 188:
@functools.lru_cache(maxsize=4096)
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
sharding_spec, special_axes = self._get_sharding_spec(
num_dimensions, self._manual_axes)
return sharding_spec.sharding_proto(special_axes=special_axes)
else:
@functools.lru_cache(maxsize=4096)
def _to_xla_hlo_sharding(
self, num_dimensions: int,
axis_ctx: SPMDAxisContext | ShardingContext | None = None
) -> xc.HloSharding:
manual_axes = None
if axis_ctx is not None and isinstance(axis_ctx, SPMDAxisContext):
manual_axes = axis_ctx.manual_axes # type: ignore
sharding_spec, special_axes = self._get_sharding_spec(
num_dimensions, manual_axes)
return sharding_spec.sharding_proto(special_axes=special_axes)
@functools.lru_cache