[Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.

Also do a couple of cleanups.

PiperOrigin-RevId: 685746298
This commit is contained in:
Yash Katariya 2024-10-14 10:07:08 -07:00 committed by jax authors
parent 75e22f2ccd
commit 824ccd7183
9 changed files with 133 additions and 111 deletions

View File

@ -493,9 +493,10 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
if devices is None:
raise AssertionError(
'Please file a bug at https://github.com/jax-ml/jax/issues')
if axis_context.mesh_shape is not None:
ma, ms = list(zip(*axis_context.mesh_shape))
mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma)
am = axis_context.abstract_mesh
if am is not None:
mesh = mesh_lib.Mesh(np.array(devices).reshape(am.axis_sizes),
am.axis_names)
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = axis_context.mesh._flat_devices_tuple
else:

View File

@ -41,7 +41,6 @@ from jax._src import effects as effects_lib
from jax._src import linear_util as lu
from jax._src import path
from jax._src import pickle_util
from jax._src import sharding
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
@ -50,12 +49,11 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import AUTO
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import dialects, ir, passmanager
from jax._src.lib.mlir.dialects import func as func_dialect, hlo
from jax._src.lib.mlir import register_jax_dialects
from jax._src.state.types import AbstractRef
@ -900,10 +898,12 @@ def unflatten_ir_values_like_types(xs: Iterable[ir.Value],
_module_name_regex = re.compile(r"[^\w.-]")
def sharded_aval(aval: core.AbstractValue,
sharding: JSharding | None) -> core.AbstractValue:
sharding: JSharding | AUTO | None) -> core.AbstractValue:
"""Returns the new aval sharded based on sharding proto."""
if sharding is None:
return aval
if isinstance(sharding, AUTO):
return aval
if isinstance(aval, core.AbstractToken):
return aval
if not isinstance(aval, (core.ShapedArray, core.DShapedArray)):
@ -991,10 +991,14 @@ def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim):
def _to_physical_op_sharding(
ctx: ModuleContext,
aval: core.AbstractValue, sharding: JSharding | None,
) -> xc.OpSharding | sharding.SdyArraySharding | None:
aval: core.AbstractValue, sharding: JSharding | AUTO | None,
) -> xc.OpSharding | sharding_impls.SdyArraySharding | None:
if sharding is None:
return None
if isinstance(sharding, AUTO):
if config.use_shardy_partitioner.value:
return sharding._to_sdy_sharding(aval.ndim) # type: ignore
return None
assert isinstance(sharding, JSharding)
if isinstance(aval, AbstractRef):
return _to_physical_op_sharding(ctx, aval.inner_aval, sharding)
@ -1022,9 +1026,11 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout,
return str(layout._to_xla_layout(aval.dtype)) # type: ignore
def _get_mem_kind(s: JSharding | None) -> str | None:
def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
if s is None:
return None
if isinstance(s, AUTO):
return None
assert isinstance(s, JSharding)
return s.memory_kind
@ -1040,8 +1046,8 @@ def lower_jaxpr_to_module(
name_stack: source_info_util.NameStack,
donated_args: Sequence[bool],
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[JSharding | None] | None = None,
result_shardings: Sequence[JSharding | None] | None = None,
arg_shardings: Sequence[JSharding | AUTO | None] | None = None,
result_shardings: Sequence[JSharding | AUTO | None] | None = None,
in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
arg_names: Sequence[str | None] | None = None,
@ -1084,8 +1090,9 @@ def lower_jaxpr_to_module(
"In multi-platform lowering either all or no lowering platforms "
f"should support donation. Lowering for {platforms} of which "
f"only {platforms_with_donation} support donation")
if num_partitions > 1 and (
result_shardings is None or all(s is None for s in result_shardings)):
if (num_partitions > 1 and
(result_shardings is None or
all(s is None or isinstance(s, AUTO) for s in result_shardings))):
xla_donated_args = donated_args
donated_args = [False] * len(donated_args)
if xla_donated_args is None:
@ -1135,16 +1142,6 @@ def lower_jaxpr_to_module(
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
if config.use_shardy_partitioner.value:
if (isinstance(axis_context, sharding_impls.ShardingContext) and
axis_context.mesh_shape is not None):
sdy_mesh_attr = dialects.sdy.MeshAttr.get(
[dialects.sdy.MeshAxisAttr.get(name, size)
for name, size in axis_context.mesh_shape])
else:
sdy_mesh_attr = dialects.sdy.MeshAttr.get([])
ctx.module.body.append(dialects.sdy.MeshOp("mesh", sdy_mesh_attr))
module_name = _module_name_regex.sub("_", module_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
@ -1165,6 +1162,10 @@ def lower_jaxpr_to_module(
arg_layouts=in_layouts,
result_layouts=out_layouts,
propagated_out_mem_kinds=propagated_out_mem_kinds)
if config.use_shardy_partitioner.value:
pipeline = passmanager.PassManager.parse(
'builtin.module(sdy-lift-inlined-meshes)')
pipeline.run(ctx.module.operation)
try:
if not ctx.module.operation.verify():
@ -1314,8 +1315,8 @@ def lower_jaxpr_to_fun(
*,
public: bool = False,
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[JSharding | None] | None = None,
result_shardings: Sequence[JSharding | None] | None = None,
arg_shardings: Sequence[JSharding | AUTO | None] | None = None,
result_shardings: Sequence[JSharding | AUTO | None] | None = None,
use_sharding_annotations: bool = True,
input_output_aliases: Sequence[int | None] | None = None,
xla_donated_args: Sequence[bool] | None = None,
@ -1680,10 +1681,12 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
# The below custom call achieves the sharding like above example.
if config.use_shardy_partitioner.value:
physical_ndim = core.physical_aval(aval).ndim
s = sharding.SdyArraySharding(
mesh_name='mesh',
dimension_shardings=[sharding.SdyDimSharding(axes=[], is_closed=i >= aval.ndim)
for i in range(physical_ndim)])
s = sharding_impls.SdyArraySharding(
mesh_shape=None,
dimension_shardings=[
sharding_impls.SdyDimSharding(axes=[], is_closed=i >= aval.ndim)
for i in range(physical_ndim)
])
return wrap_with_sharding_op(ctx, val, aval, s)
else:
return wrap_with_sharding_op(
@ -2410,7 +2413,7 @@ def _wrap_with_spmd_op(name: str,
ctx: LoweringRuleContext,
x: ir.Value,
aval_out: core.AbstractValue,
sharding: xc.OpSharding | sharding.SdyArraySharding,
sharding: xc.OpSharding | sharding_impls.SdyArraySharding,
unspecified_dims: set[int] | None = None,
has_side_effect: bool = False,
allow_shardy_lowering: bool = False):
@ -2447,7 +2450,7 @@ wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")
def set_sharding(op, sharding: xc.OpSharding | sharding.SdyArraySharding):
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
if config.use_shardy_partitioner.value:
op.attributes["sdy.sharding"] = get_sharding_attr(sharding)
else:
@ -2455,7 +2458,7 @@ def set_sharding(op, sharding: xc.OpSharding | sharding.SdyArraySharding):
def get_sharding_attr(
sharding: xc.OpSharding | sharding.SdyArraySharding
sharding: xc.OpSharding | sharding_impls.SdyArraySharding
) -> ir.Attribute:
if config.use_shardy_partitioner.value:
return sharding.build() # type: ignore

View File

@ -1892,7 +1892,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
propagated_out_mem_kinds: tuple[None | str, ...],
platforms: tuple[str, ...],
lowering_parameters: mlir.LoweringParameters,
mesh_shape_tuple: tuple[tuple[str, int], ...] | None):
abstract_mesh: mesh_lib.AbstractMesh | None):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
@ -1914,8 +1914,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
nreps = dispatch.jaxpr_replicas(jaxpr)
_raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr)
in_mlir_shardings: list[JSharding | None] | None
out_mlir_shardings: list[JSharding | None] | None
in_mlir_shardings: list[JSharding | AUTO | None] | None
out_mlir_shardings: list[JSharding | AUTO | None] | None
axis_ctx: mlir.AxisContext
if nreps == 1:
@ -1923,7 +1923,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment,
mesh_shape_tuple)
abstract_mesh)
num_partitions = num_devices
else:
# This path is triggered for `jit(pmap)` cases.
@ -2216,18 +2216,18 @@ def lower_sharding_computation(
# 2. Build up the HLO
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
mesh_shape_tuple = None
if config.use_shardy_partitioner.value or prim_requires_devices:
abstract_mesh = None
if prim_requires_devices:
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
[js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)):
if (mesh_shape_tuple is not None and
mesh_shape_tuple != sharding.mesh.shape_tuple):
if isinstance(sharding, sharding_impls.NamedSharding):
if (abstract_mesh is not None and
abstract_mesh != sharding.mesh.abstract_mesh):
raise ValueError(
"mesh should be the same across the entire program. Got mesh"
f" shape for one sharding {mesh_shape_tuple} and"
f" {sharding.mesh.shape_tuple} for another")
mesh_shape_tuple = sharding.mesh.shape_tuple
f" shape for one sharding {abstract_mesh} and"
f" {sharding.mesh.abstract_mesh} for another")
abstract_mesh = sharding.mesh.abstract_mesh # type: ignore
semantic_in_shardings = SemanticallyEqualShardings(
in_shardings, global_in_avals) # type: ignore
@ -2242,7 +2242,7 @@ def lower_sharding_computation(
name_stack, all_default_mem_kind, inout_aliases,
propagated_out_mem_kinds, platforms,
lowering_parameters=lowering_parameters,
mesh_shape_tuple=mesh_shape_tuple)
abstract_mesh=abstract_mesh)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
@ -2285,9 +2285,11 @@ def lower_sharding_computation(
def _to_logical_sharding(
aval: core.AbstractValue, sharding: MaybeSharding | AUTO
) -> JSharding | None:
if is_unspecified(sharding) or is_auto(sharding):
) -> JSharding | AUTO | None:
if isinstance(sharding, UnspecifiedValue):
return None
if isinstance(sharding, AUTO):
return sharding
elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)):
assert isinstance(sharding, JSharding)
return sharding

