mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Don't compute the log message for checkpointing residuals unless it is going to be logged.
PiperOrigin-RevId: 606811594
This commit is contained in:
parent
ac45b8f73d
commit
7156f20b44
@ -552,23 +552,24 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
source_info_util.current())
|
||||
|
||||
# log info about saved residuals
|
||||
try:
|
||||
_, staged_unk = partition_list(in_used_staged, in_unknowns)
|
||||
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
|
||||
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
|
||||
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
|
||||
logger.log(logging.WARNING if config.log_checkpoint_residuals.value
|
||||
else logging.DEBUG,
|
||||
'remat-decorated function ' +
|
||||
'saving inputs with shapes:\n' * bool(res_invars) +
|
||||
' %s\n' * len(res_invars) +
|
||||
'and ' * bool(res_invars) * bool(body_res) +
|
||||
'saving these intermediates:\n' * bool(body_res) +
|
||||
' %s from %s\n' * len(body_res),
|
||||
*[v.aval.str_short() for v in res_invars],
|
||||
*[elt for (a, s) in body_res for elt in [a.str_short(), s]])
|
||||
except:
|
||||
pass # just don't log anything on failure
|
||||
log_level = logging.WARNING if config.log_checkpoint_residuals.value else logging.DEBUG
|
||||
if logger.isEnabledFor(log_level):
|
||||
try:
|
||||
_, staged_unk = partition_list(in_used_staged, in_unknowns)
|
||||
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
|
||||
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
|
||||
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
|
||||
logger.log(log_level,
|
||||
'remat-decorated function ' +
|
||||
'saving inputs with shapes:\n' * bool(res_invars) +
|
||||
' %s\n' * len(res_invars) +
|
||||
'and ' * bool(res_invars) * bool(body_res) +
|
||||
'saving these intermediates:\n' * bool(body_res) +
|
||||
' %s from %s\n' * len(body_res),
|
||||
*[v.aval.str_short() for v in res_invars],
|
||||
*[elt for (a, s) in body_res for elt in [a.str_short(), s]])
|
||||
except:
|
||||
pass # just don't log anything on failure
|
||||
|
||||
for t in out_jaxpr_tracers: t.recipe = recipe
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user