#sdy enable pure callbacks and debug prints in JAX.

Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.

PiperOrigin-RevId: 713780181
This commit is contained in:
Bart Chrzaszcz 2025-01-09 13:37:00 -08:00 committed by jax authors
parent 93ef0f13fe
commit dc53c563bb
7 changed files with 163 additions and 67 deletions

View File

@ -21,6 +21,7 @@ import logging
from typing import Any
import jax
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
@ -225,7 +226,9 @@ batching.primitive_batchers[pure_callback_p] = functools.partial(
)
def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
def _callback_op_sharding(
axis_context, sharding: SingleDeviceSharding | None, avals_out
):
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
# If we have fully manual sharding during lowering, that means the JAX
# program has per-device semantics, so we run the callback on each device.
@ -239,8 +242,18 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
"callbacks do not support specifying sharding inside spmd"
" computations"
)
op_sharding = xc.OpSharding()
op_sharding.type = xc.OpSharding.Type.MANUAL
if config.use_shardy_partitioner.value:
assert len(avals_out) == 1
op_sharding = sharding_impls.SdyArrayShardingList([
sharding_impls.SdyArraySharding(
mesh_shape=(),
dimension_shardings=[
sharding_impls.SdyDimSharding(axes=[], is_closed=True)
] * avals_out[0].ndim,
logical_device_ids=())])
else:
op_sharding = xc.OpSharding() # type: ignore[assignment]
op_sharding.type = xc.OpSharding.Type.MANUAL
return op_sharding
if isinstance(axis_context, sharding_impls.ShardingContext):
@ -268,10 +281,17 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
# If we have fully automatic sharding during lowering, that means the JAX
# program has bulk array semantics, so we run the callback with a MAXIMAL
# sharding and hence execute it only once on the full logical value).
op_sharding = xc.OpSharding()
op_sharding.type = xc.OpSharding.Type.MAXIMAL
op_sharding.tile_assignment_dimensions = [1]
op_sharding.tile_assignment_devices = [device_index]
if config.use_shardy_partitioner.value:
op_sharding = sharding_impls.SdyArrayShardingList([
sharding_impls.SdyArraySharding(
mesh_shape=(),
dimension_shardings=[],
logical_device_ids=(device_index,))])
else:
op_sharding = xc.OpSharding() # type: ignore[assignment]
op_sharding.type = xc.OpSharding.Type.MAXIMAL
op_sharding.tile_assignment_dimensions = [1]
op_sharding.tile_assignment_devices = [device_index]
return op_sharding
# When there's no SPMD partitioning going on, don't annotate a sharding.
@ -291,7 +311,8 @@ def pure_callback_lowering(
)
)
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
op_sharding = _callback_op_sharding(
ctx.module_context.axis_context, sharding, ctx.avals_out)
result, _, _ = mlir.emit_python_callback(
ctx,
_callback,
@ -561,7 +582,8 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
)
)
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
op_sharding = _callback_op_sharding(
ctx.module_context.axis_context, sharding, ctx.avals_out)
if ordered:
token = ctx.tokens_in.get(_OrderedIOEffect)
result, token, _ = mlir.emit_python_callback(
@ -576,7 +598,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
)
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token}))
else:
result, token, _ = mlir.emit_python_callback(
result, _, _ = mlir.emit_python_callback(
ctx,
_callback,
None,

View File

@ -15,9 +15,9 @@
from __future__ import annotations
import importlib.util
from collections.abc import Callable, Sequence
import functools
import importlib.util
import logging
import string
import sys
@ -29,12 +29,12 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import effects
from jax._src import mesh as mesh_lib
from jax._src import sharding_impls
from jax._src import dispatch
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import ad
@ -135,22 +135,37 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
axis_context = ctx.module_context.axis_context
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
# If we have fully manual sharding during lowering, that means the JAX
# program has per-device semantics, so we run the callback on each device.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MANUAL
if config.use_shardy_partitioner.value:
assert len(ctx.avals_out) == 1
sharding = sharding_impls.SdyArrayShardingList([
sharding_impls.SdyArraySharding(
mesh_shape=(),
dimension_shardings=[
sharding_impls.SdyDimSharding(axes=[], is_closed=True)
] * ctx.avals_out[0].ndim,
logical_device_ids=())])
else:
# If we have fully manual sharding during lowering, that means the JAX
# program has per-device semantics, so we run the callback on each device.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MANUAL
elif isinstance(
axis_context,
(sharding_impls.ShardingContext, sharding_impls.SPMDAxisContext),
):
# If we have fully automatic sharding during lowering, that means the JAX
# program has bulk array semantics, so we run the callback with a MAXIMAL
# sharding and hence execute it only once on the full logical value).
# If we have partially automatic sharding, we do this too... not sure why!
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [0]
if config.use_shardy_partitioner.value:
sharding = sharding_impls.SdyArrayShardingList([
sharding_impls.SdyArraySharding(
mesh_shape=(), dimension_shardings=[], logical_device_ids=(0,))])
else:
# If we have fully automatic sharding during lowering, that means the JAX
# program has bulk array semantics, so we run the callback with a MAXIMAL
# sharding and hence execute it only once on the full logical value).
# If we have partially automatic sharding, we do this too... not sure why!
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [0]
else:
# When there's no SPMD partitioning going on, don't annotate a sharding.
sharding = None

View File

@ -50,7 +50,8 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import (AUTO, NamedSharding,
modify_sdy_sharding_wrt_axis_types)
modify_sdy_sharding_wrt_axis_types,
SdyArraySharding, SdyArrayShardingList)
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib.mlir import dialects, ir, passmanager
@ -1019,7 +1020,7 @@ def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim):
def _to_physical_op_sharding(
ctx: ModuleContext,
aval: core.AbstractValue, sharding: JSharding | AUTO | None,
) -> xc.OpSharding | sharding_impls.SdyArraySharding | None:
) -> xc.OpSharding | SdyArraySharding | None:
if sharding is None:
return None
if isinstance(sharding, AUTO):
@ -1749,7 +1750,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
if config.use_shardy_partitioner.value:
physical_ndim = core.physical_aval(aval).ndim
s = sharding_impls.SdyArraySharding(
s = SdyArraySharding(
mesh_shape=None,
dimension_shardings=[
sharding_impls.SdyDimSharding(axes=[], is_closed=i >= aval.ndim)
@ -2554,7 +2555,7 @@ def _wrap_with_spmd_op(name: str,
ctx: LoweringRuleContext,
x: ir.Value,
aval_out: core.AbstractValue,
sharding: xc.OpSharding | sharding_impls.SdyArraySharding,
sharding: xc.OpSharding | SdyArraySharding,
unspecified_dims: set[int] | None = None,
has_side_effect: bool = False,
allow_shardy_lowering: bool = False):
@ -2615,7 +2616,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
def set_sharding(op, sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList):
if config.use_shardy_partitioner.value:
op.attributes["sdy.sharding"] = get_sharding_attr(sharding)
else:
@ -2623,7 +2624,7 @@ def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
def get_sharding_attr(
sharding: xc.OpSharding | sharding_impls.SdyArraySharding
sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList
) -> ir.Attribute:
if config.use_shardy_partitioner.value:
return sharding.build() # type: ignore
@ -2783,9 +2784,15 @@ RECV_FROM_HOST_TYPE = 3
def is_empty_shape(s: core.Shape) -> bool:
return any(d == 0 for d in s)
def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
aval: core.ShapedArray, name: str, *,
sharding: xc.OpSharding | None = None) -> ir.Value:
def send_to_host(
channel: int,
token: hlo.TokenType,
operand: Any,
name: str,
*,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
) -> ir.Value:
channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
send_op = hlo.SendOp([operand], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
@ -2794,13 +2801,27 @@ def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
if sharding is not None:
if config.use_shardy_partitioner.value:
# `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed
# in the maximal sharding of the array type. Since a token has no rank,
# we need to create an equivalent sharding with no dimensions.
assert isinstance(sharding, SdyArrayShardingList)
assert len(sharding.shardings) == 1
sharding = SdyArrayShardingList([
SdyArraySharding(
mesh_shape=(), dimension_shardings=[],
logical_device_ids=sharding.shardings[0].logical_device_ids)])
set_sharding(send_op, sharding)
return send_op.result
def receive_from_host(channel: int, token: hlo.TokenType,
out_aval: core.ShapedArray, name: str, *,
sharding: xc.OpSharding | None = None,
def receive_from_host(
channel: int,
token: hlo.TokenType,
out_aval: core.ShapedArray,
name: str,
*,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
) -> tuple[ir.Value, ir.Value]:
channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
@ -2811,6 +2832,17 @@ def receive_from_host(channel: int, token: hlo.TokenType,
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
if sharding is not None:
if config.use_shardy_partitioner.value:
assert isinstance(sharding, SdyArrayShardingList)
assert len(sharding.shardings) == 1
# `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the
# number of shardings to match the number of results, but JAX only sees
# the array result, we need to add an equivalent sharding for the token.
sharding = SdyArrayShardingList([
sharding.shardings[0],
SdyArraySharding(
mesh_shape=(), dimension_shardings=[],
logical_device_ids=sharding.shardings[0].logical_device_ids)])
set_sharding(recv_op, sharding)
# Token should be at the end of the results
result, token = recv_op.results
@ -2828,7 +2860,7 @@ def _emit_tpu_python_callback(
result_avals: Sequence[core.ShapedArray],
result_shapes: Sequence[xc.Shape],
*,
sharding: xc.OpSharding | None = None
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
) -> tuple[Sequence[ir.Value], Any]:
token = token or hlo.create_token()
_wrapped_callback = callback
@ -2848,14 +2880,14 @@ def _emit_tpu_python_callback(
dummy_send_val = ir_constant(np.zeros(1, np.float32))
operand_shapes = [*operand_shapes,
xla.aval_to_xla_shapes(dummy_send_aval)[0]]
token = send_to_host(send_channel, token, dummy_send_val, dummy_send_aval,
callback.__name__, sharding=sharding)
token = send_to_host(send_channel, token, dummy_send_val, callback.__name__,
sharding=sharding)
send_channels.append(send_channel)
else:
for operand, operand_aval in zip(operands, operand_avals):
for operand in operands:
channel = ctx.module_context.new_channel()
token = send_to_host(channel, token, operand, operand_aval,
callback.__name__, sharding=sharding)
token = send_to_host(channel, token, operand, callback.__name__,
sharding=sharding)
send_channels.append(channel)
recv_channels = []
@ -2873,6 +2905,7 @@ def _emit_tpu_python_callback(
ctx.module_context.add_host_callback(ifrt_callback)
return outputs, token
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
if minor_to_major is None:
# Needed for token layouts
@ -2896,7 +2929,7 @@ def emit_python_callback(
result_avals: Sequence[core.ShapedArray],
*,
has_side_effect: bool,
sharding: xc.OpSharding | None = None,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
operand_layouts: Sequence[Sequence[int] | None] | None = None,
result_layouts: Sequence[Sequence[int] | None] | None = None,
) -> tuple[Sequence[IrValues], Any, Any]:
@ -3024,6 +3057,7 @@ def emit_python_callback(
token, *results = results
return results, token, ifrt_callback
def build_mlir_module_helper(
closed_jaxpr: core.ClosedJaxpr, *, name: str,
platforms: Sequence[str],

View File

@ -118,7 +118,6 @@ class SdyDimSharding:
is_closed: bool
priority: int | None = None
# NOTE: An MLIR context is required as a context manager.
def build(self) -> sdy.DimensionShardingAttr:
return sdy.DimensionShardingAttr.get(
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
@ -144,7 +143,6 @@ class SdyArraySharding:
logical_device_ids: tuple[int, ...] | None = None
replicated_axes: tuple[str, ...] = ()
# NOTE: An MLIR context is required as a context manager.
def build(self) -> sdy.TensorShardingAttr:
if self.mesh_shape is None:
mesh_attr = sdy.MeshAttr.get([])
@ -169,6 +167,15 @@ class SdyArraySharding:
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})"
@dataclasses.dataclass
class SdyArrayShardingList:
shardings: Sequence[SdyArraySharding]
def build(self) -> sdy.TensorShardingPerValueAttr:
return sdy.TensorShardingPerValueAttr.get(
[sharding.build() for sharding in self.shardings])
@util.cache(max_size=4096, trace_context_in_key=False)
def named_sharding_to_xla_hlo_sharding(
self, num_dimensions: int) -> xc.HloSharding:

View File

@ -622,9 +622,10 @@ def _rule_missing(prim: core.Primitive, *_, **__):
# Lowering
def _shardy_shard_map_sharding(
ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in
) -> ir.Attribute:
) -> sharding_impls.SdyArraySharding:
axes = {name: i for i, ns in names.items() for name in ns}
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
@ -634,7 +635,7 @@ def _shardy_shard_map_sharding(
if auto:
for dim_sharding in sdy_sharding.dimension_shardings:
dim_sharding.is_closed = False
return sdy_sharding.build()
return sdy_sharding
def _shard_map_lowering_shardy(
@ -664,12 +665,12 @@ def _shard_map_lowering_shardy(
dim_var_values=ctx.dim_var_values)
return out_nodes
in_shardings = sdy.TensorShardingPerValueAttr.get(map(
in_shardings = sharding_impls.SdyArrayShardingList(map(
partial(_shardy_shard_map_sharding, ctx, mesh, auto),
in_names, ctx.avals_in))
out_shardings = sdy.TensorShardingPerValueAttr.get(map(
in_names, ctx.avals_in)).build()
out_shardings = sharding_impls.SdyArrayShardingList(map(
partial(_shardy_shard_map_sharding, ctx, mesh, auto),
out_names, ctx.avals_out))
out_names, ctx.avals_out)).build()
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
manual_computation_op = sdy.ManualComputationOp(
output_types, args, in_shardings, out_shardings,

View File

@ -1278,15 +1278,14 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "debugging_primitives_test",
srcs = ["debugging_primitives_test.py"],
disable_configs = [
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
],
enable_configs = [
"cpu",
"gpu_h100",
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
"gpu_a100_shardy",
"tpu_v3_2x2_shardy",
],
)
@ -1296,13 +1295,12 @@ jax_multiplatform_test(
backend_tags = {
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
disable_configs = [
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
],
enable_configs = [
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
"tpu_v3_2x2_shardy",
"gpu_2gpu_shardy",
],
tags = ["multiaccelerator"],
deps = [

View File

@ -937,11 +937,19 @@ class PureCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(
out, np.arange(jax.local_device_count()) * 2
)
self.assertIn(
f'{{maximal device={callback_device_index}}}',
str(f_jit.lower(inp).compiler_ir(dialect='stablehlo')),
)
stablehlo_ir = f_jit.lower(inp).as_text()
if config.use_shardy_partitioner.value:
self.assertIn(
"sdy.sharding ="
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
" []>]>",
stablehlo_ir)
self.assertIn(
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
f" device_ids=[{callback_device_index}]>",
stablehlo_ir)
else:
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
def test_can_shard_pure_callback_manually(self):
@ -1199,10 +1207,19 @@ class IOCallbackTest(jtu.JaxTestCase):
self.assertIn(v, _collected)
callback_device_index = in_spec._device_assignment.index(callback_device)
self.assertIn(
f'{{maximal device={callback_device_index}}}',
str(f.lower(x).compiler_ir(dialect='stablehlo')),
)
stablehlo_ir = f.lower(x).as_text()
if config.use_shardy_partitioner.value:
self.assertIn(
"sdy.sharding ="
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
" []>]>",
stablehlo_ir)
self.assertIn(
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
f" device_ids=[{callback_device_index}]>",
stablehlo_ir)
else:
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
def test_sequence_pjit_io_callback_ordered(self):
# A sequence of pairs of calls to pjit(io_callback(ordered=True)) with each
@ -1257,6 +1274,8 @@ class IOCallbackTest(jtu.JaxTestCase):
self.assertEqual(_collected, expected)
def test_can_shard_io_callback_manually(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/384938613): Failing under shardy.")
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))