jax.debug.callback now requires a Callable[..., None]

This makes the "return value is ignored" behavior explicit in the type.

PiperOrigin-RevId: 626430448
This commit is contained in:
Sergei Lebedev 2024-04-19 11:54:19 -07:00 committed by jax authors
parent b2375fa7e9
commit 32922f61e9

View File

@ -73,7 +73,8 @@ map, unsafe_map = util.safe_map, map
def debug_callback_impl(*args, callback: Callable[..., Any],
effect: DebugEffect):
del effect
return callback(*args)
callback(*args)
return ()
@debug_callback_p.def_effectful_abstract_eval
def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any],
@ -136,13 +137,13 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
sharding = None
def _callback(*flat_args):
return tuple(
debug_callback_p.impl(
*flat_args, effect=effect, callback=callback, **params))
*flat_args, effect=effect, callback=callback, **params)
return ()
if effects.ordered_effects.contains(effect):
token = ctx.tokens_in.get(effect)[0]
result, token, _ = mlir.emit_python_callback(
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, has_side_effect=True)
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
else:
result, token, _ = mlir.emit_python_callback(
@ -187,7 +188,7 @@ def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn):
pe.partial_eval_jaxpr_custom_rules[debug_callback_p] = (
_debug_callback_partial_eval_custom)
def debug_callback(callback: Callable[..., Any], *args: Any,
def debug_callback(callback: Callable[..., None], *args: Any,
ordered: bool = False, **kwargs: Any) -> None:
"""Calls a stageable Python callback.
@ -206,7 +207,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
of the computation are duplicated or dropped.
Args:
callback: A Python callable. Its return value will be ignored.
callback: A Python callable returning None.
*args: The positional arguments to the callback.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this callback w.r.t.
@ -231,7 +232,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
callback(*args, **kwargs)
return []
return ()
debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect)
class _DebugPrintFormatChecker(string.Formatter):