Merge pull request #11717 from sharadmv:move-sharding

PiperOrigin-RevId: 465369371
This commit is contained in:
jax authors 2022-08-04 11:55:57 -07:00
commit 5ec9226845
3 changed files with 25 additions and 12 deletions

View File

@ -27,6 +27,7 @@ from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc
import jax.numpy as jnp
DebugEffect = enum.Enum('DebugEffect', ['PRINT', 'ORDERED_PRINT'])
@ -88,6 +89,16 @@ ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule
def debug_callback_lowering(ctx, *args, effect, callback, **params):
if isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
# Apply maximal sharding so pjit only executes the callback on device 0.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [0]
else:
sharding = None
def _callback(*flat_args):
return tuple(
debug_callback_p.impl(
@ -99,7 +110,8 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
else:
result, token, keepalive = mlir.emit_python_callback(
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True)
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True,
sharding=sharding)
ctx.module_context.add_keepalive(keepalive)
return result
mlir.register_lowering(debug_callback_p, debug_callback_lowering,

View File

@ -1173,10 +1173,19 @@ def _outside_call_lowering(
result_arrays = ()
return result_arrays
if isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
# Apply maximal sharding so pjit only executes the callback on device 0.
sharding = xla_client.OpSharding()
sharding.type = xla_client.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [0]
else:
sharding = None
results, next_token, keep_alive = mlir.emit_python_callback(ctx,
wrapped_callback, current_token, callback_operands,
callback_operand_avals, callback_flat_results_aval, # type: ignore[arg-type]
has_side_effect=True)
has_side_effect=True, sharding=sharding)
_callback_handler_data.keep_alives.append(keep_alive)
# We must put the two tokens at the end
if identity:

View File

@ -1419,7 +1419,8 @@ def emit_python_callback(
ctx: LoweringRuleContext, callback, token: Optional[Any],
operands: List[ir.Value], operand_avals: List[core.AbstractValue],
result_avals: List[core.AbstractValue],
has_side_effect: bool) -> Tuple[List[ir.Value], Any, Any]:
has_side_effect: bool, *, sharding: Optional[xc.OpSharding] = None
) -> Tuple[List[ir.Value], Any, Any]:
"""Creates an MHLO `CustomCallOp` that calls back to the provided function."""
platform = ctx.module_context.platform
if platform in {"tpu"} and jax._src.lib.version < (0, 3, 15):
@ -1433,15 +1434,6 @@ def emit_python_callback(
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
operand_shapes = util.flatten(
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
if isinstance(ctx.module_context.axis_context,
(SPMDAxisContext, ShardingContext)):
# Apply maximal sharding so pjit only executes the callback on device 0.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [0]
else:
sharding = None
if platform == "tpu":
if result_avals:
raise NotImplementedError(