Don't compute the log message for checkpointing residuals unless it is going to be logged.

PiperOrigin-RevId: 606811594
This commit is contained in:
Peter Hawkins 2024-02-13 18:23:02 -08:00 committed by jax authors
parent ac45b8f73d
commit 7156f20b44

View File

@ -552,23 +552,24 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
source_info_util.current())
# log info about saved residuals
try:
_, staged_unk = partition_list(in_used_staged, in_unknowns)
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
logger.log(logging.WARNING if config.log_checkpoint_residuals.value
else logging.DEBUG,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
' %s\n' * len(res_invars) +
'and ' * bool(res_invars) * bool(body_res) +
'saving these intermediates:\n' * bool(body_res) +
' %s from %s\n' * len(body_res),
*[v.aval.str_short() for v in res_invars],
*[elt for (a, s) in body_res for elt in [a.str_short(), s]])
except:
pass # just don't log anything on failure
log_level = logging.WARNING if config.log_checkpoint_residuals.value else logging.DEBUG
if logger.isEnabledFor(log_level):
try:
_, staged_unk = partition_list(in_used_staged, in_unknowns)
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
logger.log(log_level,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
' %s\n' * len(res_invars) +
'and ' * bool(res_invars) * bool(body_res) +
'saving these intermediates:\n' * bool(body_res) +
' %s from %s\n' * len(body_res),
*[v.aval.str_short() for v in res_invars],
*[elt for (a, s) in body_res for elt in [a.str_short(), s]])
except:
pass # just don't log anything on failure
for t in out_jaxpr_tracers: t.recipe = recipe