Merge pull request #20285 from mattjj:print-saved-residuals-more-names

PiperOrigin-RevId: 616896480
This commit is contained in:
jax authors 2024-03-18 11:36:52 -07:00
commit 56451d1f56
2 changed files with 14 additions and 1 deletions

View File

@ -465,11 +465,14 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
src = 'from the argument at flattened index {i}'
results.append((v.aval, src))
named_vars = {v: e for e in jaxpr.eqns if e.primitive is name_p
for v in e.invars}
for eqn in jaxpr.eqns:
src = source_info_util.summarize(eqn.source_info)
for v in eqn.outvars:
if v in res_vars:
if eqn.primitive is name_p:
if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]):
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
elif str(eqn.primitive) == 'xla_call':
results.append((v.aval,

View File

@ -5568,6 +5568,16 @@ class RematTest(jtu.JaxTestCase):
res_avals = saved_residuals(f, jnp.ones((2, 2)))
self.assertLen(res_avals, 1)
def test_name_saveable_input(self):
@partial(jax.remat, policy=lambda p, *_, **__: 'mul' in str(p))
def f(x):
x = checkpoint_name(x * x, 'foo')
x = x * x
return x
res = saved_residuals(f, 3.)
self.assertStartsWith(res[1][1], "named 'foo'")
def test_name_denylist(self):
def f(x):
y = checkpoint_name(jnp.multiply(2., 2.), 'y')