mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
#sdy enable pure callbacks and debug prints in JAX.
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up. PiperOrigin-RevId: 713780181
This commit is contained in:
parent
93ef0f13fe
commit
dc53c563bb
@ -21,6 +21,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
@ -225,7 +226,9 @@ batching.primitive_batchers[pure_callback_p] = functools.partial(
|
||||
)
|
||||
|
||||
|
||||
def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
|
||||
def _callback_op_sharding(
|
||||
axis_context, sharding: SingleDeviceSharding | None, avals_out
|
||||
):
|
||||
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
# If we have fully manual sharding during lowering, that means the JAX
|
||||
# program has per-device semantics, so we run the callback on each device.
|
||||
@ -239,8 +242,18 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
|
||||
"callbacks do not support specifying sharding inside spmd"
|
||||
" computations"
|
||||
)
|
||||
op_sharding = xc.OpSharding()
|
||||
op_sharding.type = xc.OpSharding.Type.MANUAL
|
||||
if config.use_shardy_partitioner.value:
|
||||
assert len(avals_out) == 1
|
||||
op_sharding = sharding_impls.SdyArrayShardingList([
|
||||
sharding_impls.SdyArraySharding(
|
||||
mesh_shape=(),
|
||||
dimension_shardings=[
|
||||
sharding_impls.SdyDimSharding(axes=[], is_closed=True)
|
||||
] * avals_out[0].ndim,
|
||||
logical_device_ids=())])
|
||||
else:
|
||||
op_sharding = xc.OpSharding() # type: ignore[assignment]
|
||||
op_sharding.type = xc.OpSharding.Type.MANUAL
|
||||
return op_sharding
|
||||
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
@ -268,10 +281,17 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
|
||||
# If we have fully automatic sharding during lowering, that means the JAX
|
||||
# program has bulk array semantics, so we run the callback with a MAXIMAL
|
||||
# sharding and hence execute it only once on the full logical value).
|
||||
op_sharding = xc.OpSharding()
|
||||
op_sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
op_sharding.tile_assignment_dimensions = [1]
|
||||
op_sharding.tile_assignment_devices = [device_index]
|
||||
if config.use_shardy_partitioner.value:
|
||||
op_sharding = sharding_impls.SdyArrayShardingList([
|
||||
sharding_impls.SdyArraySharding(
|
||||
mesh_shape=(),
|
||||
dimension_shardings=[],
|
||||
logical_device_ids=(device_index,))])
|
||||
else:
|
||||
op_sharding = xc.OpSharding() # type: ignore[assignment]
|
||||
op_sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
op_sharding.tile_assignment_dimensions = [1]
|
||||
op_sharding.tile_assignment_devices = [device_index]
|
||||
return op_sharding
|
||||
|
||||
# When there's no SPMD partitioning going on, don't annotate a sharding.
|
||||
@ -291,7 +311,8 @@ def pure_callback_lowering(
|
||||
)
|
||||
)
|
||||
|
||||
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
|
||||
op_sharding = _callback_op_sharding(
|
||||
ctx.module_context.axis_context, sharding, ctx.avals_out)
|
||||
result, _, _ = mlir.emit_python_callback(
|
||||
ctx,
|
||||
_callback,
|
||||
@ -561,7 +582,8 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
||||
)
|
||||
)
|
||||
|
||||
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
|
||||
op_sharding = _callback_op_sharding(
|
||||
ctx.module_context.axis_context, sharding, ctx.avals_out)
|
||||
if ordered:
|
||||
token = ctx.tokens_in.get(_OrderedIOEffect)
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
@ -576,7 +598,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
||||
)
|
||||
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token}))
|
||||
else:
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
result, _, _ = mlir.emit_python_callback(
|
||||
ctx,
|
||||
_callback,
|
||||
None,
|
||||
|
@ -15,9 +15,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
import importlib.util
|
||||
import logging
|
||||
import string
|
||||
import sys
|
||||
@ -29,12 +29,12 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import dispatch
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
@ -135,22 +135,37 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
|
||||
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
|
||||
# If we have fully manual sharding during lowering, that means the JAX
|
||||
# program has per-device semantics, so we run the callback on each device.
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MANUAL
|
||||
if config.use_shardy_partitioner.value:
|
||||
assert len(ctx.avals_out) == 1
|
||||
sharding = sharding_impls.SdyArrayShardingList([
|
||||
sharding_impls.SdyArraySharding(
|
||||
mesh_shape=(),
|
||||
dimension_shardings=[
|
||||
sharding_impls.SdyDimSharding(axes=[], is_closed=True)
|
||||
] * ctx.avals_out[0].ndim,
|
||||
logical_device_ids=())])
|
||||
else:
|
||||
# If we have fully manual sharding during lowering, that means the JAX
|
||||
# program has per-device semantics, so we run the callback on each device.
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MANUAL
|
||||
elif isinstance(
|
||||
axis_context,
|
||||
(sharding_impls.ShardingContext, sharding_impls.SPMDAxisContext),
|
||||
):
|
||||
# If we have fully automatic sharding during lowering, that means the JAX
|
||||
# program has bulk array semantics, so we run the callback with a MAXIMAL
|
||||
# sharding and hence execute it only once on the full logical value).
|
||||
# If we have partially automatic sharding, we do this too... not sure why!
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
sharding.tile_assignment_dimensions = [1]
|
||||
sharding.tile_assignment_devices = [0]
|
||||
if config.use_shardy_partitioner.value:
|
||||
sharding = sharding_impls.SdyArrayShardingList([
|
||||
sharding_impls.SdyArraySharding(
|
||||
mesh_shape=(), dimension_shardings=[], logical_device_ids=(0,))])
|
||||
else:
|
||||
# If we have fully automatic sharding during lowering, that means the JAX
|
||||
# program has bulk array semantics, so we run the callback with a MAXIMAL
|
||||
# sharding and hence execute it only once on the full logical value).
|
||||
# If we have partially automatic sharding, we do this too... not sure why!
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
sharding.tile_assignment_dimensions = [1]
|
||||
sharding.tile_assignment_devices = [0]
|
||||
else:
|
||||
# When there's no SPMD partitioning going on, don't annotate a sharding.
|
||||
sharding = None
|
||||
|
@ -50,7 +50,8 @@ from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import (AUTO, NamedSharding,
|
||||
modify_sdy_sharding_wrt_axis_types)
|
||||
modify_sdy_sharding_wrt_axis_types,
|
||||
SdyArraySharding, SdyArrayShardingList)
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib.mlir import dialects, ir, passmanager
|
||||
@ -1019,7 +1020,7 @@ def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim):
|
||||
def _to_physical_op_sharding(
|
||||
ctx: ModuleContext,
|
||||
aval: core.AbstractValue, sharding: JSharding | AUTO | None,
|
||||
) -> xc.OpSharding | sharding_impls.SdyArraySharding | None:
|
||||
) -> xc.OpSharding | SdyArraySharding | None:
|
||||
if sharding is None:
|
||||
return None
|
||||
if isinstance(sharding, AUTO):
|
||||
@ -1749,7 +1750,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
|
||||
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
|
||||
if config.use_shardy_partitioner.value:
|
||||
physical_ndim = core.physical_aval(aval).ndim
|
||||
s = sharding_impls.SdyArraySharding(
|
||||
s = SdyArraySharding(
|
||||
mesh_shape=None,
|
||||
dimension_shardings=[
|
||||
sharding_impls.SdyDimSharding(axes=[], is_closed=i >= aval.ndim)
|
||||
@ -2554,7 +2555,7 @@ def _wrap_with_spmd_op(name: str,
|
||||
ctx: LoweringRuleContext,
|
||||
x: ir.Value,
|
||||
aval_out: core.AbstractValue,
|
||||
sharding: xc.OpSharding | sharding_impls.SdyArraySharding,
|
||||
sharding: xc.OpSharding | SdyArraySharding,
|
||||
unspecified_dims: set[int] | None = None,
|
||||
has_side_effect: bool = False,
|
||||
allow_shardy_lowering: bool = False):
|
||||
@ -2615,7 +2616,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
|
||||
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
|
||||
|
||||
|
||||
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
|
||||
def set_sharding(op, sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList):
|
||||
if config.use_shardy_partitioner.value:
|
||||
op.attributes["sdy.sharding"] = get_sharding_attr(sharding)
|
||||
else:
|
||||
@ -2623,7 +2624,7 @@ def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
|
||||
|
||||
|
||||
def get_sharding_attr(
|
||||
sharding: xc.OpSharding | sharding_impls.SdyArraySharding
|
||||
sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList
|
||||
) -> ir.Attribute:
|
||||
if config.use_shardy_partitioner.value:
|
||||
return sharding.build() # type: ignore
|
||||
@ -2783,9 +2784,15 @@ RECV_FROM_HOST_TYPE = 3
|
||||
def is_empty_shape(s: core.Shape) -> bool:
|
||||
return any(d == 0 for d in s)
|
||||
|
||||
def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
|
||||
aval: core.ShapedArray, name: str, *,
|
||||
sharding: xc.OpSharding | None = None) -> ir.Value:
|
||||
|
||||
def send_to_host(
|
||||
channel: int,
|
||||
token: hlo.TokenType,
|
||||
operand: Any,
|
||||
name: str,
|
||||
*,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
) -> ir.Value:
|
||||
channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
|
||||
send_op = hlo.SendOp([operand], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
@ -2794,13 +2801,27 @@ def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
|
||||
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
||||
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
||||
if sharding is not None:
|
||||
if config.use_shardy_partitioner.value:
|
||||
# `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed
|
||||
# in the maximal sharding of the array type. Since a token has no rank,
|
||||
# we need to create an equivalent sharding with no dimensions.
|
||||
assert isinstance(sharding, SdyArrayShardingList)
|
||||
assert len(sharding.shardings) == 1
|
||||
sharding = SdyArrayShardingList([
|
||||
SdyArraySharding(
|
||||
mesh_shape=(), dimension_shardings=[],
|
||||
logical_device_ids=sharding.shardings[0].logical_device_ids)])
|
||||
set_sharding(send_op, sharding)
|
||||
return send_op.result
|
||||
|
||||
|
||||
def receive_from_host(channel: int, token: hlo.TokenType,
|
||||
out_aval: core.ShapedArray, name: str, *,
|
||||
sharding: xc.OpSharding | None = None,
|
||||
def receive_from_host(
|
||||
channel: int,
|
||||
token: hlo.TokenType,
|
||||
out_aval: core.ShapedArray,
|
||||
name: str,
|
||||
*,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
) -> tuple[ir.Value, ir.Value]:
|
||||
channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
|
||||
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
|
||||
@ -2811,6 +2832,17 @@ def receive_from_host(channel: int, token: hlo.TokenType,
|
||||
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
||||
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
||||
if sharding is not None:
|
||||
if config.use_shardy_partitioner.value:
|
||||
assert isinstance(sharding, SdyArrayShardingList)
|
||||
assert len(sharding.shardings) == 1
|
||||
# `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the
|
||||
# number of shardings to match the number of results, but JAX only sees
|
||||
# the array result, we need to add an equivalent sharding for the token.
|
||||
sharding = SdyArrayShardingList([
|
||||
sharding.shardings[0],
|
||||
SdyArraySharding(
|
||||
mesh_shape=(), dimension_shardings=[],
|
||||
logical_device_ids=sharding.shardings[0].logical_device_ids)])
|
||||
set_sharding(recv_op, sharding)
|
||||
# Token should be at the end of the results
|
||||
result, token = recv_op.results
|
||||
@ -2828,7 +2860,7 @@ def _emit_tpu_python_callback(
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
result_shapes: Sequence[xc.Shape],
|
||||
*,
|
||||
sharding: xc.OpSharding | None = None
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
) -> tuple[Sequence[ir.Value], Any]:
|
||||
token = token or hlo.create_token()
|
||||
_wrapped_callback = callback
|
||||
@ -2848,14 +2880,14 @@ def _emit_tpu_python_callback(
|
||||
dummy_send_val = ir_constant(np.zeros(1, np.float32))
|
||||
operand_shapes = [*operand_shapes,
|
||||
xla.aval_to_xla_shapes(dummy_send_aval)[0]]
|
||||
token = send_to_host(send_channel, token, dummy_send_val, dummy_send_aval,
|
||||
callback.__name__, sharding=sharding)
|
||||
token = send_to_host(send_channel, token, dummy_send_val, callback.__name__,
|
||||
sharding=sharding)
|
||||
send_channels.append(send_channel)
|
||||
else:
|
||||
for operand, operand_aval in zip(operands, operand_avals):
|
||||
for operand in operands:
|
||||
channel = ctx.module_context.new_channel()
|
||||
token = send_to_host(channel, token, operand, operand_aval,
|
||||
callback.__name__, sharding=sharding)
|
||||
token = send_to_host(channel, token, operand, callback.__name__,
|
||||
sharding=sharding)
|
||||
send_channels.append(channel)
|
||||
|
||||
recv_channels = []
|
||||
@ -2873,6 +2905,7 @@ def _emit_tpu_python_callback(
|
||||
ctx.module_context.add_host_callback(ifrt_callback)
|
||||
return outputs, token
|
||||
|
||||
|
||||
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
|
||||
if minor_to_major is None:
|
||||
# Needed for token layouts
|
||||
@ -2896,7 +2929,7 @@ def emit_python_callback(
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
*,
|
||||
has_side_effect: bool,
|
||||
sharding: xc.OpSharding | None = None,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
operand_layouts: Sequence[Sequence[int] | None] | None = None,
|
||||
result_layouts: Sequence[Sequence[int] | None] | None = None,
|
||||
) -> tuple[Sequence[IrValues], Any, Any]:
|
||||
@ -3024,6 +3057,7 @@ def emit_python_callback(
|
||||
token, *results = results
|
||||
return results, token, ifrt_callback
|
||||
|
||||
|
||||
def build_mlir_module_helper(
|
||||
closed_jaxpr: core.ClosedJaxpr, *, name: str,
|
||||
platforms: Sequence[str],
|
||||
|
@ -118,7 +118,6 @@ class SdyDimSharding:
|
||||
is_closed: bool
|
||||
priority: int | None = None
|
||||
|
||||
# NOTE: An MLIR context is required as a context manager.
|
||||
def build(self) -> sdy.DimensionShardingAttr:
|
||||
return sdy.DimensionShardingAttr.get(
|
||||
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
|
||||
@ -144,7 +143,6 @@ class SdyArraySharding:
|
||||
logical_device_ids: tuple[int, ...] | None = None
|
||||
replicated_axes: tuple[str, ...] = ()
|
||||
|
||||
# NOTE: An MLIR context is required as a context manager.
|
||||
def build(self) -> sdy.TensorShardingAttr:
|
||||
if self.mesh_shape is None:
|
||||
mesh_attr = sdy.MeshAttr.get([])
|
||||
@ -169,6 +167,15 @@ class SdyArraySharding:
|
||||
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SdyArrayShardingList:
|
||||
shardings: Sequence[SdyArraySharding]
|
||||
|
||||
def build(self) -> sdy.TensorShardingPerValueAttr:
|
||||
return sdy.TensorShardingPerValueAttr.get(
|
||||
[sharding.build() for sharding in self.shardings])
|
||||
|
||||
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def named_sharding_to_xla_hlo_sharding(
|
||||
self, num_dimensions: int) -> xc.HloSharding:
|
||||
|
@ -622,9 +622,10 @@ def _rule_missing(prim: core.Primitive, *_, **__):
|
||||
|
||||
# Lowering
|
||||
|
||||
|
||||
def _shardy_shard_map_sharding(
|
||||
ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in
|
||||
) -> ir.Attribute:
|
||||
) -> sharding_impls.SdyArraySharding:
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
|
||||
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
||||
@ -634,7 +635,7 @@ def _shardy_shard_map_sharding(
|
||||
if auto:
|
||||
for dim_sharding in sdy_sharding.dimension_shardings:
|
||||
dim_sharding.is_closed = False
|
||||
return sdy_sharding.build()
|
||||
return sdy_sharding
|
||||
|
||||
|
||||
def _shard_map_lowering_shardy(
|
||||
@ -664,12 +665,12 @@ def _shard_map_lowering_shardy(
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
return out_nodes
|
||||
|
||||
in_shardings = sdy.TensorShardingPerValueAttr.get(map(
|
||||
in_shardings = sharding_impls.SdyArrayShardingList(map(
|
||||
partial(_shardy_shard_map_sharding, ctx, mesh, auto),
|
||||
in_names, ctx.avals_in))
|
||||
out_shardings = sdy.TensorShardingPerValueAttr.get(map(
|
||||
in_names, ctx.avals_in)).build()
|
||||
out_shardings = sharding_impls.SdyArrayShardingList(map(
|
||||
partial(_shardy_shard_map_sharding, ctx, mesh, auto),
|
||||
out_names, ctx.avals_out))
|
||||
out_names, ctx.avals_out)).build()
|
||||
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
|
||||
manual_computation_op = sdy.ManualComputationOp(
|
||||
output_types, args, in_shardings, out_shardings,
|
||||
|
10
tests/BUILD
10
tests/BUILD
@ -1278,15 +1278,14 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "debugging_primitives_test",
|
||||
srcs = ["debugging_primitives_test.py"],
|
||||
disable_configs = [
|
||||
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
|
||||
],
|
||||
enable_configs = [
|
||||
"cpu",
|
||||
"gpu_h100",
|
||||
"tpu_v2_1x1",
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v4_2x2",
|
||||
"gpu_a100_shardy",
|
||||
"tpu_v3_2x2_shardy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1296,13 +1295,12 @@ jax_multiplatform_test(
|
||||
backend_tags = {
|
||||
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
||||
},
|
||||
disable_configs = [
|
||||
"cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported.
|
||||
],
|
||||
enable_configs = [
|
||||
"tpu_v2_1x1",
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v4_2x2",
|
||||
"tpu_v3_2x2_shardy",
|
||||
"gpu_2gpu_shardy",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
|
@ -937,11 +937,19 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(
|
||||
out, np.arange(jax.local_device_count()) * 2
|
||||
)
|
||||
|
||||
self.assertIn(
|
||||
f'{{maximal device={callback_device_index}}}',
|
||||
str(f_jit.lower(inp).compiler_ir(dialect='stablehlo')),
|
||||
)
|
||||
stablehlo_ir = f_jit.lower(inp).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn(
|
||||
"sdy.sharding ="
|
||||
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
|
||||
" []>]>",
|
||||
stablehlo_ir)
|
||||
self.assertIn(
|
||||
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
|
||||
f" device_ids=[{callback_device_index}]>",
|
||||
stablehlo_ir)
|
||||
else:
|
||||
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
|
||||
|
||||
def test_can_shard_pure_callback_manually(self):
|
||||
|
||||
@ -1199,10 +1207,19 @@ class IOCallbackTest(jtu.JaxTestCase):
|
||||
self.assertIn(v, _collected)
|
||||
|
||||
callback_device_index = in_spec._device_assignment.index(callback_device)
|
||||
self.assertIn(
|
||||
f'{{maximal device={callback_device_index}}}',
|
||||
str(f.lower(x).compiler_ir(dialect='stablehlo')),
|
||||
)
|
||||
stablehlo_ir = f.lower(x).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn(
|
||||
"sdy.sharding ="
|
||||
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
|
||||
" []>]>",
|
||||
stablehlo_ir)
|
||||
self.assertIn(
|
||||
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
|
||||
f" device_ids=[{callback_device_index}]>",
|
||||
stablehlo_ir)
|
||||
else:
|
||||
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
|
||||
|
||||
def test_sequence_pjit_io_callback_ordered(self):
|
||||
# A sequence of pairs of calls to pjit(io_callback(ordered=True)) with each
|
||||
@ -1257,6 +1274,8 @@ class IOCallbackTest(jtu.JaxTestCase):
|
||||
self.assertEqual(_collected, expected)
|
||||
|
||||
def test_can_shard_io_callback_manually(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/384938613): Failing under shardy.")
|
||||
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user