View File

@ -247,6 +247,10 @@ class Mesh(contextlib.ContextDecorator):
(name, size)
for name, size in util.safe_zip(self.axis_names, self.devices.shape))
@property
def axis_sizes(self) -> tuple[int, ...]:
return self.devices.shape
@property
def size(self):
return math.prod(self.shape.values()) if self.devices.ndim else 0
@ -361,6 +365,10 @@ class AbstractMesh:
def axis_names(self):
return self._axis_names
@property
def axis_sizes(self) -> tuple[int, ...]:
return self._axis_sizes
@functools.cached_property
def size(self):
return math.prod(self._axis_sizes) if self._axis_sizes else 0

View File

@ -15,13 +15,11 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
import dataclasses
import functools
from jax._src.util import safe_zip, use_cpp_class, cache
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import sdy
from jax._src.op_shardings import (
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated,
op_sharding_to_indices)
@ -78,38 +76,6 @@ def _common_shard_shape(self, global_shape: Shape) -> Shape:
return tuple(out)
@dataclasses.dataclass
class SdyDimSharding:
axes: Sequence[str]
is_closed: bool
priority: int | None = None
def build(self) -> sdy.DimensionShardingAttr:
"""Builds the attribute.
NOTE: An MLIR context is required as a context manager.
"""
return sdy.DimensionShardingAttr.get(
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
is_closed=self.is_closed,
priority=self.priority)
@dataclasses.dataclass
class SdyArraySharding:
mesh_name: str
dimension_shardings: Sequence[SdyDimSharding]
def build(self) -> sdy.TensorShardingAttr:
"""Builds the attribute.
NOTE: An MLIR context is required as a context manager.
"""
return sdy.TensorShardingAttr.get(
self.mesh_name,
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
@use_cpp_class(xc.Sharding)
class Sharding:
"""Describes how a :class:`jax.Array` is laid out across devices.
@ -165,7 +131,7 @@ class Sharding:
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
raise NotImplementedError('Subclasses should implement this method.')
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
def _to_sdy_sharding(self, num_dimensions: int):
raise NotImplementedError('Subclasses should implement this method.')
#############################################################################

View File

@ -32,6 +32,7 @@ from jax._src import util
from jax._src import xla_bridge
from jax._src import mesh_utils
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import sdy
from jax._src.op_shardings import (
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
from jax._src.partition_spec import PartitionSpec
@ -93,6 +94,37 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
return out
@dataclasses.dataclass
class SdyDimSharding:
axes: Sequence[str]
is_closed: bool
priority: int | None = None
# NOTE: An MLIR context is required as a context manager.
def build(self) -> sdy.DimensionShardingAttr:
return sdy.DimensionShardingAttr.get(
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
is_closed=self.is_closed,
priority=self.priority)
@dataclasses.dataclass
class SdyArraySharding:
mesh_shape: tuple[tuple[str, int], ...] | None
dimension_shardings: Sequence[SdyDimSharding]
# NOTE: An MLIR context is required as a context manager.
def build(self) -> sdy.TensorShardingAttr:
if self.mesh_shape is None:
mesh_attr = sdy.MeshAttr.get([])
else:
mesh_attr = sdy.MeshAttr.get([sdy.MeshAxisAttr.get(name, size)
for name, size in self.mesh_shape])
return sdy.TensorShardingAttr.get(
mesh_attr,
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
@util.cache(max_size=4096, trace_context_in_key=False)
def named_sharding_to_xla_hlo_sharding(
self, num_dimensions: int) -> xc.HloSharding:
@ -325,8 +357,8 @@ class NamedSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
dim_shardings = [sharding.SdyDimSharding(axes=[], is_closed=True)
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
dim_shardings = [SdyDimSharding(axes=[], is_closed=True)
for _ in range(num_dimensions)]
for i, dim_spec in enumerate(self._parsed_pspec):
if dim_spec is None:
@ -336,7 +368,7 @@ class NamedSharding(sharding.Sharding):
pass
else:
dim_shardings[i].axes = dim_spec
return sharding.SdyArraySharding('mesh', dim_shardings)
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings)
@util.cache(max_size=128, trace_context_in_key=False)
@ -410,11 +442,10 @@ class SingleDeviceSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return get_replicated_hlo_sharding()
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
return sharding.SdyArraySharding(
'mesh',
[sharding.SdyDimSharding(axes=[], is_closed=True)
for _ in range(num_dimensions)])
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True)
for _ in range(num_dimensions)]
return SdyArraySharding(None, sdy_dim_sharding)
@property
def is_fully_replicated(self) -> bool:
@ -552,7 +583,7 @@ class PmapSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
raise NotImplementedError("pmap doesn't use OpSharding.")
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
raise NotImplementedError("pmap doesn't use SdyArraySharding.")
@functools.cached_property
@ -758,7 +789,7 @@ class PositionalSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return _positional_sharding_to_xla_hlo_sharding(self, num_dimensions)
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
raise NotImplementedError(
"PositionalSharding can't be converted to an SdyArraySharding.")
@ -875,7 +906,7 @@ class GSPMDSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return self._hlo_sharding
def _to_sdy_sharding(self, num_dimensions: int) -> sharding.SdyArraySharding:
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
raise NotImplementedError(
"GSPMDSharding can't be converted to SdyArraySharding.")
@ -898,6 +929,11 @@ class AUTO:
def __init__(self, mesh: mesh_lib.Mesh):
self.mesh = mesh
def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding:
dim_shardings = [SdyDimSharding(axes=[], is_closed=False)
for _ in range(ndim)]
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings)
def is_auto(x):
return isinstance(x, AUTO)
@ -1145,7 +1181,7 @@ class ShardingContext:
"""
num_devices: int
device_assignment: tuple[xc.Device, ...] | None = None
mesh_shape: tuple[tuple[str, int], ...] | None = None
abstract_mesh: mesh_lib.AbstractMesh | None = None
def __post_init__(self):
if self.device_assignment is not None:

View File

@ -39,6 +39,7 @@ PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) {
REGISTER_DIALECT(nvvm);
REGISTER_DIALECT(llvm);
mlirRegisterTransformsPasses();
// For Shardy
mlirRegisterAllSdyPassesAndPipelines();
// Transforms used by JAX.
mlirRegisterTransformsStripDebugInfo();

View File

@ -32,11 +32,11 @@ from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import dialects, ir
from jax._src.util import safe_zip
from jax._src.sharding import common_devices_indices_map, SdyDimSharding, SdyArraySharding
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
pmap_sharding_devices_indices_map,
NamedSharding, GSPMDSharding,
PositionalSharding)
from jax._src.sharding import common_devices_indices_map
from jax._src.sharding_impls import (
_op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map,
NamedSharding, GSPMDSharding, PositionalSharding, SdyDimSharding,
SdyArraySharding)
from jax.experimental.pjit import pjit
from jax.experimental import multihost_utils
from jax.sharding import PartitionSpec as P
@ -1306,15 +1306,18 @@ class ShardyShardingTest(jtu.JaxTestCase):
self.assertEqual(
sdy_sharding,
SdyArraySharding(
'mesh',
[SdyDimSharding(('sequence', 'data'), True),
mesh.shape_tuple,
[SdyDimSharding(
('sequence', 'data'), True),
SdyDimSharding(('model',), True),
SdyDimSharding([], True)]))
with ir.Context() as ctx:
dialects.sdy.register_dialect(ctx)
self.assertEqual(
str(sdy_sharding.build()),
'#sdy.sharding<@mesh, [{"sequence", "data"}, {"model"}, {}]>')
'#sdy.sharding<mesh<["sequence"=2, "data"=2, "model"=2]>,'
' [{"sequence", "data"}, {"model"}, {}]>',
)
def test_unconstrained(self):
mesh = jtu.create_mesh((8,), ('x',))
@ -1323,14 +1326,15 @@ class ShardyShardingTest(jtu.JaxTestCase):
self.assertEqual(
sdy_sharding,
SdyArraySharding(
'mesh',
mesh.shape_tuple,
[SdyDimSharding([], True),
SdyDimSharding([], False),
SdyDimSharding(('x',), True)]))
with ir.Context() as ctx:
dialects.sdy.register_dialect(ctx)
self.assertEqual(
str(sdy_sharding.build()), '#sdy.sharding<@mesh, [{}, {?}, {"x"}]>')
str(sdy_sharding.build()),
'#sdy.sharding<mesh<["x"=8]>, [{}, {?}, {"x"}]>')
class RngShardingTest(jtu.JaxTestCase):

