1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

#sdy Initial set of changes to allow for lowering to the Shardy dialect.

The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
This commit is contained in:
Bart Chrzaszcz 2024-07-23 05:31:15 -07:00 committed by jax authors
parent 459b83cf4a
commit 864178d3a3
16 changed files with 260 additions and 24 deletions

@ -33,6 +33,7 @@ from jax._src import profiler
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np
@ -113,6 +114,7 @@ def get_compile_options(
num_partitions: int,
device_assignment=None,
use_spmd_partitioning: bool = True,
use_shardy_partitioner: bool = False,
use_auto_spmd_partitioning: bool = False,
auto_spmd_partitioning_mesh_shape: list[int] | None = None,
auto_spmd_partitioning_mesh_ids: list[int] | None = None,
@ -132,6 +134,10 @@ def get_compile_options(
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_shardy_partitioner: boolean indicating whether to use the Shardy
partitioner in XLA. Shardy is a new open sourced propagation framework for
MLIR. Currently Shardy is experimental in JAX. See
www.github.com/openxla/shardy.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
@ -141,8 +147,8 @@ def get_compile_options(
env_options_overrides: dict of additional options parsed by the compiler
fdo_profile: Optional profile for feedback-directed optimization passed to
XLA.
detailed_logging: Is this an "interesting" computation about which XLA
would be wise to log compilation information?
detailed_logging: Is this an "interesting" computation about which XLA would
be wise to log compilation information?
backend: the client, if available.
"""
compile_options = xc.CompileOptions()
@ -194,6 +200,11 @@ def get_compile_options(
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
# TODO(b/352486192): Set this on compile_options after the field is moved to
# the `ExecutableBuildOptions` proto.
if xla_extension_version >= 278:
debug_options.xla_use_shardy = use_shardy_partitioner
# XLA-AutoFDO profile version: precedence order is:
# 1. Whatever --jax_xla_profile_version is set to.
# 2. If --jax_xla_profile_version is not set (i.e., 0), call the function
@ -227,6 +238,7 @@ def get_compile_options(
return compile_options
@profiler.annotate_function
def backend_compile(
backend: xc.Client,

@ -221,7 +221,8 @@ def trace_context():
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value,
pgle_profiling_runs.value,
enable_pgle.value)
enable_pgle.value,
use_shardy_partitioner.value)
config = Config()
@ -829,6 +830,7 @@ class _GlobalExtraJitContext(NamedTuple):
xla_profile_version: int = 0
pgle_profiling_runs: int = 0
enable_pgle: bool = False
use_shardy_partitioner: bool = False
def _update_global_jit_state(**kw):
@ -1678,3 +1680,20 @@ pmap_no_rank_reduction = bool_state(
"If True, pmap shards have a the same rank as their enclosing array."
)
)
use_shardy_partitioner = bool_state(
name='jax_use_shardy_partitioner',
default=False,
upgrade=True,
help=(
'Whether to lower to Shardy. Shardy is a new open sourced propagation '
'framework for MLIR. Currently Shardy is experimental in JAX. See '
'www.github.com/openxla/shardy'
),
update_global_hook=lambda val: _update_global_jit_state(
use_shardy_partitioner=val
),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
use_shardy_partitioner=val
),
)

@ -506,7 +506,9 @@ def make_ir_context() -> ir.Context:
# we don't do any heavy computation on MLIR modules from Python anyway, so we
# just disable threading.
context.enable_multithreading(False)
# TODO(bartchr): Once JAX is released with SDY, remove the if.
if dialects.sdy:
dialects.sdy.register_dialect(context)
dialects.mhlo.register_mhlo_dialect(context)
dialects.chlo.register_dialect(context)
dialects.hlo.register_dialect(context)
@ -874,6 +876,8 @@ _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
def _to_physical_op_sharding(
aval: core.AbstractValue, sharding: JSharding | None,
) -> xc.OpSharding | None:
# TODO(bartchr): add `dialects.sdy.TensorShardingAttr` to func type once JAX
# is released with SDY.
if sharding is None:
return None
assert isinstance(sharding, JSharding)
@ -883,6 +887,8 @@ def _to_physical_op_sharding(
if dtypes.issubdtype(aval.dtype, dtypes.extended):
sharding = sharding_impls.physical_sharding(aval, sharding)
aval = core.physical_aval(aval)
if config.use_shardy_partitioner.value:
return sharding._to_sdy_sharding(aval.ndim)
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
@ -927,6 +933,7 @@ def lower_jaxpr_to_module(
input_output_aliases: None | tuple[int | None, ...] = None,
propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
lowering_parameters: LoweringParameters,
mesh_shape_tuple: tuple[tuple[str, int], ...] | None = None,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
@ -1012,6 +1019,14 @@ def lower_jaxpr_to_module(
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
if config.use_shardy_partitioner.value:
assert mesh_shape_tuple is not None
ctx.module.body.append(
dialects.sdy.MeshOp(
"mesh",
dialects.sdy.MeshAttr.get(
[dialects.sdy.MeshAxisAttr.get(name, size)
for name, size in mesh_shape_tuple])))
module_name = _module_name_regex.sub("_", module_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
@ -1053,6 +1068,7 @@ def lower_jaxpr_to_module(
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
ctx.shape_poly_state)
def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args,
arg_memory_kinds, result_memory_kinds):
if input_output_aliases is None:
@ -1330,7 +1346,10 @@ def lower_jaxpr_to_fun(
if use_sharding_annotations and ir_arg_shardings is not None:
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
if sharding is not None:
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
if config.use_shardy_partitioner.value:
attrs["sdy.sharding"] = get_sharding_attr(sharding)
else:
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
if ir_arg_memory_kinds is not None:
for attrs, memory_kind in zip(arg_attrs, ir_arg_memory_kinds):
@ -1394,7 +1413,10 @@ def lower_jaxpr_to_fun(
if use_sharding_annotations and ir_result_shardings is not None:
for attrs, sharding in zip(result_attrs, ir_result_shardings):
if sharding is not None:
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
if config.use_shardy_partitioner.value:
attrs["sdy.sharding"] = get_sharding_attr(sharding)
else:
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
if ir_result_memory_kinds is not None:
for attrs, mem_kind in zip(result_attrs, ir_result_memory_kinds):
@ -2247,18 +2269,31 @@ wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding")
wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")
def set_sharding(op, sharding_proto: xc.OpSharding):
op.attributes["mhlo.sharding"] = get_sharding_attr(sharding_proto)
def get_sharding_attr(sharding_proto: xc.OpSharding):
# If there are very large numbers of devices, use the proto representation.
# The MHLO to HLO conversion supports both, and the proto representation is
# more compact.
if len(sharding_proto.tile_assignment_devices) > 100:
return ir.StringAttr.get(sharding_proto.SerializeToString()) # type: ignore[arg-type]
def set_sharding(op, sharding: xc.OpSharding):
# TODO(bartchr): add `dialects.sdy.TensorShardingAttr` to sharding type once
# JAX is released with SDY.
if config.use_shardy_partitioner.value:
op.attributes["sdy.sharding"] = get_sharding_attr(sharding)
else:
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
op.attributes["mhlo.sharding"] = get_sharding_attr(sharding)
def get_sharding_attr(
sharding: xc.OpSharding,
) -> ir.Attribute:
# TODO(bartchr): add `dialects.sdy.TensorShardingAttr` to sharding type once
# JAX is released with SDY.
if config.use_shardy_partitioner.value:
return sharding # type: ignore[return-value]
else:
# If there are very large numbers of devices, use the proto representation.
# The MHLO to HLO conversion supports both, and the proto representation is
# more compact.
if len(sharding.tile_assignment_devices) > 100:
return ir.StringAttr.get(sharding.SerializeToString()) # type: ignore[arg-type]
else:
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding)))
def wrap_with_layout_op(ctx: LoweringRuleContext,

@ -1943,10 +1943,11 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
inout_aliases: None | tuple[None | int, ...],
propagated_out_mem_kinds: tuple[None | str, ...],
platforms: tuple[str, ...],
lowering_parameters: mlir.LoweringParameters):
lowering_parameters: mlir.LoweringParameters,
mesh_shape_tuple: tuple[tuple[str, int], ...]):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings._gspmd_shardings
out_shardings = semantic_out_shardings._gspmd_shardings
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals
@ -2019,7 +2020,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
all_default_mem_kind=all_default_mem_kind,
input_output_aliases=inout_aliases,
propagated_out_mem_kinds=propagated_out_mem_kinds,
lowering_parameters=lowering_parameters)
lowering_parameters=lowering_parameters,
mesh_shape_tuple=mesh_shape_tuple)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
@ -2262,6 +2264,14 @@ def lower_sharding_computation(
semantic_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals) # type: ignore
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
mesh_shape_tuple = None
if config.use_shardy_partitioner.value:
for sharding in it.chain(
in_shardings, out_shardings,
[js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, sharding_impls.NamedSharding):
mesh_shape_tuple = sharding.mesh.shape_tuple
break
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
@ -2270,7 +2280,8 @@ def lower_sharding_computation(
tuple(da_object) if prim_requires_devices else None, donated_invars,
name_stack, all_default_mem_kind, inout_aliases,
propagated_out_mem_kinds, platforms,
lowering_parameters=lowering_parameters)
lowering_parameters=lowering_parameters,
mesh_shape_tuple=mesh_shape_tuple)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
@ -2789,6 +2800,7 @@ def create_compile_options(
num_partitions=num_partitions,
device_assignment=xla_device_assignment,
use_spmd_partitioning=spmd_lowering,
use_shardy_partitioner=config.use_shardy_partitioner.value,
use_auto_spmd_partitioning=auto_spmd_lowering,
env_options_overrides=compiler_options,
fdo_profile=fdo_profile,

@ -55,6 +55,7 @@ py_library_providing_imports_info(
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:scf_dialect",
"//jaxlib/mlir:sdy_dialect",
"//jaxlib/mlir:sparse_tensor_dialect",
"//jaxlib/mlir:stablehlo_dialect",
"//jaxlib/mlir:vector_dialect",

@ -13,14 +13,21 @@
# limitations under the License.
# ruff: noqa: F401
from typing import Any
import jaxlib.mlir.dialects.arith as arith
import jaxlib.mlir.dialects.builtin as builtin
import jaxlib.mlir.dialects.chlo as chlo
import jaxlib.mlir.dialects.func as func
import jaxlib.mlir.dialects.math as math
import jaxlib.mlir.dialects.memref as memref
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.func as func
import jaxlib.mlir.dialects.scf as scf
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
try:
import jaxlib.mlir.dialects.sdy as sdy
except ImportError:
sdy: Any = None # type: ignore[no-redef]
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
import jaxlib.mlir.dialects.vector as vector
try:

@ -258,6 +258,12 @@ class Mesh(contextlib.ContextDecorator):
(name, size)
for name, size in util.safe_zip(self.axis_names, self.devices.shape))
@functools.cached_property
def shape_tuple(self):
return tuple(
(name, size)
for name, size in util.safe_zip(self.axis_names, self.devices.shape))
@property
def size(self):
return math.prod(self.shape.values())

@ -126,6 +126,8 @@ class Sharding:
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
raise NotImplementedError('Subclasses should implement this method.')
def _to_sdy_sharding(self, num_dimensions: int):
raise NotImplementedError('Subclasses should implement this method.')
#############################################################################
# Default implementations below that all subclasses will inherit.

@ -24,14 +24,15 @@ import itertools
import math
from typing import Any, NamedTuple, Union, cast
from jax._src import core
from jax._src import mesh as mesh_lib
from jax._src import sharding
from jax._src import sharding_specs
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge
from jax._src import core
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import sdy
from jax._src.op_shardings import (
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
from jax._src.partition_spec import PartitionSpec
@ -295,6 +296,22 @@ class NamedSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
def _to_sdy_sharding(self, num_dimensions: int):
dim_shardings = [
sdy.DimensionShardingAttr.get([], is_closed=True)
] * num_dimensions
for i, dim_spec in enumerate(self._parsed_pspec):
if dim_spec is None:
dim_shardings[i] = sdy.DimensionShardingAttr.get([], is_closed=False)
elif not dim_spec:
# Already empty and closed sharding.
pass
else:
dim_shardings[i] = sdy.DimensionShardingAttr.get(
[sdy.AxisRefAttr.get(axis) for axis in dim_spec],
is_closed=True)
return sdy.TensorShardingAttr.get('mesh', dim_shardings)
@util.cache(max_size=128, trace_context_in_key=False)
def get_replicated_hlo_sharding():
@ -363,6 +380,11 @@ class SingleDeviceSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return get_replicated_hlo_sharding()
def _to_sdy_sharding(self, num_dimensions: int):
return sdy.TensorShardingAttr.get(
'mesh',
[sdy.DimensionShardingAttr.get([], is_closed=True)] * num_dimensions)
@property
def is_fully_replicated(self) -> bool:
return True
@ -495,6 +517,9 @@ class PmapSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
raise NotImplementedError("pmap doesn't use OpSharding.")
def _to_sdy_sharding(self, num_dimensions: int):
raise NotImplementedError("pmap doesn't use sdy.TensorShardingAttr.")
@functools.cached_property
def is_fully_replicated(self) -> bool:
for s in self.sharding_spec.sharding:
@ -698,6 +723,10 @@ class PositionalSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return _positional_sharding_to_xla_hlo_sharding(self, num_dimensions)
def _to_sdy_sharding(self, num_dimensions: int):
raise NotImplementedError(
"PositionalSharding can't be converted to sdy.TensorShardingAttr.")
@functools.cached_property
def is_fully_addressable(self) -> bool:
return self._internal_device_list.is_fully_addressable
@ -807,6 +836,10 @@ class GSPMDSharding(sharding.Sharding):
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return self._hlo_sharding
def _to_sdy_sharding(self, num_dimensions: int):
raise NotImplementedError(
"GSPMDSharding can't be converted to sdy.TensorShardingAttr.")
@functools.cached_property
def is_fully_replicated(self) -> bool:
return is_op_sharding_replicated(self._hlo_sharding)

@ -77,6 +77,12 @@ pytype_strict_library(
deps = if_building_jaxlib(["//jaxlib/mlir:sparse_tensor_dialect"]),
)
pytype_strict_library(
name = "sdy_dialect",
srcs = ["sdy.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:sdy_dialect"]),
)
pytype_strict_library(
name = "stablehlo_dialect",
srcs = ["stablehlo.py"],

@ -0,0 +1,21 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
try:
from jaxlib.mlir.dialects.sdy import *
except ImportError:
pass

@ -74,6 +74,7 @@ py_library_providing_imports_info(
"//jaxlib/mlir:nvvm_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:scf_dialect",
"//jaxlib/mlir:sdy_dialect",
"//jaxlib/mlir:sparse_tensor_dialect",
"//jaxlib/mlir:stablehlo_dialect",
"//jaxlib/mlir:vector_dialect",

@ -210,6 +210,20 @@ symlink_inputs(
],
)
symlink_inputs(
name = "sdy_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {
"dialects": ["@shardy//shardy/integrations/python/ir:sdy_ops_py_files"],
}},
deps = [
":core",
":ir",
":mlir",
"//jaxlib/mlir/_mlir_libs:_sdy",
],
)
symlink_inputs(
name = "stablehlo_dialect",
rule = py_library,
@ -301,4 +315,3 @@ symlink_inputs(
"//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM",
],
)

@ -281,6 +281,28 @@ py_extension(
],
)
##---------------------------------------------------------------------------##
# Shardy Extensions
##---------------------------------------------------------------------------##
py_extension(
name = "_sdy",
srcs = [
"@shardy//shardy/integrations/python/ir:sdy_module.cc",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@local_config_python//:headers",
"@pybind11",
"@shardy//shardy/integrations/c:sdy_capi_headers",
],
)
##---------------------------------------------------------------------------##
# Stablehlo Extensions
##---------------------------------------------------------------------------##
@ -357,6 +379,7 @@ cc_library(
"@llvm-project//mlir:CAPITransformsObjects",
"@llvm-project//mlir:CAPIVectorObjects",
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
"@shardy//shardy/integrations/c:sdy_capi_objects",
"@stablehlo//:chlo_capi_objects",
"@stablehlo//:stablehlo_capi_objects",
"@xla//xla/mlir_hlo:CAPIObjects",
@ -394,6 +417,7 @@ windows_cc_shared_mlir_library(
exported_symbol_prefixes = [
"mlir",
"chlo",
"sdy",
"stablehlo",
],
deps = [":jaxlib_mlir_capi_objects"],

@ -282,6 +282,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
"__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_ods_common.py",
"__main__/jaxlib/mlir/dialects/_scf_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_sdy_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py",
@ -303,6 +304,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
"__main__/jaxlib/mlir/dialects/memref.py",
"__main__/jaxlib/mlir/dialects/mhlo.py",
"__main__/jaxlib/mlir/dialects/scf.py",
"__main__/jaxlib/mlir/dialects/sdy.py",
"__main__/jaxlib/mlir/dialects/sparse_tensor.py",
"__main__/jaxlib/mlir/dialects/stablehlo.py",
"__main__/jaxlib/mlir/dialects/vector.py",

@ -56,6 +56,8 @@ from jax._src.pjit import pjit, pjit_p
from jax._src import mesh as mesh_lib
from jax._src.interpreters import pxla
from jax.interpreters import mlir
from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
@ -5065,5 +5067,45 @@ class UtilTest(jtu.JaxTestCase):
self.assertTupleEqual(mesh.axis_names, ('dp',))
@jtu.with_config(jax_use_shardy_partitioner=True)
class SdyIntegrationTest(jtu.JaxTestCase):
# TODO(bartchr): Once JAX is released with SDY, remove setUp.
def setUp(self):
if not dialects.sdy:
raise unittest.SkipTest('Shardy is not available.')
def test_lowering_input_output_sharding(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text())
def test_long_axis_names(self):
mesh = jtu.create_global_mesh((2, 2, 2), ('sequence', 'data', 'model'))
s = jax.sharding.NamedSharding(mesh, P(('sequence', 'data'), 'model'))
with ir.Context() as ctx:
dialects.sdy.register_dialect(ctx)
self.assertEqual(
str(s._to_sdy_sharding(3)),
'#sdy.sharding<@mesh, [{"sequence", "data"}, {"model"}, {}]>',
)
def test_unconstrained(self):
mesh = jtu.create_global_mesh((8,), ('x',))
s = jax.sharding.NamedSharding(mesh, P(None, P.UNCONSTRAINED, 'x'))
with ir.Context() as ctx:
dialects.sdy.register_dialect(ctx)
self.assertEqual(
str(s._to_sdy_sharding(3)), '#sdy.sharding<@mesh, [{}, {?}, {"x"}]>'
)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())