Don't compute the log message for checkpointing residuals unless it is going to be logged.

PiperOrigin-RevId: 606811594
This commit is contained in:
Peter Hawkins 2024-02-13 18:23:02 -08:00 committed by jax authors
parent ac45b8f73d
commit 7156f20b44

View File

@ -552,13 +552,14 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
source_info_util.current())
# log info about saved residuals
log_level = logging.WARNING if config.log_checkpoint_residuals.value else logging.DEBUG
if logger.isEnabledFor(log_level):
try:
_, staged_unk = partition_list(in_used_staged, in_unknowns)
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
logger.log(logging.WARNING if config.log_checkpoint_residuals.value
else logging.DEBUG,
logger.log(log_level,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
' %s\n' * len(res_invars) +