mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Don't compute the log message for checkpointing residuals unless it is going to be logged.
PiperOrigin-RevId: 606811594
This commit is contained in:
parent
ac45b8f73d
commit
7156f20b44
@ -552,13 +552,14 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
source_info_util.current())
|
||||
|
||||
# log info about saved residuals
|
||||
log_level = logging.WARNING if config.log_checkpoint_residuals.value else logging.DEBUG
|
||||
if logger.isEnabledFor(log_level):
|
||||
try:
|
||||
_, staged_unk = partition_list(in_used_staged, in_unknowns)
|
||||
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
|
||||
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
|
||||
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
|
||||
logger.log(logging.WARNING if config.log_checkpoint_residuals.value
|
||||
else logging.DEBUG,
|
||||
logger.log(log_level,
|
||||
'remat-decorated function ' +
|
||||
'saving inputs with shapes:\n' * bool(res_invars) +
|
||||
' %s\n' * len(res_invars) +
|
||||
|
Loading…
x
Reference in New Issue
Block a user