mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #20285 from mattjj:print-saved-residuals-more-names
PiperOrigin-RevId: 616896480
This commit is contained in:
commit
56451d1f56
@ -465,11 +465,14 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
|
||||
src = 'from the argument at flattened index {i}'
|
||||
results.append((v.aval, src))
|
||||
|
||||
named_vars = {v: e for e in jaxpr.eqns if e.primitive is name_p
|
||||
for v in e.invars}
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
src = source_info_util.summarize(eqn.source_info)
|
||||
for v in eqn.outvars:
|
||||
if v in res_vars:
|
||||
if eqn.primitive is name_p:
|
||||
if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]):
|
||||
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
|
||||
elif str(eqn.primitive) == 'xla_call':
|
||||
results.append((v.aval,
|
||||
|
@ -5568,6 +5568,16 @@ class RematTest(jtu.JaxTestCase):
|
||||
res_avals = saved_residuals(f, jnp.ones((2, 2)))
|
||||
self.assertLen(res_avals, 1)
|
||||
|
||||
def test_name_saveable_input(self):
|
||||
@partial(jax.remat, policy=lambda p, *_, **__: 'mul' in str(p))
|
||||
def f(x):
|
||||
x = checkpoint_name(x * x, 'foo')
|
||||
x = x * x
|
||||
return x
|
||||
|
||||
res = saved_residuals(f, 3.)
|
||||
self.assertStartsWith(res[1][1], "named 'foo'")
|
||||
|
||||
def test_name_denylist(self):
|
||||
def f(x):
|
||||
y = checkpoint_name(jnp.multiply(2., 2.), 'y')
|
||||
|
Loading…
x
Reference in New Issue
Block a user