View File

@ -4021,7 +4021,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
lowered_text = make_keys.lower(seeds).as_text()
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text)
self.assertIn('<@empty_mesh, [{?}, {?}, {}]>', lowered_text)
else:
self.assertIn('unspecified_dims=[0,1]', lowered_text)
@ -4050,7 +4050,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
lowered_text = make_keys.lower(seeds).as_text()
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text)
self.assertIn('<@empty_mesh, [{?}, {?}, {}]>', lowered_text)
else:
self.assertIn('unspecified_dims=[0,1]', lowered_text)
@ -4077,7 +4077,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
lowered_text = make_keys.lower(seeds).as_text()
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {?}, {?}, {}]>', lowered_text)
self.assertIn('<@empty_mesh, [{?}, {?}, {?}, {}]>', lowered_text)
else:
self.assertIn('unspecified_dims=[0,1,2]', lowered_text)
@ -5476,12 +5476,13 @@ class UtilTest(jtu.JaxTestCase):
@jtu.with_config(jax_use_shardy_partitioner=True)
class SdyIntegrationTest(jtu.JaxTestCase):
class ShardyTest(jtu.JaxTestCase):
# TODO(bartchr): Once JAX is released with SDY, remove setUp.
def setUp(self):
if not dialects.sdy:
raise unittest.SkipTest('Shardy is not available.')
super().setUp()
def test_lowering_input_output_sharding(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))