mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.
Also do a couple of cleanups. PiperOrigin-RevId: 685746298
This commit is contained in:
parent
75e22f2ccd
commit
824ccd7183
@ -493,9 +493,10 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
if devices is None:
|
||||
raise AssertionError(
|
||||
'Please file a bug at https://github.com/jax-ml/jax/issues')
|
||||
if axis_context.mesh_shape is not None:
|
||||
ma, ms = list(zip(*axis_context.mesh_shape))
|
||||
mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma)
|
||||
am = axis_context.abstract_mesh
|
||||
if am is not None:
|
||||
mesh = mesh_lib.Mesh(np.array(devices).reshape(am.axis_sizes),
|
||||
am.axis_names)
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
devices = axis_context.mesh._flat_devices_tuple
|
||||
else:
|
||||
|
@ -41,7 +41,6 @@ from jax._src import effects as effects_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import path
|
||||
from jax._src import pickle_util
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -50,12 +49,11 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import AUTO
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib.mlir import dialects
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir import dialects, ir, passmanager
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect, hlo
|
||||
from jax._src.lib.mlir import register_jax_dialects
|
||||
from jax._src.state.types import AbstractRef
|
||||
|
||||
@ -900,10 +898,12 @@ def unflatten_ir_values_like_types(xs: Iterable[ir.Value],
|
||||
_module_name_regex = re.compile(r"[^\w.-]")
|
||||
|
||||
def sharded_aval(aval: core.AbstractValue,
|
||||
sharding: JSharding | None) -> core.AbstractValue:
|
||||
sharding: JSharding | AUTO | None) -> core.AbstractValue:
|
||||
"""Returns the new aval sharded based on sharding proto."""
|
||||
if sharding is None:
|
||||
return aval
|
||||
if isinstance(sharding, AUTO):
|
||||
return aval
|
||||
if isinstance(aval, core.AbstractToken):
|
||||
return aval
|
||||
if not isinstance(aval, (core.ShapedArray, core.DShapedArray)):
|
||||
@ -991,10 +991,14 @@ def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim):
|
||||
|
||||
def _to_physical_op_sharding(
|
||||
ctx: ModuleContext,
|
||||
aval: core.AbstractValue, sharding: JSharding | None,
|
||||
) -> xc.OpSharding | sharding.SdyArraySharding | None:
|
||||
aval: core.AbstractValue, sharding: JSharding | AUTO | None,
|
||||
) -> xc.OpSharding | sharding_impls.SdyArraySharding | None:
|
||||
if sharding is None:
|
||||
return None
|
||||
if isinstance(sharding, AUTO):
|
||||
if config.use_shardy_partitioner.value:
|
||||
return sharding._to_sdy_sharding(aval.ndim) # type: ignore
|
||||
return None
|
||||
assert isinstance(sharding, JSharding)
|
||||
if isinstance(aval, AbstractRef):
|
||||
return _to_physical_op_sharding(ctx, aval.inner_aval, sharding)
|
||||
@ -1022,9 +1026,11 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout,
|
||||
return str(layout._to_xla_layout(aval.dtype)) # type: ignore
|
||||
|
||||
|
||||
def _get_mem_kind(s: JSharding | None) -> str | None:
|
||||
def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
|
||||
if s is None:
|
||||
return None
|
||||
if isinstance(s, AUTO):
|
||||
return None
|
||||
assert isinstance(s, JSharding)
|
||||
return s.memory_kind
|
||||
|
||||
@ -1040,8 +1046,8 @@ def lower_jaxpr_to_module(
|
||||
name_stack: source_info_util.NameStack,
|
||||
donated_args: Sequence[bool],
|
||||
replicated_args: Sequence[bool] | None = None,
|
||||
arg_shardings: Sequence[JSharding | None] | None = None,
|
||||
result_shardings: Sequence[JSharding | None] | None = None,
|
||||
arg_shardings: Sequence[JSharding | AUTO | None] | None = None,
|
||||
result_shardings: Sequence[JSharding | AUTO | None] | None = None,
|
||||
in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
arg_names: Sequence[str | None] | None = None,
|
||||
@ -1084,8 +1090,9 @@ def lower_jaxpr_to_module(
|
||||
"In multi-platform lowering either all or no lowering platforms "
|
||||
f"should support donation. Lowering for {platforms} of which "
|
||||
f"only {platforms_with_donation} support donation")
|
||||
if num_partitions > 1 and (
|
||||
result_shardings is None or all(s is None for s in result_shardings)):
|
||||
if (num_partitions > 1 and
|
||||
(result_shardings is None or
|
||||
all(s is None or isinstance(s, AUTO) for s in result_shardings))):
|
||||
xla_donated_args = donated_args
|
||||
donated_args = [False] * len(donated_args)
|
||||
if xla_donated_args is None:
|
||||
@ -1135,16 +1142,6 @@ def lower_jaxpr_to_module(
|
||||
# Remove module name characters that XLA would alter. This ensures that
|
||||
# XLA computation preserves the module name.
|
||||
attrs = ctx.module.operation.attributes
|
||||
if config.use_shardy_partitioner.value:
|
||||
if (isinstance(axis_context, sharding_impls.ShardingContext) and
|
||||
axis_context.mesh_shape is not None):
|
||||
sdy_mesh_attr = dialects.sdy.MeshAttr.get(
|
||||
[dialects.sdy.MeshAxisAttr.get(name, size)
|
||||
for name, size in axis_context.mesh_shape])
|
||||
else:
|
||||
sdy_mesh_attr = dialects.sdy.MeshAttr.get([])
|
||||
|
||||
ctx.module.body.append(dialects.sdy.MeshOp("mesh", sdy_mesh_attr))
|
||||
module_name = _module_name_regex.sub("_", module_name)
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
|
||||
@ -1165,6 +1162,10 @@ def lower_jaxpr_to_module(
|
||||
arg_layouts=in_layouts,
|
||||
result_layouts=out_layouts,
|
||||
propagated_out_mem_kinds=propagated_out_mem_kinds)
|
||||
if config.use_shardy_partitioner.value:
|
||||
pipeline = passmanager.PassManager.parse(
|
||||
'builtin.module(sdy-lift-inlined-meshes)')
|
||||
pipeline.run(ctx.module.operation)
|
||||
|
||||
try:
|
||||
if not ctx.module.operation.verify():
|
||||
@ -1314,8 +1315,8 @@ def lower_jaxpr_to_fun(
|
||||
*,
|
||||
public: bool = False,
|
||||
replicated_args: Sequence[bool] | None = None,
|
||||
arg_shardings: Sequence[JSharding | None] | None = None,
|
||||
result_shardings: Sequence[JSharding | None] | None = None,
|
||||
arg_shardings: Sequence[JSharding | AUTO | None] | None = None,
|
||||
result_shardings: Sequence[JSharding | AUTO | None] | None = None,
|
||||
use_sharding_annotations: bool = True,
|
||||
input_output_aliases: Sequence[int | None] | None = None,
|
||||
xla_donated_args: Sequence[bool] | None = None,
|
||||
@ -1680,10 +1681,12 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
|
||||
# The below custom call achieves the sharding like above example.
|
||||
if config.use_shardy_partitioner.value:
|
||||
physical_ndim = core.physical_aval(aval).ndim
|
||||
s = sharding.SdyArraySharding(
|
||||
mesh_name='mesh',
|
||||
dimension_shardings=[sharding.SdyDimSharding(axes=[], is_closed=i >= aval.ndim)
|
||||
for i in range(physical_ndim)])
|
||||
s = sharding_impls.SdyArraySharding(
|
||||
mesh_shape=None,
|
||||
dimension_shardings=[
|
||||
sharding_impls.SdyDimSharding(axes=[], is_closed=i >= aval.ndim)
|
||||
for i in range(physical_ndim)
|
||||
])
|
||||
return wrap_with_sharding_op(ctx, val, aval, s)
|
||||
else:
|
||||
return wrap_with_sharding_op(
|
||||
@ -2410,7 +2413,7 @@ def _wrap_with_spmd_op(name: str,
|
||||
ctx: LoweringRuleContext,
|
||||
x: ir.Value,
|
||||
aval_out: core.AbstractValue,
|
||||
sharding: xc.OpSharding | sharding.SdyArraySharding,
|
||||
sharding: xc.OpSharding | sharding_impls.SdyArraySharding,
|
||||
unspecified_dims: set[int] | None = None,
|
||||
has_side_effect: bool = False,
|
||||
allow_shardy_lowering: bool = False):
|
||||
@ -2447,7 +2450,7 @@ wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
|
||||
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")
|
||||
|
||||
|
||||
def set_sharding(op, sharding: xc.OpSharding | sharding.SdyArraySharding):
|
||||
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
|
||||
if config.use_shardy_partitioner.value:
|
||||
op.attributes["sdy.sharding"] = get_sharding_attr(sharding)
|
||||
else:
|
||||
@ -2455,7 +2458,7 @@ def set_sharding(op, sharding: xc.OpSharding | sharding.SdyArraySharding):
|
||||
|
||||
|
||||
def get_sharding_attr(
|
||||
sharding: xc.OpSharding | sharding.SdyArraySharding
|
||||
sharding: xc.OpSharding | sharding_impls.SdyArraySharding
|
||||
) -> ir.Attribute:
|
||||
if config.use_shardy_partitioner.value:
|
||||
return sharding.build() # type: ignore
|
||||
|
@ -1892,7 +1892,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
propagated_out_mem_kinds: tuple[None | str, ...],
|
||||
platforms: tuple[str, ...],
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
mesh_shape_tuple: tuple[tuple[str, int], ...] | None):
|
||||
abstract_mesh: mesh_lib.AbstractMesh | None):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = semantic_in_shardings.shardings
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
@ -1914,8 +1914,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
nreps = dispatch.jaxpr_replicas(jaxpr)
|
||||
_raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr)
|
||||
|
||||
in_mlir_shardings: list[JSharding | None] | None
|
||||
out_mlir_shardings: list[JSharding | None] | None
|
||||
in_mlir_shardings: list[JSharding | AUTO | None] | None
|
||||
out_mlir_shardings: list[JSharding | AUTO | None] | None
|
||||
axis_ctx: mlir.AxisContext
|
||||
|
||||
if nreps == 1:
|
||||
@ -1923,7 +1923,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(global_in_avals)
|
||||
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment,
|
||||
mesh_shape_tuple)
|
||||
abstract_mesh)
|
||||
num_partitions = num_devices
|
||||
else:
|
||||
# This path is triggered for `jit(pmap)` cases.
|
||||
@ -2216,18 +2216,18 @@ def lower_sharding_computation(
|
||||
# 2. Build up the HLO
|
||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||
|
||||
mesh_shape_tuple = None
|
||||
if config.use_shardy_partitioner.value or prim_requires_devices:
|
||||
abstract_mesh = None
|
||||
if prim_requires_devices:
|
||||
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings]):
|
||||
if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)):
|
||||
if (mesh_shape_tuple is not None and
|
||||
mesh_shape_tuple != sharding.mesh.shape_tuple):
|
||||
if isinstance(sharding, sharding_impls.NamedSharding):
|
||||
if (abstract_mesh is not None and
|
||||
abstract_mesh != sharding.mesh.abstract_mesh):
|
||||
raise ValueError(
|
||||
"mesh should be the same across the entire program. Got mesh"
|
||||
f" shape for one sharding {mesh_shape_tuple} and"
|
||||
f" {sharding.mesh.shape_tuple} for another")
|
||||
mesh_shape_tuple = sharding.mesh.shape_tuple
|
||||
f" shape for one sharding {abstract_mesh} and"
|
||||
f" {sharding.mesh.abstract_mesh} for another")
|
||||
abstract_mesh = sharding.mesh.abstract_mesh # type: ignore
|
||||
|
||||
semantic_in_shardings = SemanticallyEqualShardings(
|
||||
in_shardings, global_in_avals) # type: ignore
|
||||
@ -2242,7 +2242,7 @@ def lower_sharding_computation(
|
||||
name_stack, all_default_mem_kind, inout_aliases,
|
||||
propagated_out_mem_kinds, platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
mesh_shape_tuple=mesh_shape_tuple)
|
||||
abstract_mesh=abstract_mesh)
|
||||
|
||||
# backend and device_assignment is passed through to MeshExecutable because
|
||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||
@ -2285,9 +2285,11 @@ def lower_sharding_computation(
|
||||
|
||||
def _to_logical_sharding(
|
||||
aval: core.AbstractValue, sharding: MaybeSharding | AUTO
|
||||
) -> JSharding | None:
|
||||
if is_unspecified(sharding) or is_auto(sharding):
|
||||
) -> JSharding | AUTO | None:
|
||||
if isinstance(sharding, UnspecifiedValue):
|
||||
return None
|
||||
if isinstance(sharding, AUTO):
|
||||
return sharding
|
||||
elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)):
|
||||
assert isinstance(sharding, JSharding)
|
||||
return sharding
|
||||
|
@ -247,6 +247,10 @@ class Mesh(contextlib.ContextDecorator):
|
||||
(name, size)
|
||||
for name, size in util.safe_zip(self.axis_names, self.devices.shape))
|
||||
|
||||
@property
|
||||
def axis_sizes(self) -> tuple[int, ...]:
|
||||
return self.devices.shape
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return math.prod(self.shape.values()) if self.devices.ndim else 0
|
||||
@ -361,6 +365,10 @@ class AbstractMesh:
|
||||
def axis_names(self):
|
||||
return self._axis_names
|
||||
|
||||
@property
|
||||
def axis_sizes(self) -> tuple[int, ...]:
|
||||
return self._axis_sizes
|
||||
|
||||
@functools.cached_property
|
||||
def size(self):
|
||||
return math.prod(self._axis_sizes) if self._axis_sizes else 0
|
||||
|
@ -15,13 +15,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
|
||||
from jax._src.util import safe_zip, use_cpp_class, cache
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.op_shardings import (
|
||||
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated,
|
||||
op_sharding_to_indices)
|
||||
@ -78,38 +76,6 @@ def _common_shard_shape(self, global_shape: Shape) -> Shape:
|
||||
return tuple(out)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SdyDimSharding:
|
||||
axes: Sequence[str]
|
||||
is_closed: bool
|
||||
priority: int | None = None
|
||||
|
||||
def build(self) -> sdy.DimensionShardingAttr:
|
||||
"""Builds the attribute.
|
||||
|
||||
NOTE: An MLIR context is required as a context manager.
|
||||
"""
|
||||
return sdy.DimensionShardingAttr.get(
|
||||
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
|
||||
is_closed=self.is_closed,
|
||||
priority=self.priority)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SdyArraySharding:
|
||||
mesh_name: str
|
||||
dimension_shardings: Sequence[SdyDimSharding]
|
||||
|
||||
def build(self) -> sdy.TensorShardingAttr:
|
||||
"""Builds the attribute.
|
||||
|
||||
NOTE: An MLIR context is required as a context manager.
|
||||
"""
|
||||
return sdy.TensorShardingAttr.get(
|
||||
self.mesh_name,
|
||||
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
|
||||
|
||||
|
||||
@use_cpp_class(xc.Sharding)
|
||||
class Sharding:
|
||||
"""Describes how a :class:`jax.Array` is laid out across devices.
|
||||
@ -165,7 +131,7 @@ class Sharding:
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
||||
def _to_sdy_sharding(self, num_dimensions: int):
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
#############################################################################
|
||||
|
@ -32,6 +32,7 @@ from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import mesh_utils
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.op_shardings import (
|
||||
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -93,6 +94,37 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
|
||||
return out
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SdyDimSharding:
|
||||
axes: Sequence[str]
|
||||
is_closed: bool
|
||||
priority: int | None = None
|
||||
|
||||
# NOTE: An MLIR context is required as a context manager.
|
||||
def build(self) -> sdy.DimensionShardingAttr:
|
||||
return sdy.DimensionShardingAttr.get(
|
||||
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
|
||||
is_closed=self.is_closed,
|
||||
priority=self.priority)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SdyArraySharding:
|
||||
mesh_shape: tuple[tuple[str, int], ...] | None
|
||||
dimension_shardings: Sequence[SdyDimSharding]
|
||||
|
||||
# NOTE: An MLIR context is required as a context manager.
|
||||
def build(self) -> sdy.TensorShardingAttr:
|
||||
if self.mesh_shape is None:
|
||||
mesh_attr = sdy.MeshAttr.get([])
|
||||
else:
|
||||
mesh_attr = sdy.MeshAttr.get([sdy.MeshAxisAttr.get(name, size)
|
||||
for name, size in self.mesh_shape])
|
||||
return sdy.TensorShardingAttr.get(
|
||||
mesh_attr,
|
||||
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
|
||||
|
||||
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def named_sharding_to_xla_hlo_sharding(
|
||||
self, num_dimensions: int) -> xc.HloSharding:
|
||||
@ -325,8 +357,8 @@ class NamedSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
|
||||
dim_shardings = [sharding.SdyDimSharding(axes=[], is_closed=True)
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
||||
dim_shardings = [SdyDimSharding(axes=[], is_closed=True)
|
||||
for _ in range(num_dimensions)]
|
||||
for i, dim_spec in enumerate(self._parsed_pspec):
|
||||
if dim_spec is None:
|
||||
@ -336,7 +368,7 @@ class NamedSharding(sharding.Sharding):
|
||||
pass
|
||||
else:
|
||||
dim_shardings[i].axes = dim_spec
|
||||
return sharding.SdyArraySharding('mesh', dim_shardings)
|
||||
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings)
|
||||
|
||||
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
@ -410,11 +442,10 @@ class SingleDeviceSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return get_replicated_hlo_sharding()
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
|
||||
return sharding.SdyArraySharding(
|
||||
'mesh',
|
||||
[sharding.SdyDimSharding(axes=[], is_closed=True)
|
||||
for _ in range(num_dimensions)])
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
||||
sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True)
|
||||
for _ in range(num_dimensions)]
|
||||
return SdyArraySharding(None, sdy_dim_sharding)
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
@ -552,7 +583,7 @@ class PmapSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
raise NotImplementedError("pmap doesn't use OpSharding.")
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
||||
raise NotImplementedError("pmap doesn't use SdyArraySharding.")
|
||||
|
||||
@functools.cached_property
|
||||
@ -758,7 +789,7 @@ class PositionalSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return _positional_sharding_to_xla_hlo_sharding(self, num_dimensions)
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
||||
raise NotImplementedError(
|
||||
"PositionalSharding can't be converted to an SdyArraySharding.")
|
||||
|
||||
@ -875,7 +906,7 @@ class GSPMDSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return self._hlo_sharding
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
||||
raise NotImplementedError(
|
||||
"GSPMDSharding can't be converted to SdyArraySharding.")
|
||||
|
||||
@ -898,6 +929,11 @@ class AUTO:
|
||||
def __init__(self, mesh: mesh_lib.Mesh):
|
||||
self.mesh = mesh
|
||||
|
||||
def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding:
|
||||
dim_shardings = [SdyDimSharding(axes=[], is_closed=False)
|
||||
for _ in range(ndim)]
|
||||
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings)
|
||||
|
||||
|
||||
def is_auto(x):
|
||||
return isinstance(x, AUTO)
|
||||
@ -1145,7 +1181,7 @@ class ShardingContext:
|
||||
"""
|
||||
num_devices: int
|
||||
device_assignment: tuple[xc.Device, ...] | None = None
|
||||
mesh_shape: tuple[tuple[str, int], ...] | None = None
|
||||
abstract_mesh: mesh_lib.AbstractMesh | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device_assignment is not None:
|
||||
|
@ -39,6 +39,7 @@ PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) {
|
||||
REGISTER_DIALECT(nvvm);
|
||||
REGISTER_DIALECT(llvm);
|
||||
mlirRegisterTransformsPasses();
|
||||
// For Shardy
|
||||
mlirRegisterAllSdyPassesAndPipelines();
|
||||
// Transforms used by JAX.
|
||||
mlirRegisterTransformsStripDebugInfo();
|
||||
|
@ -32,11 +32,11 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import dialects, ir
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.sharding import common_devices_indices_map, SdyDimSharding, SdyArraySharding
|
||||
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
|
||||
pmap_sharding_devices_indices_map,
|
||||
NamedSharding, GSPMDSharding,
|
||||
PositionalSharding)
|
||||
from jax._src.sharding import common_devices_indices_map
|
||||
from jax._src.sharding_impls import (
|
||||
_op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map,
|
||||
NamedSharding, GSPMDSharding, PositionalSharding, SdyDimSharding,
|
||||
SdyArraySharding)
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.sharding import PartitionSpec as P
|
||||
@ -1306,15 +1306,18 @@ class ShardyShardingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(
|
||||
sdy_sharding,
|
||||
SdyArraySharding(
|
||||
'mesh',
|
||||
[SdyDimSharding(('sequence', 'data'), True),
|
||||
mesh.shape_tuple,
|
||||
[SdyDimSharding(
|
||||
('sequence', 'data'), True),
|
||||
SdyDimSharding(('model',), True),
|
||||
SdyDimSharding([], True)]))
|
||||
with ir.Context() as ctx:
|
||||
dialects.sdy.register_dialect(ctx)
|
||||
self.assertEqual(
|
||||
str(sdy_sharding.build()),
|
||||
'#sdy.sharding<@mesh, [{"sequence", "data"}, {"model"}, {}]>')
|
||||
'#sdy.sharding<mesh<["sequence"=2, "data"=2, "model"=2]>,'
|
||||
' [{"sequence", "data"}, {"model"}, {}]>',
|
||||
)
|
||||
|
||||
def test_unconstrained(self):
|
||||
mesh = jtu.create_mesh((8,), ('x',))
|
||||
@ -1323,14 +1326,15 @@ class ShardyShardingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(
|
||||
sdy_sharding,
|
||||
SdyArraySharding(
|
||||
'mesh',
|
||||
mesh.shape_tuple,
|
||||
[SdyDimSharding([], True),
|
||||
SdyDimSharding([], False),
|
||||
SdyDimSharding(('x',), True)]))
|
||||
with ir.Context() as ctx:
|
||||
dialects.sdy.register_dialect(ctx)
|
||||
self.assertEqual(
|
||||
str(sdy_sharding.build()), '#sdy.sharding<@mesh, [{}, {?}, {"x"}]>')
|
||||
str(sdy_sharding.build()),
|
||||
'#sdy.sharding<mesh<["x"=8]>, [{}, {?}, {"x"}]>')
|
||||
|
||||
|
||||
class RngShardingTest(jtu.JaxTestCase):
|
||||
|
@ -4021,7 +4021,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
lowered_text = make_keys.lower(seeds).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text)
|
||||
self.assertIn('<@empty_mesh, [{?}, {?}, {}]>', lowered_text)
|
||||
else:
|
||||
self.assertIn('unspecified_dims=[0,1]', lowered_text)
|
||||
|
||||
@ -4050,7 +4050,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
lowered_text = make_keys.lower(seeds).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text)
|
||||
self.assertIn('<@empty_mesh, [{?}, {?}, {}]>', lowered_text)
|
||||
else:
|
||||
self.assertIn('unspecified_dims=[0,1]', lowered_text)
|
||||
|
||||
@ -4077,7 +4077,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
lowered_text = make_keys.lower(seeds).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn('<@mesh, [{?}, {?}, {?}, {}]>', lowered_text)
|
||||
self.assertIn('<@empty_mesh, [{?}, {?}, {?}, {}]>', lowered_text)
|
||||
else:
|
||||
self.assertIn('unspecified_dims=[0,1,2]', lowered_text)
|
||||
|
||||
@ -5476,12 +5476,13 @@ class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
@jtu.with_config(jax_use_shardy_partitioner=True)
|
||||
class SdyIntegrationTest(jtu.JaxTestCase):
|
||||
class ShardyTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove setUp.
|
||||
def setUp(self):
|
||||
if not dialects.sdy:
|
||||
raise unittest.SkipTest('Shardy is not available.')
|
||||
super().setUp()
|
||||
|
||||
def test_lowering_input_output_sharding(self):
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user