diff --git a/CHANGELOG.md b/CHANGELOG.md index 43a052189..8255a4ca0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,12 +14,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * Breaking changes: * The host_callback primitives have been simplified to drop the - special autodiff handling for hcb.id_tap and id_print. - From now on, only the primals are tapped. The old behavior can be - obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS`` - environment variable, or the ```--flax_host_callback_ad_transforms``` flag. - Additionally, added documentation for how to implement the old behavior - using JAX custom AD APIs ({jax-issue}`#7839`). + special autodiff handling for hcb.id_tap and id_print. + From now on, only the primals are tapped. The old behavior can be + obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS`` + environment variable, or the ```--flax_host_callback_ad_transforms``` flag. + Additionally, added documentation for how to implement the old behavior + using JAX custom AD APIs ({jax-issue}`#8678`). + +* Bug fixes: + * host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`#8907`). * New features: * add `jax.block_until_ready` ({jax-issue}`#8941) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 332eaae30..52603b3ce 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -321,32 +321,6 @@ def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: else: return jaxpr -outfeed_primitives: Set[core.Primitive] = set() -def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool: - """Finds if there are outfeed primitives anywhere inside a Jaxpr.""" - return any(primitive_uses_outfeed(eqn.primitive, eqn.params) - for eqn in jaxpr.eqns) - -def _param_uses_outfeed(param): - if type(param) is core.Jaxpr: - if jaxpr_uses_outfeed(param): - return True - elif type(param) is core.ClosedJaxpr: - if jaxpr_uses_outfeed(param.jaxpr): - return True - return False - -def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool: - if prim in outfeed_primitives: - return True - for param in params.values(): - if isinstance(param, tuple): - if any(unsafe_map(_param_uses_outfeed, param)): - return True - elif _param_uses_outfeed(param): - return True - return False - def jaxpr_replicas(jaxpr) -> int: """The number of replicas needed for a jaxpr. diff --git a/jax/core.py b/jax/core.py index 884986eac..10961b75b 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1678,6 +1678,32 @@ call_p.def_impl(call_impl) named_call_p: CallPrimitive = CallPrimitive('named_call') named_call_p.def_impl(call_impl) +outfeed_primitives: Set[Primitive] = set() +def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool: + """Finds if there are outfeed primitives anywhere inside a Jaxpr.""" + return any(primitive_uses_outfeed(eqn.primitive, eqn.params) + for eqn in jaxpr.eqns) + +def _param_uses_outfeed(param): + if type(param) is Jaxpr: + if jaxpr_uses_outfeed(param): + return True + elif type(param) is ClosedJaxpr: + if jaxpr_uses_outfeed(param.jaxpr): + return True + return False + +def primitive_uses_outfeed(prim: Primitive, params: Dict) -> bool: + if prim in outfeed_primitives: + return True + for param in params.values(): + if isinstance(param, tuple): + if any(unsafe_map(_param_uses_outfeed, param)): + return True + elif _param_uses_outfeed(param): + return True + return False + # ------------------- Map ------------------- def mapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue: diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index f3499863c..7c14bb758 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -313,6 +313,20 @@ its argument:: # what: x,x^2 : (3., 9.) # what: cotangents : (9., 3.) +If you use :func:`ad_checkpoint.checkpoint` to rematerialize the residuals +for the backward pass, then the callbacks from the primal computation will +be called twice:: + + jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.) + # what: x,x^2 : (3., 9.) + # what: x,x^2 : (27., 729.) + # what: x,x^2 : (3., 9.) + +The callbacks are, in order from: the primal computation of the inner ``power3``, +the primal computation of the outer ``power3``, and the rematerialization +of the residuals for the inner ``power3``. + + Behavior under jax.vmap ----------------------- @@ -900,7 +914,7 @@ It takes the following parameters: """ outside_call_p = core.Primitive("outside_call") outside_call_p.multiple_results = True -dispatch.outfeed_primitives.add(outside_call_p) +core.outfeed_primitives.add(outside_call_p) def _outside_call_abstract_eval(*args_a: pe.AbstractValue, @@ -1385,7 +1399,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, """Rewrite a Jaxpr to thread the token, if needed.""" assert has_input_token or not has_output_token - if not has_input_token and not dispatch.jaxpr_uses_outfeed(jaxpr): + if not has_input_token and not core.jaxpr_uses_outfeed(jaxpr): return jaxpr mk_new_var = core.gensym([jaxpr]) @@ -1407,7 +1421,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, lax.create_token_p, {}, source_info_util.current())) for eqn in jaxpr.eqns: - if not dispatch.primitive_uses_outfeed(eqn.primitive, eqn.params): + if not core.primitive_uses_outfeed(eqn.primitive, eqn.params): eqns.append(eqn) else: output_token_var = mk_new_var(last_token_var.aval) @@ -1445,7 +1459,7 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], cond_jaxpr, _, body_jaxpr, _ = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) - if dispatch.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): + if core.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var, input_itoken_var, output_itoken_var, mk_new_var) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 93e62fe4d..26ccfd096 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1029,7 +1029,7 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: List[bool] used_outs = map(read, eqn.outvars) # If any outputs are used, then we need to keep a version of the eqn and # potentially mark some inputs as used. Otherwise mark all inputs as unused. - if any(used_outs): + if any(used_outs) or core.primitive_uses_outfeed(eqn.primitive, eqn.params): # If there's a rule for modifying the eqn and computing used inputs, apply # it. Otherwise, keep the eqn unmodified and mark all inputs as used. rule = dce_rules.get(eqn.primitive) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index d8dac83b5..e2ac4a0fe 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -27,6 +27,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax +from jax import ad_checkpoint from jax import core from jax.config import config from jax import dtypes @@ -1458,6 +1459,19 @@ class HostCallbackTapTest(jtu.JaxTestCase): transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents ( [4. 9.] [2. 3.] )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + + print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}") + hcb.barrier_wait() + expected = """ + what: x,x^2 + ( 3. 9. ) + what: x,x^2 + ( 27. 729. ) + what: x,x^2 + ( 3. 9. )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() def test_tap_pmap(self): if len(local_devices()) < 2: @@ -2024,8 +2038,8 @@ class HostCallbackTapTest(jtu.JaxTestCase): use_result=use_result, use_remat=use_remat, grad_func=grad_func) for use_result in [True, False] for grad_func in ["grad", "value_and_grad"] - for use_remat in ["old", "none"])) - def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="old"): + for use_remat in ["old", "new", "none"])) + def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"): def f(x): id_print_result = hcb.id_print(x, output_stream=testing_stream) if use_result: @@ -2034,6 +2048,8 @@ class HostCallbackTapTest(jtu.JaxTestCase): grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad if use_remat == "old": trans_f = jax.remat(f) + elif use_remat == "new": + trans_f = ad_checkpoint.checkpoint(f) else: assert use_remat == "none" trans_f = f @@ -2068,8 +2084,14 @@ class HostCallbackTapTest(jtu.JaxTestCase): 2. 2.""" else: - # TODO: we should see two callbacks - expected = "" + if use_remat == "old": + # TODO: we should see two callbacks + expected = "" + else: + # Good: we see two callbacks, whether or not we use the result. + expected = """ + 2. + 2.""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) def test_tap_named_call(self):