From dc53c563bba973904ff87b9647227d007aee1723 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 9 Jan 2025 13:37:00 -0800 Subject: [PATCH] #sdy enable pure callbacks and debug prints in JAX. Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up. PiperOrigin-RevId: 713780181 --- jax/_src/callback.py | 42 +++++++++++++++----- jax/_src/debugging.py | 45 ++++++++++++++-------- jax/_src/interpreters/mlir.py | 72 ++++++++++++++++++++++++++--------- jax/_src/sharding_impls.py | 11 +++++- jax/experimental/shard_map.py | 13 ++++--- tests/BUILD | 10 ++--- tests/python_callback_test.py | 37 +++++++++++++----- 7 files changed, 163 insertions(+), 67 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 013b766b8..098273d27 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -21,6 +21,7 @@ import logging from typing import Any import jax +from jax._src import config from jax._src import core from jax._src import deprecations from jax._src import dispatch @@ -225,7 +226,9 @@ batching.primitive_batchers[pure_callback_p] = functools.partial( ) -def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): +def _callback_op_sharding( + axis_context, sharding: SingleDeviceSharding | None, avals_out +): if isinstance(axis_context, sharding_impls.SPMDAxisContext): # If we have fully manual sharding during lowering, that means the JAX # program has per-device semantics, so we run the callback on each device. @@ -239,8 +242,18 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): "callbacks do not support specifying sharding inside spmd" " computations" ) - op_sharding = xc.OpSharding() - op_sharding.type = xc.OpSharding.Type.MANUAL + if config.use_shardy_partitioner.value: + assert len(avals_out) == 1 + op_sharding = sharding_impls.SdyArrayShardingList([ + sharding_impls.SdyArraySharding( + mesh_shape=(), + dimension_shardings=[ + sharding_impls.SdyDimSharding(axes=[], is_closed=True) + ] * avals_out[0].ndim, + logical_device_ids=())]) + else: + op_sharding = xc.OpSharding() # type: ignore[assignment] + op_sharding.type = xc.OpSharding.Type.MANUAL return op_sharding if isinstance(axis_context, sharding_impls.ShardingContext): @@ -268,10 +281,17 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): # If we have fully automatic sharding during lowering, that means the JAX # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). - op_sharding = xc.OpSharding() - op_sharding.type = xc.OpSharding.Type.MAXIMAL - op_sharding.tile_assignment_dimensions = [1] - op_sharding.tile_assignment_devices = [device_index] + if config.use_shardy_partitioner.value: + op_sharding = sharding_impls.SdyArrayShardingList([ + sharding_impls.SdyArraySharding( + mesh_shape=(), + dimension_shardings=[], + logical_device_ids=(device_index,))]) + else: + op_sharding = xc.OpSharding() # type: ignore[assignment] + op_sharding.type = xc.OpSharding.Type.MAXIMAL + op_sharding.tile_assignment_dimensions = [1] + op_sharding.tile_assignment_devices = [device_index] return op_sharding # When there's no SPMD partitioning going on, don't annotate a sharding. @@ -291,7 +311,8 @@ def pure_callback_lowering( ) ) - op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding) + op_sharding = _callback_op_sharding( + ctx.module_context.axis_context, sharding, ctx.avals_out) result, _, _ = mlir.emit_python_callback( ctx, _callback, @@ -561,7 +582,8 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params): ) ) - op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding) + op_sharding = _callback_op_sharding( + ctx.module_context.axis_context, sharding, ctx.avals_out) if ordered: token = ctx.tokens_in.get(_OrderedIOEffect) result, token, _ = mlir.emit_python_callback( @@ -576,7 +598,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params): ) ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token})) else: - result, token, _ = mlir.emit_python_callback( + result, _, _ = mlir.emit_python_callback( ctx, _callback, None, diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 9d3d3f85f..7685ac2bf 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -15,9 +15,9 @@ from __future__ import annotations -import importlib.util from collections.abc import Callable, Sequence import functools +import importlib.util import logging import string import sys @@ -29,12 +29,12 @@ import numpy as np import jax import jax.numpy as jnp from jax import lax - +from jax._src import config from jax._src import core +from jax._src import dispatch from jax._src import effects from jax._src import mesh as mesh_lib from jax._src import sharding_impls -from jax._src import dispatch from jax._src import tree_util from jax._src import util from jax._src.interpreters import ad @@ -135,22 +135,37 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): axis_context = ctx.module_context.axis_context if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)): - # If we have fully manual sharding during lowering, that means the JAX - # program has per-device semantics, so we run the callback on each device. - sharding = xc.OpSharding() - sharding.type = xc.OpSharding.Type.MANUAL + if config.use_shardy_partitioner.value: + assert len(ctx.avals_out) == 1 + sharding = sharding_impls.SdyArrayShardingList([ + sharding_impls.SdyArraySharding( + mesh_shape=(), + dimension_shardings=[ + sharding_impls.SdyDimSharding(axes=[], is_closed=True) + ] * ctx.avals_out[0].ndim, + logical_device_ids=())]) + else: + # If we have fully manual sharding during lowering, that means the JAX + # program has per-device semantics, so we run the callback on each device. + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MANUAL elif isinstance( axis_context, (sharding_impls.ShardingContext, sharding_impls.SPMDAxisContext), ): - # If we have fully automatic sharding during lowering, that means the JAX - # program has bulk array semantics, so we run the callback with a MAXIMAL - # sharding and hence execute it only once on the full logical value). - # If we have partially automatic sharding, we do this too... not sure why! - sharding = xc.OpSharding() - sharding.type = xc.OpSharding.Type.MAXIMAL - sharding.tile_assignment_dimensions = [1] - sharding.tile_assignment_devices = [0] + if config.use_shardy_partitioner.value: + sharding = sharding_impls.SdyArrayShardingList([ + sharding_impls.SdyArraySharding( + mesh_shape=(), dimension_shardings=[], logical_device_ids=(0,))]) + else: + # If we have fully automatic sharding during lowering, that means the JAX + # program has bulk array semantics, so we run the callback with a MAXIMAL + # sharding and hence execute it only once on the full logical value). + # If we have partially automatic sharding, we do this too... not sure why! + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MAXIMAL + sharding.tile_assignment_dimensions = [1] + sharding.tile_assignment_devices = [0] else: # When there's no SPMD partitioning going on, don't annotate a sharding. sharding = None diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f27566632..92f3b72e4 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -50,7 +50,8 @@ from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import (AUTO, NamedSharding, - modify_sdy_sharding_wrt_axis_types) + modify_sdy_sharding_wrt_axis_types, + SdyArraySharding, SdyArrayShardingList) from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension from jax._src.lib.mlir import dialects, ir, passmanager @@ -1019,7 +1020,7 @@ def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim): def _to_physical_op_sharding( ctx: ModuleContext, aval: core.AbstractValue, sharding: JSharding | AUTO | None, -) -> xc.OpSharding | sharding_impls.SdyArraySharding | None: +) -> xc.OpSharding | SdyArraySharding | None: if sharding is None: return None if isinstance(sharding, AUTO): @@ -1749,7 +1750,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) if config.use_shardy_partitioner.value: physical_ndim = core.physical_aval(aval).ndim - s = sharding_impls.SdyArraySharding( + s = SdyArraySharding( mesh_shape=None, dimension_shardings=[ sharding_impls.SdyDimSharding(axes=[], is_closed=i >= aval.ndim) @@ -2554,7 +2555,7 @@ def _wrap_with_spmd_op(name: str, ctx: LoweringRuleContext, x: ir.Value, aval_out: core.AbstractValue, - sharding: xc.OpSharding | sharding_impls.SdyArraySharding, + sharding: xc.OpSharding | SdyArraySharding, unspecified_dims: set[int] | None = None, has_side_effect: bool = False, allow_shardy_lowering: bool = False): @@ -2615,7 +2616,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims) -def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding): +def set_sharding(op, sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList): if config.use_shardy_partitioner.value: op.attributes["sdy.sharding"] = get_sharding_attr(sharding) else: @@ -2623,7 +2624,7 @@ def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding): def get_sharding_attr( - sharding: xc.OpSharding | sharding_impls.SdyArraySharding + sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList ) -> ir.Attribute: if config.use_shardy_partitioner.value: return sharding.build() # type: ignore @@ -2783,9 +2784,15 @@ RECV_FROM_HOST_TYPE = 3 def is_empty_shape(s: core.Shape) -> bool: return any(d == 0 for d in s) -def send_to_host(channel: int, token: hlo.TokenType, operand: Any, - aval: core.ShapedArray, name: str, *, - sharding: xc.OpSharding | None = None) -> ir.Value: + +def send_to_host( + channel: int, + token: hlo.TokenType, + operand: Any, + name: str, + *, + sharding: SdyArrayShardingList | xc.OpSharding | None = None, +) -> ir.Value: channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE) send_op = hlo.SendOp([operand], token, channel_handle, is_host_transfer=ir.BoolAttr.get(True)) @@ -2794,13 +2801,27 @@ def send_to_host(channel: int, token: hlo.TokenType, operand: Any, _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) if sharding is not None: + if config.use_shardy_partitioner.value: + # `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed + # in the maximal sharding of the array type. Since a token has no rank, + # we need to create an equivalent sharding with no dimensions. + assert isinstance(sharding, SdyArrayShardingList) + assert len(sharding.shardings) == 1 + sharding = SdyArrayShardingList([ + SdyArraySharding( + mesh_shape=(), dimension_shardings=[], + logical_device_ids=sharding.shardings[0].logical_device_ids)]) set_sharding(send_op, sharding) return send_op.result -def receive_from_host(channel: int, token: hlo.TokenType, - out_aval: core.ShapedArray, name: str, *, - sharding: xc.OpSharding | None = None, +def receive_from_host( + channel: int, + token: hlo.TokenType, + out_aval: core.ShapedArray, + name: str, + *, + sharding: SdyArrayShardingList | xc.OpSharding | None = None, ) -> tuple[ir.Value, ir.Value]: channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE) recv_op = hlo.RecvOp([aval_to_ir_type(out_aval), @@ -2811,6 +2832,17 @@ def receive_from_host(channel: int, token: hlo.TokenType, _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) if sharding is not None: + if config.use_shardy_partitioner.value: + assert isinstance(sharding, SdyArrayShardingList) + assert len(sharding.shardings) == 1 + # `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the + # number of shardings to match the number of results, but JAX only sees + # the array result, we need to add an equivalent sharding for the token. + sharding = SdyArrayShardingList([ + sharding.shardings[0], + SdyArraySharding( + mesh_shape=(), dimension_shardings=[], + logical_device_ids=sharding.shardings[0].logical_device_ids)]) set_sharding(recv_op, sharding) # Token should be at the end of the results result, token = recv_op.results @@ -2828,7 +2860,7 @@ def _emit_tpu_python_callback( result_avals: Sequence[core.ShapedArray], result_shapes: Sequence[xc.Shape], *, - sharding: xc.OpSharding | None = None + sharding: SdyArrayShardingList | xc.OpSharding | None = None, ) -> tuple[Sequence[ir.Value], Any]: token = token or hlo.create_token() _wrapped_callback = callback @@ -2848,14 +2880,14 @@ def _emit_tpu_python_callback( dummy_send_val = ir_constant(np.zeros(1, np.float32)) operand_shapes = [*operand_shapes, xla.aval_to_xla_shapes(dummy_send_aval)[0]] - token = send_to_host(send_channel, token, dummy_send_val, dummy_send_aval, - callback.__name__, sharding=sharding) + token = send_to_host(send_channel, token, dummy_send_val, callback.__name__, + sharding=sharding) send_channels.append(send_channel) else: - for operand, operand_aval in zip(operands, operand_avals): + for operand in operands: channel = ctx.module_context.new_channel() - token = send_to_host(channel, token, operand, operand_aval, - callback.__name__, sharding=sharding) + token = send_to_host(channel, token, operand, callback.__name__, + sharding=sharding) send_channels.append(channel) recv_channels = [] @@ -2873,6 +2905,7 @@ def _emit_tpu_python_callback( ctx.module_context.add_host_callback(ifrt_callback) return outputs, token + def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None): if minor_to_major is None: # Needed for token layouts @@ -2896,7 +2929,7 @@ def emit_python_callback( result_avals: Sequence[core.ShapedArray], *, has_side_effect: bool, - sharding: xc.OpSharding | None = None, + sharding: SdyArrayShardingList | xc.OpSharding | None = None, operand_layouts: Sequence[Sequence[int] | None] | None = None, result_layouts: Sequence[Sequence[int] | None] | None = None, ) -> tuple[Sequence[IrValues], Any, Any]: @@ -3024,6 +3057,7 @@ def emit_python_callback( token, *results = results return results, token, ifrt_callback + def build_mlir_module_helper( closed_jaxpr: core.ClosedJaxpr, *, name: str, platforms: Sequence[str], diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 519a2f21c..53a19fd7f 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -118,7 +118,6 @@ class SdyDimSharding: is_closed: bool priority: int | None = None - # NOTE: An MLIR context is required as a context manager. def build(self) -> sdy.DimensionShardingAttr: return sdy.DimensionShardingAttr.get( [sdy.AxisRefAttr.get(axis) for axis in self.axes], @@ -144,7 +143,6 @@ class SdyArraySharding: logical_device_ids: tuple[int, ...] | None = None replicated_axes: tuple[str, ...] = () - # NOTE: An MLIR context is required as a context manager. def build(self) -> sdy.TensorShardingAttr: if self.mesh_shape is None: mesh_attr = sdy.MeshAttr.get([]) @@ -169,6 +167,15 @@ class SdyArraySharding: return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})" +@dataclasses.dataclass +class SdyArrayShardingList: + shardings: Sequence[SdyArraySharding] + + def build(self) -> sdy.TensorShardingPerValueAttr: + return sdy.TensorShardingPerValueAttr.get( + [sharding.build() for sharding in self.shardings]) + + @util.cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( self, num_dimensions: int) -> xc.HloSharding: diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 910fa4728..05f44f313 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -622,9 +622,10 @@ def _rule_missing(prim: core.Primitive, *_, **__): # Lowering + def _shardy_shard_map_sharding( ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in - ) -> ir.Attribute: +) -> sharding_impls.SdyArraySharding: axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): @@ -634,7 +635,7 @@ def _shardy_shard_map_sharding( if auto: for dim_sharding in sdy_sharding.dimension_shardings: dim_sharding.is_closed = False - return sdy_sharding.build() + return sdy_sharding def _shard_map_lowering_shardy( @@ -664,12 +665,12 @@ def _shard_map_lowering_shardy( dim_var_values=ctx.dim_var_values) return out_nodes - in_shardings = sdy.TensorShardingPerValueAttr.get(map( + in_shardings = sharding_impls.SdyArrayShardingList(map( partial(_shardy_shard_map_sharding, ctx, mesh, auto), - in_names, ctx.avals_in)) - out_shardings = sdy.TensorShardingPerValueAttr.get(map( + in_names, ctx.avals_in)).build() + out_shardings = sharding_impls.SdyArrayShardingList(map( partial(_shardy_shard_map_sharding, ctx, mesh, auto), - out_names, ctx.avals_out)) + out_names, ctx.avals_out)).build() output_types = map(mlir.aval_to_ir_type, ctx.avals_out) manual_computation_op = sdy.ManualComputationOp( output_types, args, in_shardings, out_shardings, diff --git a/tests/BUILD b/tests/BUILD index 126ca3327..a673a5aef 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1278,15 +1278,14 @@ jax_multiplatform_test( jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], - disable_configs = [ - "cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported. - ], enable_configs = [ "cpu", "gpu_h100", "tpu_v2_1x1", "tpu_v3_2x2", "tpu_v4_2x2", + "gpu_a100_shardy", + "tpu_v3_2x2_shardy", ], ) @@ -1296,13 +1295,12 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, - disable_configs = [ - "cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported. - ], enable_configs = [ "tpu_v2_1x1", "tpu_v3_2x2", "tpu_v4_2x2", + "tpu_v3_2x2_shardy", + "gpu_2gpu_shardy", ], tags = ["multiaccelerator"], deps = [ diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index efa877fd3..5a3b9bab5 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -937,11 +937,19 @@ class PureCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose( out, np.arange(jax.local_device_count()) * 2 ) - - self.assertIn( - f'{{maximal device={callback_device_index}}}', - str(f_jit.lower(inp).compiler_ir(dialect='stablehlo')), - ) + stablehlo_ir = f_jit.lower(inp).as_text() + if config.use_shardy_partitioner.value: + self.assertIn( + "sdy.sharding =" + f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index}," + " []>]>", + stablehlo_ir) + self.assertIn( + f"sdy.mesh @maximal_mesh_{callback_device_index} = <[]," + f" device_ids=[{callback_device_index}]>", + stablehlo_ir) + else: + self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir) def test_can_shard_pure_callback_manually(self): @@ -1199,10 +1207,19 @@ class IOCallbackTest(jtu.JaxTestCase): self.assertIn(v, _collected) callback_device_index = in_spec._device_assignment.index(callback_device) - self.assertIn( - f'{{maximal device={callback_device_index}}}', - str(f.lower(x).compiler_ir(dialect='stablehlo')), - ) + stablehlo_ir = f.lower(x).as_text() + if config.use_shardy_partitioner.value: + self.assertIn( + "sdy.sharding =" + f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index}," + " []>]>", + stablehlo_ir) + self.assertIn( + f"sdy.mesh @maximal_mesh_{callback_device_index} = <[]," + f" device_ids=[{callback_device_index}]>", + stablehlo_ir) + else: + self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir) def test_sequence_pjit_io_callback_ordered(self): # A sequence of pairs of calls to pjit(io_callback(ordered=True)) with each @@ -1257,6 +1274,8 @@ class IOCallbackTest(jtu.JaxTestCase): self.assertEqual(_collected, expected) def test_can_shard_io_callback_manually(self): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/384938613): Failing under shardy.") mesh = Mesh(np.array(jax.devices()), axis_names=('x',))