mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add _manual_axes support to NamedSharding. This is needed because
custom_partitioning may produce manually sharded axes. PiperOrigin-RevId: 559288864
This commit is contained in:
parent
517e0a93ca
commit
e58ddb7258
@ -54,6 +54,7 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
|
||||
@ -1887,14 +1888,17 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
|
||||
out_aval, = ctx.avals_out
|
||||
axis_ctx = ctx.module_context.axis_context
|
||||
# axis_ctx and manual_axes is *only used with xmap* and xmap only works with
|
||||
# NamedSharding. So convert the GSPMDSharding to NamedSharding
|
||||
# and then convert it back with the added special axes.
|
||||
# NamedSharding. So update the NamedSharding to have the manual axes.
|
||||
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
||||
mesh = resource_env.physical_mesh
|
||||
parsed_pspec = parse_flatten_op_sharding(sharding._hlo_sharding, mesh)[0]
|
||||
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
|
||||
sharding = GSPMDSharding(
|
||||
mps._device_assignment, mps._to_xla_hlo_sharding(aval.ndim, axis_ctx=axis_ctx))
|
||||
if xla_extension_version >= 188:
|
||||
sharding = NamedSharding._from_parsed_pspec(
|
||||
mesh, parsed_pspec, _manual_axes=axis_ctx.manual_axes)
|
||||
else:
|
||||
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
|
||||
sharding = GSPMDSharding(
|
||||
mps._device_assignment, mps._to_xla_hlo_sharding(aval.ndim, axis_ctx=axis_ctx))
|
||||
return [
|
||||
mlir.wrap_with_sharding_op(ctx,
|
||||
x_node, out_aval,
|
||||
|
@ -206,16 +206,20 @@ class NamedSharding(XLACompatibleSharding):
|
||||
spec: PartitionSpec
|
||||
_memory_kind: str | None
|
||||
_parsed_pspec: ParsedPartitionSpec
|
||||
_manual_axes: frozenset[MeshAxisName]
|
||||
if xla_extension_version < 188:
|
||||
_manual_axes = frozenset()
|
||||
|
||||
@use_cpp_method()
|
||||
def __init__(
|
||||
self, mesh: mesh_lib.Mesh, spec: PartitionSpec, *,
|
||||
memory_kind: str | None = None, _parsed_pspec = None):
|
||||
|
||||
memory_kind: str | None = None, _parsed_pspec=None,
|
||||
_manual_axes=frozenset()):
|
||||
self.mesh = mesh
|
||||
self.spec = spec
|
||||
self._memory_kind = memory_kind
|
||||
self._parsed_pspec = _parsed_pspec
|
||||
self._manual_axes = _manual_axes
|
||||
self._preprocess()
|
||||
|
||||
def _preprocess(self):
|
||||
@ -240,7 +244,8 @@ class NamedSharding(XLACompatibleSharding):
|
||||
|
||||
def __reduce__(self):
|
||||
return (type(self), (self.mesh, self.spec),
|
||||
{'memory_kind': self.memory_kind})
|
||||
{'memory_kind': self.memory_kind,
|
||||
'_manual_axes': self._manual_axes})
|
||||
|
||||
if xla_extension_version >= 178:
|
||||
@property
|
||||
@ -249,7 +254,8 @@ class NamedSharding(XLACompatibleSharding):
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
|
||||
self._hash = hash(
|
||||
(self.mesh, self.memory_kind, self._parsed_pspec, self._manual_axes))
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -257,12 +263,11 @@ class NamedSharding(XLACompatibleSharding):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
|
||||
mem_kind_equal = self.memory_kind == other.memory_kind
|
||||
if (id(self.mesh) == id(other.mesh) and mem_kind_equal and
|
||||
parsed_pspec_equal):
|
||||
return True
|
||||
return self.mesh == other.mesh and mem_kind_equal and parsed_pspec_equal
|
||||
if (self._parsed_pspec != other._parsed_pspec
|
||||
or self.memory_kind != other.memory_kind
|
||||
or self._manual_axes != other._manual_axes):
|
||||
return False
|
||||
return id(self.mesh) == id(other.mesh) or self.mesh == other.mesh
|
||||
|
||||
def is_compatible_aval(self, aval_shape: Shape):
|
||||
assert self._parsed_pspec is not None
|
||||
@ -275,9 +280,16 @@ class NamedSharding(XLACompatibleSharding):
|
||||
f"{len(aval_shape)}.{extra_msg}")
|
||||
|
||||
@classmethod
|
||||
def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None):
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
|
||||
def _from_parsed_pspec(
|
||||
cls, mesh, parsed_pspec, *, memory_kind=None, _manual_axes=frozenset()
|
||||
):
|
||||
if xla_extension_version >= 188:
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec,
|
||||
_manual_axes=_manual_axes)
|
||||
else:
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
|
||||
|
||||
@property
|
||||
def device_set(self) -> set[Device]:
|
||||
@ -313,7 +325,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
def with_memory_kind(self, kind: str) -> NamedSharding:
|
||||
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
|
||||
|
||||
def _get_sharding_spec(self, num_dimensions, axis_ctx):
|
||||
def _get_sharding_spec(self, num_dimensions, manual_axes):
|
||||
assert self._parsed_pspec is not None
|
||||
array_mapping = get_array_mapping(self._parsed_pspec)
|
||||
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
|
||||
@ -322,22 +334,30 @@ class NamedSharding(XLACompatibleSharding):
|
||||
self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping)
|
||||
# Used in `with_sharding_constraint`.
|
||||
special_axes = {}
|
||||
# Manual axes is only used with xmap.
|
||||
if axis_ctx is not None and isinstance(axis_ctx, SPMDAxisContext):
|
||||
if manual_axes:
|
||||
axis_names = self.mesh.axis_names
|
||||
# Ignore type because mypy doesn't recognize the `hasattr` check above.
|
||||
for manual_axis in axis_ctx.manual_axes: # type: ignore
|
||||
for manual_axis in manual_axes:
|
||||
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
|
||||
return sharding_spec, special_axes
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def _to_xla_hlo_sharding(
|
||||
self, num_dimensions: int,
|
||||
axis_ctx: SPMDAxisContext | ShardingContext | None = None
|
||||
) -> xc.HloSharding:
|
||||
sharding_spec, special_axes = self._get_sharding_spec(
|
||||
num_dimensions, axis_ctx)
|
||||
return sharding_spec.sharding_proto(special_axes=special_axes)
|
||||
if xla_extension_version >= 188:
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
sharding_spec, special_axes = self._get_sharding_spec(
|
||||
num_dimensions, self._manual_axes)
|
||||
return sharding_spec.sharding_proto(special_axes=special_axes)
|
||||
else:
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def _to_xla_hlo_sharding(
|
||||
self, num_dimensions: int,
|
||||
axis_ctx: SPMDAxisContext | ShardingContext | None = None
|
||||
) -> xc.HloSharding:
|
||||
manual_axes = None
|
||||
if axis_ctx is not None and isinstance(axis_ctx, SPMDAxisContext):
|
||||
manual_axes = axis_ctx.manual_axes # type: ignore
|
||||
sharding_spec, special_axes = self._get_sharding_spec(
|
||||
num_dimensions, manual_axes)
|
||||
return sharding_spec.sharding_proto(special_axes=special_axes)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
|
Loading…
x
Reference in New Issue
Block a user