make print_saved_residuals show more checkpoint_name names

Previously, we would only show the `checkpoint_name` name for a saved residual
if the _output_ of `checkpoint_name` was saved. That meant that a policy like
`save_dots` could allow a value to be saved, but because the saved version was
the _input_ to a `checkpoint_name`, we wouldn't print the name.

This fix is kind of a quick hack, and doesn't handle all cases. For example, if
a value is an input to two different `checkpoint_name` calls, this code
arbitrarily uses the latter name (I think). Moreover there are probably ways
the name doesn't get picked up even if it should. A more thorough version might
check all uses, even ones that occur outside a higher-order primitive or
something. But for now this change makes the output slightly nicer for some
useful cases.
This commit is contained in:
Matthew Johnson 2024-03-17 17:48:12 -07:00
parent aaeeaf5f0c
commit 566af12aca
2 changed files with 14 additions and 1 deletions

View File

@ -465,11 +465,14 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
src = 'from the argument at flattened index {i}'
results.append((v.aval, src))
named_vars = {v: e for e in jaxpr.eqns if e.primitive is name_p
for v in e.invars}
for eqn in jaxpr.eqns:
src = source_info_util.summarize(eqn.source_info)
for v in eqn.outvars:
if v in res_vars:
if eqn.primitive is name_p:
if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]):
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
elif str(eqn.primitive) == 'xla_call':
results.append((v.aval,

View File

@ -5568,6 +5568,16 @@ class RematTest(jtu.JaxTestCase):
res_avals = saved_residuals(f, jnp.ones((2, 2)))
self.assertLen(res_avals, 1)
def test_name_saveable_input(self):
@partial(jax.remat, policy=lambda p, *_, **__: 'mul' in str(p))
def f(x):
x = checkpoint_name(x * x, 'foo')
x = x * x
return x
res = saved_residuals(f, 3.)
self.assertStartsWith(res[1][1], "named 'foo'")
def test_name_denylist(self):
def f(x):
y = checkpoint_name(jnp.multiply(2., 2.), 'y')