don't automatically use new checkpoint implementation

There's a bug we're struggling to repro.

To use the new checkpoint, just use

```python
from jax.ad_checkpoint import checkpoint
```

rather than `from jax import checkpoint.
This commit is contained in:
Matthew Johnson 2021-10-14 07:09:06 -07:00
parent 2f7e125c70
commit 725fe3abd4
2 changed files with 38 additions and 31 deletions

View File

@ -76,7 +76,7 @@ from ..interpreters import invertible_ad as iad
from ..interpreters.invertible_ad import custom_ivjp
from ..custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp, linear_call)
from ..ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
from ..ad_checkpoint import checkpoint_policies
from .._src.config import (flags, config, bool_env, disable_jit as _disable_jit,
debug_nans as config_debug_nans,
@ -2896,10 +2896,6 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
... return lambda x: f1(jax.checkpoint(f2)(x))
...
"""
# TODO(mattjj): we temporarily have parallel code paths
if policy is not None:
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy)
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):

View File

@ -55,7 +55,7 @@ from jax import tree_util
from jax import linear_util as lu
import jax._src.util
from jax._src.ad_checkpoint import saved_residuals
from jax.ad_checkpoint import checkpoint_name
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name
from jax.config import config
config.parse_flags_with_absl()
@ -3153,7 +3153,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_basic(self, remat):
@remat
@ -3194,7 +3195,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_freevars(self, remat):
def f1(x):
@ -3239,7 +3241,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_jit(self, remat):
@remat
@ -3266,7 +3269,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_vmap(self, remat):
@remat
@ -3291,7 +3295,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_higher_order_autodiff(self, remat):
def f(x):
@ -3333,7 +3338,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_no_redundant_flops(self, remat):
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
@ -3361,7 +3367,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_binomial_checkpointing(self, remat):
def binom_checkpoint(funs):
@ -3409,7 +3416,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_jit2(self, remat):
@api.jit
@ -3455,7 +3463,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_jit3(self, remat):
# https://github.com/google/jax/issues/2180
@ -3518,7 +3527,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_eval_counter(self, remat):
# https://github.com/google/jax/issues/2737
@ -3578,7 +3588,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_escaped_tracer_remat(self, remat):
# b/169779185
@ -3598,7 +3609,8 @@ class RematTest(jtu.JaxTestCase):
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False))
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_no_cse_widget_on_primals(self, remat):
@remat
@ -3857,25 +3869,24 @@ class RematTest(jtu.JaxTestCase):
# The old implementation of remat worked by data dependence, and so
# (potentially large) constants would not be rematerialized and could be
# wastefully instantiated. This test checks that the newer remat
# implementation avoids that. We engage the newer implementation by passing
# an explicit policy. See https://github.com/google/jax/pull/8191.
# implementation avoids that. See https://github.com/google/jax/pull/8191.
# no residuals from constants created inside jnp.einsum
@partial(jax.checkpoint, policy=lambda *_, **__: False)
@partial(new_checkpoint, policy=lambda *_, **__: False)
def f(x):
return jnp.einsum('ii->i', x)
res_avals = saved_residuals(f, jnp.ones((2, 2)))
self.assertLen(res_avals, 0)
# no residuals from jnp.zeros
@partial(jax.checkpoint, policy=lambda *_, **__: False)
@partial(new_checkpoint, policy=lambda *_, **__: False)
def f(x):
return jnp.zeros_like(x) * x
res_avals = saved_residuals(f, jnp.ones((2, 2)))
self.assertLen(res_avals, 0)
# no residuals from jnp.zeros, but input must be saved
@partial(jax.checkpoint, policy=lambda *_, **__: False)
@partial(new_checkpoint, policy=lambda *_, **__: False)
def f(x):
return jnp.zeros_like(x) * jnp.sin(x)
res_avals = saved_residuals(f, jnp.ones((2, 2)))
@ -3890,19 +3901,19 @@ class RematTest(jtu.JaxTestCase):
return (((x * y) * z) * w) * u
policy = jax.checkpoint_policies.save_any_names_but_these('y', 'z', 'w')
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
self.assertLen(res, 0) # can't save anything
policy = jax.checkpoint_policies.save_any_names_but_these('z', 'w')
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
self.assertLen(res, 1) # can save only y
policy = jax.checkpoint_policies.save_any_names_but_these('w')
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
self.assertLen(res, 2) # can save y and z
policy = jax.checkpoint_policies.save_any_names_but_these()
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
self.assertLen(res, 3) # can save y, z, and w
def test_name_allowlist(self):
@ -3914,19 +3925,19 @@ class RematTest(jtu.JaxTestCase):
return (((x * y) * z) * w) * u
policy = jax.checkpoint_policies.save_only_these_names('y', 'z', 'w')
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
self.assertLen(res, 3) # can save y, z, and w
policy = jax.checkpoint_policies.save_only_these_names('z', 'w')
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
self.assertLen(res, 2) # can save z and w
policy = jax.checkpoint_policies.save_only_these_names('w')
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
self.assertLen(res, 1) # can save w
policy = jax.checkpoint_policies.save_only_these_names()
res = saved_residuals(jax.checkpoint(f, policy=policy), 1.)
res = saved_residuals(new_checkpoint(f, policy=policy), 1.)
self.assertLen(res, 0) # can't save anything!
def test_saved_residuals_utility(self):