mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jax.debug.callback now requires a Callable[..., None]
This makes the "return value is ignored" behavior explicit in the type. PiperOrigin-RevId: 626430448
This commit is contained in:
parent
b2375fa7e9
commit
32922f61e9
@ -73,7 +73,8 @@ map, unsafe_map = util.safe_map, map
|
||||
def debug_callback_impl(*args, callback: Callable[..., Any],
|
||||
effect: DebugEffect):
|
||||
del effect
|
||||
return callback(*args)
|
||||
callback(*args)
|
||||
return ()
|
||||
|
||||
@debug_callback_p.def_effectful_abstract_eval
|
||||
def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any],
|
||||
@ -136,13 +137,13 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
sharding = None
|
||||
|
||||
def _callback(*flat_args):
|
||||
return tuple(
|
||||
debug_callback_p.impl(
|
||||
*flat_args, effect=effect, callback=callback, **params))
|
||||
*flat_args, effect=effect, callback=callback, **params)
|
||||
return ()
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token = ctx.tokens_in.get(effect)[0]
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
|
||||
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, has_side_effect=True)
|
||||
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
|
||||
else:
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
@ -187,7 +188,7 @@ def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
pe.partial_eval_jaxpr_custom_rules[debug_callback_p] = (
|
||||
_debug_callback_partial_eval_custom)
|
||||
|
||||
def debug_callback(callback: Callable[..., Any], *args: Any,
|
||||
def debug_callback(callback: Callable[..., None], *args: Any,
|
||||
ordered: bool = False, **kwargs: Any) -> None:
|
||||
"""Calls a stageable Python callback.
|
||||
|
||||
@ -206,7 +207,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
|
||||
of the computation are duplicated or dropped.
|
||||
|
||||
Args:
|
||||
callback: A Python callable. Its return value will be ignored.
|
||||
callback: A Python callable returning None.
|
||||
*args: The positional arguments to the callback.
|
||||
ordered: A keyword only argument used to indicate whether or not the
|
||||
staged out computation will enforce ordering of this callback w.r.t.
|
||||
@ -231,7 +232,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
|
||||
def _flat_callback(*flat_args):
|
||||
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
||||
callback(*args, **kwargs)
|
||||
return []
|
||||
return ()
|
||||
debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect)
|
||||
|
||||
class _DebugPrintFormatChecker(string.Formatter):
|
||||
|
Loading…
x
Reference in New Issue
Block a user