mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #11717 from sharadmv:move-sharding
PiperOrigin-RevId: 465369371
This commit is contained in:
commit
5ec9226845
@ -27,6 +27,7 @@ from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lax import control_flow as lcf
|
||||
from jax._src.lib import xla_client as xc
|
||||
import jax.numpy as jnp
|
||||
|
||||
DebugEffect = enum.Enum('DebugEffect', ['PRINT', 'ORDERED_PRINT'])
|
||||
@ -88,6 +89,16 @@ ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule
|
||||
|
||||
def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
|
||||
if isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
# Apply maximal sharding so pjit only executes the callback on device 0.
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
sharding.tile_assignment_dimensions = [1]
|
||||
sharding.tile_assignment_devices = [0]
|
||||
else:
|
||||
sharding = None
|
||||
|
||||
def _callback(*flat_args):
|
||||
return tuple(
|
||||
debug_callback_p.impl(
|
||||
@ -99,7 +110,8 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
|
||||
else:
|
||||
result, token, keepalive = mlir.emit_python_callback(
|
||||
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True)
|
||||
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True,
|
||||
sharding=sharding)
|
||||
ctx.module_context.add_keepalive(keepalive)
|
||||
return result
|
||||
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
|
||||
|
@ -1173,10 +1173,19 @@ def _outside_call_lowering(
|
||||
result_arrays = ()
|
||||
return result_arrays
|
||||
|
||||
if isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
# Apply maximal sharding so pjit only executes the callback on device 0.
|
||||
sharding = xla_client.OpSharding()
|
||||
sharding.type = xla_client.OpSharding.Type.MAXIMAL
|
||||
sharding.tile_assignment_dimensions = [1]
|
||||
sharding.tile_assignment_devices = [0]
|
||||
else:
|
||||
sharding = None
|
||||
results, next_token, keep_alive = mlir.emit_python_callback(ctx,
|
||||
wrapped_callback, current_token, callback_operands,
|
||||
callback_operand_avals, callback_flat_results_aval, # type: ignore[arg-type]
|
||||
has_side_effect=True)
|
||||
has_side_effect=True, sharding=sharding)
|
||||
_callback_handler_data.keep_alives.append(keep_alive)
|
||||
# We must put the two tokens at the end
|
||||
if identity:
|
||||
|
@ -1419,7 +1419,8 @@ def emit_python_callback(
|
||||
ctx: LoweringRuleContext, callback, token: Optional[Any],
|
||||
operands: List[ir.Value], operand_avals: List[core.AbstractValue],
|
||||
result_avals: List[core.AbstractValue],
|
||||
has_side_effect: bool) -> Tuple[List[ir.Value], Any, Any]:
|
||||
has_side_effect: bool, *, sharding: Optional[xc.OpSharding] = None
|
||||
) -> Tuple[List[ir.Value], Any, Any]:
|
||||
"""Creates an MHLO `CustomCallOp` that calls back to the provided function."""
|
||||
platform = ctx.module_context.platform
|
||||
if platform in {"tpu"} and jax._src.lib.version < (0, 3, 15):
|
||||
@ -1433,15 +1434,6 @@ def emit_python_callback(
|
||||
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
|
||||
operand_shapes = util.flatten(
|
||||
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
|
||||
if isinstance(ctx.module_context.axis_context,
|
||||
(SPMDAxisContext, ShardingContext)):
|
||||
# Apply maximal sharding so pjit only executes the callback on device 0.
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
sharding.tile_assignment_dimensions = [1]
|
||||
sharding.tile_assignment_devices = [0]
|
||||
else:
|
||||
sharding = None
|
||||
if platform == "tpu":
|
||||
if result_avals:
|
||||
raise NotImplementedError(
|
||||
|
Loading…
x
Reference in New Issue
Block a user