mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #20285 from mattjj:print-saved-residuals-more-names
PiperOrigin-RevId: 616896480
This commit is contained in:
commit
56451d1f56
@ -465,11 +465,14 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
|
|||||||
src = 'from the argument at flattened index {i}'
|
src = 'from the argument at flattened index {i}'
|
||||||
results.append((v.aval, src))
|
results.append((v.aval, src))
|
||||||
|
|
||||||
|
named_vars = {v: e for e in jaxpr.eqns if e.primitive is name_p
|
||||||
|
for v in e.invars}
|
||||||
|
|
||||||
for eqn in jaxpr.eqns:
|
for eqn in jaxpr.eqns:
|
||||||
src = source_info_util.summarize(eqn.source_info)
|
src = source_info_util.summarize(eqn.source_info)
|
||||||
for v in eqn.outvars:
|
for v in eqn.outvars:
|
||||||
if v in res_vars:
|
if v in res_vars:
|
||||||
if eqn.primitive is name_p:
|
if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]):
|
||||||
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
|
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
|
||||||
elif str(eqn.primitive) == 'xla_call':
|
elif str(eqn.primitive) == 'xla_call':
|
||||||
results.append((v.aval,
|
results.append((v.aval,
|
||||||
|
@ -5568,6 +5568,16 @@ class RematTest(jtu.JaxTestCase):
|
|||||||
res_avals = saved_residuals(f, jnp.ones((2, 2)))
|
res_avals = saved_residuals(f, jnp.ones((2, 2)))
|
||||||
self.assertLen(res_avals, 1)
|
self.assertLen(res_avals, 1)
|
||||||
|
|
||||||
|
def test_name_saveable_input(self):
|
||||||
|
@partial(jax.remat, policy=lambda p, *_, **__: 'mul' in str(p))
|
||||||
|
def f(x):
|
||||||
|
x = checkpoint_name(x * x, 'foo')
|
||||||
|
x = x * x
|
||||||
|
return x
|
||||||
|
|
||||||
|
res = saved_residuals(f, 3.)
|
||||||
|
self.assertStartsWith(res[1][1], "named 'foo'")
|
||||||
|
|
||||||
def test_name_denylist(self):
|
def test_name_denylist(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
y = checkpoint_name(jnp.multiply(2., 2.), 'y')
|
y = checkpoint_name(jnp.multiply(2., 2.), 'y')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user