From 7156f20b44562d27229d8bb1a9301942505ba819 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 13 Feb 2024 18:23:02 -0800 Subject: [PATCH] Don't compute the log message for checkpointing residuals unless it is going to be logged. PiperOrigin-RevId: 606811594 --- jax/_src/ad_checkpoint.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index c11f01a4a..f397b9ee8 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -552,23 +552,24 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): source_info_util.current()) # log info about saved residuals - try: - _, staged_unk = partition_list(in_used_staged, in_unknowns) - res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:]) - res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:] - body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None) - logger.log(logging.WARNING if config.log_checkpoint_residuals.value - else logging.DEBUG, - 'remat-decorated function ' + - 'saving inputs with shapes:\n' * bool(res_invars) + - ' %s\n' * len(res_invars) + - 'and ' * bool(res_invars) * bool(body_res) + - 'saving these intermediates:\n' * bool(body_res) + - ' %s from %s\n' * len(body_res), - *[v.aval.str_short() for v in res_invars], - *[elt for (a, s) in body_res for elt in [a.str_short(), s]]) - except: - pass # just don't log anything on failure + log_level = logging.WARNING if config.log_checkpoint_residuals.value else logging.DEBUG + if logger.isEnabledFor(log_level): + try: + _, staged_unk = partition_list(in_used_staged, in_unknowns) + res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:]) + res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:] + body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None) + logger.log(log_level, + 'remat-decorated function ' + + 'saving inputs with shapes:\n' * bool(res_invars) + + ' %s\n' * len(res_invars) + + 'and ' * bool(res_invars) * bool(body_res) + + 'saving these intermediates:\n' * bool(body_res) + + ' %s from %s\n' * len(body_res), + *[v.aval.str_short() for v in res_invars], + *[elt for (a, s) in body_res for elt in [a.str_short(), s]]) + except: + pass # just don't log anything on failure for t in out_jaxpr_tracers: t.recipe = recipe