mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info. Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have: 1. Traced jaxpr 2. Jaxpr -> StableHLO 3. StableHLO with Shardy propagation 4. StableHLO with Shardy partitioning 5. StableHLO -> HLO 6. XLA optimizations The following test: ```py def test_sdy_lowering(self): mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @partial(jax.jit, out_shardings=s) def f(x): return x * 2 print(f.lower(arr).as_text()) ``` outputs: ``` module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <"x"=4, "y"=2> func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { %c = stablehlo.constant dense<2> : tensor<i64> %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64> %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64> return %1 : tensor<8x2xi64> } } ``` Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future. PiperOrigin-RevId: 655127611
This commit is contained in:
parent
459b83cf4a
commit
864178d3a3
jax
jaxlib
tests
@ -33,6 +33,7 @@ from jax._src import profiler
|
||||
from jax._src import traceback_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
import numpy as np
|
||||
|
||||
@ -113,6 +114,7 @@ def get_compile_options(
|
||||
num_partitions: int,
|
||||
device_assignment=None,
|
||||
use_spmd_partitioning: bool = True,
|
||||
use_shardy_partitioner: bool = False,
|
||||
use_auto_spmd_partitioning: bool = False,
|
||||
auto_spmd_partitioning_mesh_shape: list[int] | None = None,
|
||||
auto_spmd_partitioning_mesh_ids: list[int] | None = None,
|
||||
@ -132,6 +134,10 @@ def get_compile_options(
|
||||
`num_partitions`.
|
||||
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
|
||||
partitioning in XLA.
|
||||
use_shardy_partitioner: boolean indicating whether to use the Shardy
|
||||
partitioner in XLA. Shardy is a new open sourced propagation framework for
|
||||
MLIR. Currently Shardy is experimental in JAX. See
|
||||
www.github.com/openxla/shardy.
|
||||
use_auto_spmd_partitioning: boolean indicating whether to automatically
|
||||
generate XLA shardings for SPMD partitioner.
|
||||
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
|
||||
@ -141,8 +147,8 @@ def get_compile_options(
|
||||
env_options_overrides: dict of additional options parsed by the compiler
|
||||
fdo_profile: Optional profile for feedback-directed optimization passed to
|
||||
XLA.
|
||||
detailed_logging: Is this an "interesting" computation about which XLA
|
||||
would be wise to log compilation information?
|
||||
detailed_logging: Is this an "interesting" computation about which XLA would
|
||||
be wise to log compilation information?
|
||||
backend: the client, if available.
|
||||
"""
|
||||
compile_options = xc.CompileOptions()
|
||||
@ -194,6 +200,11 @@ def get_compile_options(
|
||||
debug_options.xla_llvm_disable_expensive_passes = True
|
||||
debug_options.xla_test_all_input_layouts = False
|
||||
|
||||
# TODO(b/352486192): Set this on compile_options after the field is moved to
|
||||
# the `ExecutableBuildOptions` proto.
|
||||
if xla_extension_version >= 278:
|
||||
debug_options.xla_use_shardy = use_shardy_partitioner
|
||||
|
||||
# XLA-AutoFDO profile version: precedence order is:
|
||||
# 1. Whatever --jax_xla_profile_version is set to.
|
||||
# 2. If --jax_xla_profile_version is not set (i.e., 0), call the function
|
||||
@ -227,6 +238,7 @@ def get_compile_options(
|
||||
|
||||
return compile_options
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def backend_compile(
|
||||
backend: xc.Client,
|
||||
|
@ -221,7 +221,8 @@ def trace_context():
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
hlo_source_file_canonicalization_regex.value,
|
||||
pgle_profiling_runs.value,
|
||||
enable_pgle.value)
|
||||
enable_pgle.value,
|
||||
use_shardy_partitioner.value)
|
||||
|
||||
config = Config()
|
||||
|
||||
@ -829,6 +830,7 @@ class _GlobalExtraJitContext(NamedTuple):
|
||||
xla_profile_version: int = 0
|
||||
pgle_profiling_runs: int = 0
|
||||
enable_pgle: bool = False
|
||||
use_shardy_partitioner: bool = False
|
||||
|
||||
|
||||
def _update_global_jit_state(**kw):
|
||||
@ -1678,3 +1680,20 @@ pmap_no_rank_reduction = bool_state(
|
||||
"If True, pmap shards have a the same rank as their enclosing array."
|
||||
)
|
||||
)
|
||||
|
||||
use_shardy_partitioner = bool_state(
|
||||
name='jax_use_shardy_partitioner',
|
||||
default=False,
|
||||
upgrade=True,
|
||||
help=(
|
||||
'Whether to lower to Shardy. Shardy is a new open sourced propagation '
|
||||
'framework for MLIR. Currently Shardy is experimental in JAX. See '
|
||||
'www.github.com/openxla/shardy'
|
||||
),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
use_shardy_partitioner=val
|
||||
),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
use_shardy_partitioner=val
|
||||
),
|
||||
)
|
||||
|
@ -506,7 +506,9 @@ def make_ir_context() -> ir.Context:
|
||||
# we don't do any heavy computation on MLIR modules from Python anyway, so we
|
||||
# just disable threading.
|
||||
context.enable_multithreading(False)
|
||||
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove the if.
|
||||
if dialects.sdy:
|
||||
dialects.sdy.register_dialect(context)
|
||||
dialects.mhlo.register_mhlo_dialect(context)
|
||||
dialects.chlo.register_dialect(context)
|
||||
dialects.hlo.register_dialect(context)
|
||||
@ -874,6 +876,8 @@ _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
|
||||
def _to_physical_op_sharding(
|
||||
aval: core.AbstractValue, sharding: JSharding | None,
|
||||
) -> xc.OpSharding | None:
|
||||
# TODO(bartchr): add `dialects.sdy.TensorShardingAttr` to func type once JAX
|
||||
# is released with SDY.
|
||||
if sharding is None:
|
||||
return None
|
||||
assert isinstance(sharding, JSharding)
|
||||
@ -883,6 +887,8 @@ def _to_physical_op_sharding(
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
sharding = sharding_impls.physical_sharding(aval, sharding)
|
||||
aval = core.physical_aval(aval)
|
||||
if config.use_shardy_partitioner.value:
|
||||
return sharding._to_sdy_sharding(aval.ndim)
|
||||
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
|
||||
|
||||
|
||||
@ -927,6 +933,7 @@ def lower_jaxpr_to_module(
|
||||
input_output_aliases: None | tuple[int | None, ...] = None,
|
||||
propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
|
||||
lowering_parameters: LoweringParameters,
|
||||
mesh_shape_tuple: tuple[tuple[str, int], ...] | None = None,
|
||||
) -> LoweringResult:
|
||||
"""Lowers a top-level jaxpr to an MLIR module.
|
||||
|
||||
@ -1012,6 +1019,14 @@ def lower_jaxpr_to_module(
|
||||
# Remove module name characters that XLA would alter. This ensures that
|
||||
# XLA computation preserves the module name.
|
||||
attrs = ctx.module.operation.attributes
|
||||
if config.use_shardy_partitioner.value:
|
||||
assert mesh_shape_tuple is not None
|
||||
ctx.module.body.append(
|
||||
dialects.sdy.MeshOp(
|
||||
"mesh",
|
||||
dialects.sdy.MeshAttr.get(
|
||||
[dialects.sdy.MeshAxisAttr.get(name, size)
|
||||
for name, size in mesh_shape_tuple])))
|
||||
module_name = _module_name_regex.sub("_", module_name)
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
|
||||
@ -1053,6 +1068,7 @@ def lower_jaxpr_to_module(
|
||||
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
|
||||
ctx.shape_poly_state)
|
||||
|
||||
|
||||
def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args,
|
||||
arg_memory_kinds, result_memory_kinds):
|
||||
if input_output_aliases is None:
|
||||
@ -1330,7 +1346,10 @@ def lower_jaxpr_to_fun(
|
||||
if use_sharding_annotations and ir_arg_shardings is not None:
|
||||
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
|
||||
if sharding is not None:
|
||||
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
|
||||
if config.use_shardy_partitioner.value:
|
||||
attrs["sdy.sharding"] = get_sharding_attr(sharding)
|
||||
else:
|
||||
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
|
||||
|
||||
if ir_arg_memory_kinds is not None:
|
||||
for attrs, memory_kind in zip(arg_attrs, ir_arg_memory_kinds):
|
||||
@ -1394,7 +1413,10 @@ def lower_jaxpr_to_fun(
|
||||
if use_sharding_annotations and ir_result_shardings is not None:
|
||||
for attrs, sharding in zip(result_attrs, ir_result_shardings):
|
||||
if sharding is not None:
|
||||
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
|
||||
if config.use_shardy_partitioner.value:
|
||||
attrs["sdy.sharding"] = get_sharding_attr(sharding)
|
||||
else:
|
||||
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
|
||||
|
||||
if ir_result_memory_kinds is not None:
|
||||
for attrs, mem_kind in zip(result_attrs, ir_result_memory_kinds):
|
||||
@ -2247,18 +2269,31 @@ wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding")
|
||||
wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
|
||||
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")
|
||||
|
||||
def set_sharding(op, sharding_proto: xc.OpSharding):
|
||||
op.attributes["mhlo.sharding"] = get_sharding_attr(sharding_proto)
|
||||
|
||||
|
||||
def get_sharding_attr(sharding_proto: xc.OpSharding):
|
||||
# If there are very large numbers of devices, use the proto representation.
|
||||
# The MHLO to HLO conversion supports both, and the proto representation is
|
||||
# more compact.
|
||||
if len(sharding_proto.tile_assignment_devices) > 100:
|
||||
return ir.StringAttr.get(sharding_proto.SerializeToString()) # type: ignore[arg-type]
|
||||
def set_sharding(op, sharding: xc.OpSharding):
|
||||
# TODO(bartchr): add `dialects.sdy.TensorShardingAttr` to sharding type once
|
||||
# JAX is released with SDY.
|
||||
if config.use_shardy_partitioner.value:
|
||||
op.attributes["sdy.sharding"] = get_sharding_attr(sharding)
|
||||
else:
|
||||
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
|
||||
op.attributes["mhlo.sharding"] = get_sharding_attr(sharding)
|
||||
|
||||
|
||||
def get_sharding_attr(
|
||||
sharding: xc.OpSharding,
|
||||
) -> ir.Attribute:
|
||||
# TODO(bartchr): add `dialects.sdy.TensorShardingAttr` to sharding type once
|
||||
# JAX is released with SDY.
|
||||
if config.use_shardy_partitioner.value:
|
||||
return sharding # type: ignore[return-value]
|
||||
else:
|
||||
# If there are very large numbers of devices, use the proto representation.
|
||||
# The MHLO to HLO conversion supports both, and the proto representation is
|
||||
# more compact.
|
||||
if len(sharding.tile_assignment_devices) > 100:
|
||||
return ir.StringAttr.get(sharding.SerializeToString()) # type: ignore[arg-type]
|
||||
else:
|
||||
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding)))
|
||||
|
||||
|
||||
def wrap_with_layout_op(ctx: LoweringRuleContext,
|
||||
|
@ -1943,10 +1943,11 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
inout_aliases: None | tuple[None | int, ...],
|
||||
propagated_out_mem_kinds: tuple[None | str, ...],
|
||||
platforms: tuple[str, ...],
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
mesh_shape_tuple: tuple[tuple[str, int], ...]):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = semantic_in_shardings._gspmd_shardings
|
||||
out_shardings = semantic_out_shardings._gspmd_shardings
|
||||
in_shardings = semantic_in_shardings.shardings
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
global_in_avals = closed_jaxpr.in_avals
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
|
||||
@ -2019,7 +2020,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
input_output_aliases=inout_aliases,
|
||||
propagated_out_mem_kinds=propagated_out_mem_kinds,
|
||||
lowering_parameters=lowering_parameters)
|
||||
lowering_parameters=lowering_parameters,
|
||||
mesh_shape_tuple=mesh_shape_tuple)
|
||||
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
@ -2262,6 +2264,14 @@ def lower_sharding_computation(
|
||||
semantic_out_shardings = SemanticallyEqualShardings(
|
||||
out_shardings, global_out_avals) # type: ignore
|
||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||
mesh_shape_tuple = None
|
||||
if config.use_shardy_partitioner.value:
|
||||
for sharding in it.chain(
|
||||
in_shardings, out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings]):
|
||||
if isinstance(sharding, sharding_impls.NamedSharding):
|
||||
mesh_shape_tuple = sharding.mesh.shape_tuple
|
||||
break
|
||||
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
@ -2270,7 +2280,8 @@ def lower_sharding_computation(
|
||||
tuple(da_object) if prim_requires_devices else None, donated_invars,
|
||||
name_stack, all_default_mem_kind, inout_aliases,
|
||||
propagated_out_mem_kinds, platforms,
|
||||
lowering_parameters=lowering_parameters)
|
||||
lowering_parameters=lowering_parameters,
|
||||
mesh_shape_tuple=mesh_shape_tuple)
|
||||
|
||||
# backend and device_assignment is passed through to MeshExecutable because
|
||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||
@ -2789,6 +2800,7 @@ def create_compile_options(
|
||||
num_partitions=num_partitions,
|
||||
device_assignment=xla_device_assignment,
|
||||
use_spmd_partitioning=spmd_lowering,
|
||||
use_shardy_partitioner=config.use_shardy_partitioner.value,
|
||||
use_auto_spmd_partitioning=auto_spmd_lowering,
|
||||
env_options_overrides=compiler_options,
|
||||
fdo_profile=fdo_profile,
|
||||
|
@ -55,6 +55,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:mhlo_dialect",
|
||||
"//jaxlib/mlir:pass_manager",
|
||||
"//jaxlib/mlir:scf_dialect",
|
||||
"//jaxlib/mlir:sdy_dialect",
|
||||
"//jaxlib/mlir:sparse_tensor_dialect",
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
|
@ -13,14 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F401
|
||||
from typing import Any
|
||||
|
||||
import jaxlib.mlir.dialects.arith as arith
|
||||
import jaxlib.mlir.dialects.builtin as builtin
|
||||
import jaxlib.mlir.dialects.chlo as chlo
|
||||
import jaxlib.mlir.dialects.func as func
|
||||
import jaxlib.mlir.dialects.math as math
|
||||
import jaxlib.mlir.dialects.memref as memref
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
import jaxlib.mlir.dialects.func as func
|
||||
import jaxlib.mlir.dialects.scf as scf
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
|
||||
try:
|
||||
import jaxlib.mlir.dialects.sdy as sdy
|
||||
except ImportError:
|
||||
sdy: Any = None # type: ignore[no-redef]
|
||||
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
|
||||
import jaxlib.mlir.dialects.vector as vector
|
||||
try:
|
||||
|
@ -258,6 +258,12 @@ class Mesh(contextlib.ContextDecorator):
|
||||
(name, size)
|
||||
for name, size in util.safe_zip(self.axis_names, self.devices.shape))
|
||||
|
||||
@functools.cached_property
|
||||
def shape_tuple(self):
|
||||
return tuple(
|
||||
(name, size)
|
||||
for name, size in util.safe_zip(self.axis_names, self.devices.shape))
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return math.prod(self.shape.values())
|
||||
|
@ -126,6 +126,8 @@ class Sharding:
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int):
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
#############################################################################
|
||||
# Default implementations below that all subclasses will inherit.
|
||||
|
@ -24,14 +24,15 @@ import itertools
|
||||
import math
|
||||
from typing import Any, NamedTuple, Union, cast
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import core
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.op_shardings import (
|
||||
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -295,6 +296,22 @@ class NamedSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int):
|
||||
dim_shardings = [
|
||||
sdy.DimensionShardingAttr.get([], is_closed=True)
|
||||
] * num_dimensions
|
||||
for i, dim_spec in enumerate(self._parsed_pspec):
|
||||
if dim_spec is None:
|
||||
dim_shardings[i] = sdy.DimensionShardingAttr.get([], is_closed=False)
|
||||
elif not dim_spec:
|
||||
# Already empty and closed sharding.
|
||||
pass
|
||||
else:
|
||||
dim_shardings[i] = sdy.DimensionShardingAttr.get(
|
||||
[sdy.AxisRefAttr.get(axis) for axis in dim_spec],
|
||||
is_closed=True)
|
||||
return sdy.TensorShardingAttr.get('mesh', dim_shardings)
|
||||
|
||||
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def get_replicated_hlo_sharding():
|
||||
@ -363,6 +380,11 @@ class SingleDeviceSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return get_replicated_hlo_sharding()
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int):
|
||||
return sdy.TensorShardingAttr.get(
|
||||
'mesh',
|
||||
[sdy.DimensionShardingAttr.get([], is_closed=True)] * num_dimensions)
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return True
|
||||
@ -495,6 +517,9 @@ class PmapSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
raise NotImplementedError("pmap doesn't use OpSharding.")
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int):
|
||||
raise NotImplementedError("pmap doesn't use sdy.TensorShardingAttr.")
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
for s in self.sharding_spec.sharding:
|
||||
@ -698,6 +723,10 @@ class PositionalSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return _positional_sharding_to_xla_hlo_sharding(self, num_dimensions)
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int):
|
||||
raise NotImplementedError(
|
||||
"PositionalSharding can't be converted to sdy.TensorShardingAttr.")
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_addressable(self) -> bool:
|
||||
return self._internal_device_list.is_fully_addressable
|
||||
@ -807,6 +836,10 @@ class GSPMDSharding(sharding.Sharding):
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return self._hlo_sharding
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int):
|
||||
raise NotImplementedError(
|
||||
"GSPMDSharding can't be converted to sdy.TensorShardingAttr.")
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return is_op_sharding_replicated(self._hlo_sharding)
|
||||
|
@ -77,6 +77,12 @@ pytype_strict_library(
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:sparse_tensor_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "sdy_dialect",
|
||||
srcs = ["sdy.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:sdy_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "stablehlo_dialect",
|
||||
srcs = ["stablehlo.py"],
|
||||
|
21
jax/extend/mlir/dialects/sdy.py
Normal file
21
jax/extend/mlir/dialects/sdy.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
|
||||
try:
|
||||
from jaxlib.mlir.dialects.sdy import *
|
||||
except ImportError:
|
||||
pass
|
@ -74,6 +74,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:nvvm_dialect",
|
||||
"//jaxlib/mlir:pass_manager",
|
||||
"//jaxlib/mlir:scf_dialect",
|
||||
"//jaxlib/mlir:sdy_dialect",
|
||||
"//jaxlib/mlir:sparse_tensor_dialect",
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
|
@ -210,6 +210,20 @@ symlink_inputs(
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "sdy_dialect",
|
||||
rule = py_library,
|
||||
symlinked_inputs = {"srcs": {
|
||||
"dialects": ["@shardy//shardy/integrations/python/ir:sdy_ops_py_files"],
|
||||
}},
|
||||
deps = [
|
||||
":core",
|
||||
":ir",
|
||||
":mlir",
|
||||
"//jaxlib/mlir/_mlir_libs:_sdy",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "stablehlo_dialect",
|
||||
rule = py_library,
|
||||
@ -301,4 +315,3 @@ symlink_inputs(
|
||||
"//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -281,6 +281,28 @@ py_extension(
|
||||
],
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
# Shardy Extensions
|
||||
##---------------------------------------------------------------------------##
|
||||
|
||||
py_extension(
|
||||
name = "_sdy",
|
||||
srcs = [
|
||||
"@shardy//shardy/integrations/python/ir:sdy_module.cc",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@local_config_python//:headers",
|
||||
"@pybind11",
|
||||
"@shardy//shardy/integrations/c:sdy_capi_headers",
|
||||
],
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
# Stablehlo Extensions
|
||||
##---------------------------------------------------------------------------##
|
||||
@ -357,6 +379,7 @@ cc_library(
|
||||
"@llvm-project//mlir:CAPITransformsObjects",
|
||||
"@llvm-project//mlir:CAPIVectorObjects",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
|
||||
"@shardy//shardy/integrations/c:sdy_capi_objects",
|
||||
"@stablehlo//:chlo_capi_objects",
|
||||
"@stablehlo//:stablehlo_capi_objects",
|
||||
"@xla//xla/mlir_hlo:CAPIObjects",
|
||||
@ -394,6 +417,7 @@ windows_cc_shared_mlir_library(
|
||||
exported_symbol_prefixes = [
|
||||
"mlir",
|
||||
"chlo",
|
||||
"sdy",
|
||||
"stablehlo",
|
||||
],
|
||||
deps = [":jaxlib_mlir_capi_objects"],
|
||||
|
@ -282,6 +282,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
|
||||
"__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_ods_common.py",
|
||||
"__main__/jaxlib/mlir/dialects/_scf_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_sdy_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py",
|
||||
@ -303,6 +304,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
|
||||
"__main__/jaxlib/mlir/dialects/memref.py",
|
||||
"__main__/jaxlib/mlir/dialects/mhlo.py",
|
||||
"__main__/jaxlib/mlir/dialects/scf.py",
|
||||
"__main__/jaxlib/mlir/dialects/sdy.py",
|
||||
"__main__/jaxlib/mlir/dialects/sparse_tensor.py",
|
||||
"__main__/jaxlib/mlir/dialects/stablehlo.py",
|
||||
"__main__/jaxlib/mlir/dialects/vector.py",
|
||||
|
@ -56,6 +56,8 @@ from jax._src.pjit import pjit, pjit_p
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lib.mlir import dialects
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
@ -5065,5 +5067,45 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertTupleEqual(mesh.axis_names, ('dp',))
|
||||
|
||||
|
||||
@jtu.with_config(jax_use_shardy_partitioner=True)
|
||||
class SdyIntegrationTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove setUp.
|
||||
def setUp(self):
|
||||
if not dialects.sdy:
|
||||
raise unittest.SkipTest('Shardy is not available.')
|
||||
|
||||
def test_lowering_input_output_sharding(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(jax.jit, out_shardings=s)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text())
|
||||
|
||||
def test_long_axis_names(self):
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('sequence', 'data', 'model'))
|
||||
s = jax.sharding.NamedSharding(mesh, P(('sequence', 'data'), 'model'))
|
||||
with ir.Context() as ctx:
|
||||
dialects.sdy.register_dialect(ctx)
|
||||
self.assertEqual(
|
||||
str(s._to_sdy_sharding(3)),
|
||||
'#sdy.sharding<@mesh, [{"sequence", "data"}, {"model"}, {}]>',
|
||||
)
|
||||
|
||||
def test_unconstrained(self):
|
||||
mesh = jtu.create_global_mesh((8,), ('x',))
|
||||
s = jax.sharding.NamedSharding(mesh, P(None, P.UNCONSTRAINED, 'x'))
|
||||
with ir.Context() as ctx:
|
||||
dialects.sdy.register_dialect(ctx)
|
||||
self.assertEqual(
|
||||
str(s._to_sdy_sharding(3)), '#sdy.sharding<@mesh, [{}, {?}, {"x"}]>'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user