mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
make print_saved_residuals show more checkpoint_name names
Previously, we would only show the `checkpoint_name` name for a saved residual if the _output_ of `checkpoint_name` was saved. That meant that a policy like `save_dots` could allow a value to be saved, but because the saved version was the _input_ to a `checkpoint_name`, we wouldn't print the name. This fix is kind of a quick hack, and doesn't handle all cases. For example, if a value is an input to two different `checkpoint_name` calls, this code arbitrarily uses the latter name (I think). Moreover there are probably ways the name doesn't get picked up even if it should. A more thorough version might check all uses, even ones that occur outside a higher-order primitive or something. But for now this change makes the output slightly nicer for some useful cases.
This commit is contained in:
parent
aaeeaf5f0c
commit
566af12aca
@ -465,11 +465,14 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
|
||||
src = 'from the argument at flattened index {i}'
|
||||
results.append((v.aval, src))
|
||||
|
||||
named_vars = {v: e for e in jaxpr.eqns if e.primitive is name_p
|
||||
for v in e.invars}
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
src = source_info_util.summarize(eqn.source_info)
|
||||
for v in eqn.outvars:
|
||||
if v in res_vars:
|
||||
if eqn.primitive is name_p:
|
||||
if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]):
|
||||
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
|
||||
elif str(eqn.primitive) == 'xla_call':
|
||||
results.append((v.aval,
|
||||
|
@ -5568,6 +5568,16 @@ class RematTest(jtu.JaxTestCase):
|
||||
res_avals = saved_residuals(f, jnp.ones((2, 2)))
|
||||
self.assertLen(res_avals, 1)
|
||||
|
||||
def test_name_saveable_input(self):
|
||||
@partial(jax.remat, policy=lambda p, *_, **__: 'mul' in str(p))
|
||||
def f(x):
|
||||
x = checkpoint_name(x * x, 'foo')
|
||||
x = x * x
|
||||
return x
|
||||
|
||||
res = saved_residuals(f, 3.)
|
||||
self.assertStartsWith(res[1][1], "named 'foo'")
|
||||
|
||||
def test_name_denylist(self):
|
||||
def f(x):
|
||||
y = checkpoint_name(jnp.multiply(2., 2.), 'y')
|
||||
|
Loading…
x
Reference in New Issue
Block a user