mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
don't automatically use new checkpoint implementation
There's a bug we're struggling to repro. To use the new checkpoint, just use ```python from jax.ad_checkpoint import checkpoint ``` rather than `from jax import checkpoint.
This commit is contained in:
parent
2f7e125c70
commit
725fe3abd4
@ -76,7 +76,7 @@ from ..interpreters import invertible_ad as iad
|
||||
from ..interpreters.invertible_ad import custom_ivjp
|
||||
from ..custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
|
||||
custom_vjp, linear_call)
|
||||
from ..ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
|
||||
from ..ad_checkpoint import checkpoint_policies
|
||||
|
||||
from .._src.config import (flags, config, bool_env, disable_jit as _disable_jit,
|
||||
debug_nans as config_debug_nans,
|
||||
@ -2896,10 +2896,6 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
|
||||
... return lambda x: f1(jax.checkpoint(f2)(x))
|
||||
...
|
||||
"""
|
||||
# TODO(mattjj): we temporarily have parallel code paths
|
||||
if policy is not None:
|
||||
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy)
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def fun_remat(*args, **kwargs):
|
||||
|
@ -55,7 +55,7 @@ from jax import tree_util
|
||||
from jax import linear_util as lu
|
||||
import jax._src.util
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax.ad_checkpoint import checkpoint_name
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -3153,7 +3153,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_basic(self, remat):
|
||||
@remat
|
||||
@ -3194,7 +3195,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_freevars(self, remat):
|
||||
def f1(x):
|
||||
@ -3239,7 +3241,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_jit(self, remat):
|
||||
@remat
|
||||
@ -3266,7 +3269,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_vmap(self, remat):
|
||||
@remat
|
||||
@ -3291,7 +3295,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_higher_order_autodiff(self, remat):
|
||||
def f(x):
|
||||
@ -3333,7 +3338,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_no_redundant_flops(self, remat):
|
||||
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
|
||||
@ -3361,7 +3367,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_binomial_checkpointing(self, remat):
|
||||
def binom_checkpoint(funs):
|
||||
@ -3409,7 +3416,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_jit2(self, remat):
|
||||
@api.jit
|
||||
@ -3455,7 +3463,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_jit3(self, remat):
|
||||
# https://github.com/google/jax/issues/2180
|
||||
@ -3518,7 +3527,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_remat_eval_counter(self, remat):
|
||||
# https://github.com/google/jax/issues/2737
|
||||
@ -3578,7 +3588,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_escaped_tracer_remat(self, remat):
|
||||
# b/169779185
|
||||
@ -3598,7 +3609,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False))
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_no_cse_widget_on_primals(self, remat):
|
||||
@remat
|
||||
@ -3857,25 +3869,24 @@ class RematTest(jtu.JaxTestCase):
|
||||
# The old implementation of remat worked by data dependence, and so
|
||||
# (potentially large) constants would not be rematerialized and could be
|
||||
# wastefully instantiated. This test checks that the newer remat
|
||||
# implementation avoids that. We engage the newer implementation by passing
|
||||
# an explicit policy. See https://github.com/google/jax/pull/8191.
|
||||
# implementation avoids that. See https://github.com/google/jax/pull/8191.
|
||||
|
||||
# no residuals from constants created inside jnp.einsum
|
||||
@partial(jax.checkpoint, policy=lambda *_, **__: False)
|
||||
@partial(new_checkpoint, policy=lambda *_, **__: False)
|
||||
def f(x):
|
||||
return jnp.einsum('ii->i', x)
|
||||
res_avals = saved_residuals(f, jnp.ones((2, 2)))
|
||||
self.assertLen(res_avals, 0)
|
||||
|
||||
# no residuals from jnp.zeros
|
||||
@partial(jax.checkpoint, policy=lambda *_, **__: False)
|
||||
@partial(new_checkpoint, policy=lambda *_, **__: False)
|
||||
def f(x):
|
||||
return jnp.zeros_like(x) * x
|
||||
res_avals = saved_residuals(f, jnp.ones((2, 2)))
|
||||
self.assertLen(res_avals, 0)
|
||||
|
||||
# no residuals from jnp.zeros, but input must be saved
|
||||
@partial(jax.checkpoint, policy=lambda *_, **__: False)
|
||||
@partial(new_checkpoint, policy=lambda *_, **__: False)
|
||||
def f(x):
|
||||
return jnp.zeros_like(x) * jnp.sin(x)
|
||||
res_avals = saved_residuals(f, jnp.ones((2, 2)))
|
||||
@ -3890,19 +3901,19 @@ class RematTest(jtu.JaxTestCase):
|
||||
return (((x * y) * z) * w) * u
|
||||
|
||||
policy = jax.checkpoint_policies.save_any_names_but_these('y', 'z', 'w')
|
||||
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
|
||||
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
|
||||
self.assertLen(res, 0) # can't save anything
|
||||
|
||||
policy = jax.checkpoint_policies.save_any_names_but_these('z', 'w')
|
||||
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
|
||||
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
|
||||
self.assertLen(res, 1) # can save only y
|
||||
|
||||
policy = jax.checkpoint_policies.save_any_names_but_these('w')
|
||||
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
|
||||
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
|
||||
self.assertLen(res, 2) # can save y and z
|
||||
|
||||
policy = jax.checkpoint_policies.save_any_names_but_these()
|
||||
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
|
||||
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
|
||||
self.assertLen(res, 3) # can save y, z, and w
|
||||
|
||||
def test_name_allowlist(self):
|
||||
@ -3914,19 +3925,19 @@ class RematTest(jtu.JaxTestCase):
|
||||
return (((x * y) * z) * w) * u
|
||||
|
||||
policy = jax.checkpoint_policies.save_only_these_names('y', 'z', 'w')
|
||||
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
|
||||
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
|
||||
self.assertLen(res, 3) # can save y, z, and w
|
||||
|
||||
policy = jax.checkpoint_policies.save_only_these_names('z', 'w')
|
||||
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
|
||||
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
|
||||
self.assertLen(res, 2) # can save z and w
|
||||
|
||||
policy = jax.checkpoint_policies.save_only_these_names('w')
|
||||
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
|
||||
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
|
||||
self.assertLen(res, 1) # can save w
|
||||
|
||||
policy = jax.checkpoint_policies.save_only_these_names()
|
||||
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
|
||||
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
|
||||
self.assertLen(res, 0) # can't save anything!
|
||||
|
||||
def test_saved_residuals_utility(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user