mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[hcb] Add support for remat2 to host_callback
A callback under ad_checkpoint.checkpoint will be invoked twice when taking the gradient: once during the forward pass and once again during the backward pass when the residuals for the forward pass are rematerialized.
This commit is contained in:
parent
2c7db525f7
commit
3021d3e2e2
15
CHANGELOG.md
15
CHANGELOG.md
@ -14,12 +14,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
* Breaking changes:
|
||||
* The host_callback primitives have been simplified to drop the
|
||||
special autodiff handling for hcb.id_tap and id_print.
|
||||
From now on, only the primals are tapped. The old behavior can be
|
||||
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
|
||||
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
|
||||
Additionally, added documentation for how to implement the old behavior
|
||||
using JAX custom AD APIs ({jax-issue}`#7839`).
|
||||
special autodiff handling for hcb.id_tap and id_print.
|
||||
From now on, only the primals are tapped. The old behavior can be
|
||||
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
|
||||
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
|
||||
Additionally, added documentation for how to implement the old behavior
|
||||
using JAX custom AD APIs ({jax-issue}`#8678`).
|
||||
|
||||
* Bug fixes:
|
||||
* host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`#8907`).
|
||||
|
||||
* New features:
|
||||
* add `jax.block_until_ready` ({jax-issue}`#8941)
|
||||
|
@ -321,32 +321,6 @@ def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
else:
|
||||
return jaxpr
|
||||
|
||||
outfeed_primitives: Set[core.Primitive] = set()
|
||||
def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
|
||||
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
|
||||
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
|
||||
for eqn in jaxpr.eqns)
|
||||
|
||||
def _param_uses_outfeed(param):
|
||||
if type(param) is core.Jaxpr:
|
||||
if jaxpr_uses_outfeed(param):
|
||||
return True
|
||||
elif type(param) is core.ClosedJaxpr:
|
||||
if jaxpr_uses_outfeed(param.jaxpr):
|
||||
return True
|
||||
return False
|
||||
|
||||
def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
|
||||
if prim in outfeed_primitives:
|
||||
return True
|
||||
for param in params.values():
|
||||
if isinstance(param, tuple):
|
||||
if any(unsafe_map(_param_uses_outfeed, param)):
|
||||
return True
|
||||
elif _param_uses_outfeed(param):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def jaxpr_replicas(jaxpr) -> int:
|
||||
"""The number of replicas needed for a jaxpr.
|
||||
|
26
jax/core.py
26
jax/core.py
@ -1678,6 +1678,32 @@ call_p.def_impl(call_impl)
|
||||
named_call_p: CallPrimitive = CallPrimitive('named_call')
|
||||
named_call_p.def_impl(call_impl)
|
||||
|
||||
outfeed_primitives: Set[Primitive] = set()
|
||||
def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool:
|
||||
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
|
||||
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
|
||||
for eqn in jaxpr.eqns)
|
||||
|
||||
def _param_uses_outfeed(param):
|
||||
if type(param) is Jaxpr:
|
||||
if jaxpr_uses_outfeed(param):
|
||||
return True
|
||||
elif type(param) is ClosedJaxpr:
|
||||
if jaxpr_uses_outfeed(param.jaxpr):
|
||||
return True
|
||||
return False
|
||||
|
||||
def primitive_uses_outfeed(prim: Primitive, params: Dict) -> bool:
|
||||
if prim in outfeed_primitives:
|
||||
return True
|
||||
for param in params.values():
|
||||
if isinstance(param, tuple):
|
||||
if any(unsafe_map(_param_uses_outfeed, param)):
|
||||
return True
|
||||
elif _param_uses_outfeed(param):
|
||||
return True
|
||||
return False
|
||||
|
||||
# ------------------- Map -------------------
|
||||
|
||||
def mapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue:
|
||||
|
@ -313,6 +313,20 @@ its argument::
|
||||
# what: x,x^2 : (3., 9.)
|
||||
# what: cotangents : (9., 3.)
|
||||
|
||||
If you use :func:`ad_checkpoint.checkpoint` to rematerialize the residuals
|
||||
for the backward pass, then the callbacks from the primal computation will
|
||||
be called twice::
|
||||
|
||||
jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
|
||||
# what: x,x^2 : (3., 9.)
|
||||
# what: x,x^2 : (27., 729.)
|
||||
# what: x,x^2 : (3., 9.)
|
||||
|
||||
The callbacks are, in order from: the primal computation of the inner ``power3``,
|
||||
the primal computation of the outer ``power3``, and the rematerialization
|
||||
of the residuals for the inner ``power3``.
|
||||
|
||||
|
||||
Behavior under jax.vmap
|
||||
-----------------------
|
||||
|
||||
@ -900,7 +914,7 @@ It takes the following parameters:
|
||||
"""
|
||||
outside_call_p = core.Primitive("outside_call")
|
||||
outside_call_p.multiple_results = True
|
||||
dispatch.outfeed_primitives.add(outside_call_p)
|
||||
core.outfeed_primitives.add(outside_call_p)
|
||||
|
||||
|
||||
def _outside_call_abstract_eval(*args_a: pe.AbstractValue,
|
||||
@ -1385,7 +1399,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
||||
"""Rewrite a Jaxpr to thread the token, if needed."""
|
||||
assert has_input_token or not has_output_token
|
||||
|
||||
if not has_input_token and not dispatch.jaxpr_uses_outfeed(jaxpr):
|
||||
if not has_input_token and not core.jaxpr_uses_outfeed(jaxpr):
|
||||
return jaxpr
|
||||
|
||||
mk_new_var = core.gensym([jaxpr])
|
||||
@ -1407,7 +1421,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
||||
lax.create_token_p, {}, source_info_util.current()))
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if not dispatch.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
||||
if not core.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
||||
eqns.append(eqn)
|
||||
else:
|
||||
output_token_var = mk_new_var(last_token_var.aval)
|
||||
@ -1445,7 +1459,7 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
|
||||
eqn.params,
|
||||
["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
|
||||
if dispatch.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
|
||||
if core.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
|
||||
_rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
|
||||
input_itoken_var, output_itoken_var,
|
||||
mk_new_var)
|
||||
|
@ -1029,7 +1029,7 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: List[bool]
|
||||
used_outs = map(read, eqn.outvars)
|
||||
# If any outputs are used, then we need to keep a version of the eqn and
|
||||
# potentially mark some inputs as used. Otherwise mark all inputs as unused.
|
||||
if any(used_outs):
|
||||
if any(used_outs) or core.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
||||
# If there's a rule for modifying the eqn and computing used inputs, apply
|
||||
# it. Otherwise, keep the eqn unmodified and mark all inputs as used.
|
||||
rule = dce_rules.get(eqn.primitive)
|
||||
|
@ -27,6 +27,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import ad_checkpoint
|
||||
from jax import core
|
||||
from jax.config import config
|
||||
from jax import dtypes
|
||||
@ -1458,6 +1459,19 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
|
||||
( [4. 9.] [2. 3.] )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}")
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )
|
||||
what: x,x^2
|
||||
( 27. 729. )
|
||||
what: x,x^2
|
||||
( 3. 9. )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_tap_pmap(self):
|
||||
if len(local_devices()) < 2:
|
||||
@ -2024,8 +2038,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
use_result=use_result, use_remat=use_remat, grad_func=grad_func)
|
||||
for use_result in [True, False]
|
||||
for grad_func in ["grad", "value_and_grad"]
|
||||
for use_remat in ["old", "none"]))
|
||||
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="old"):
|
||||
for use_remat in ["old", "new", "none"]))
|
||||
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
|
||||
def f(x):
|
||||
id_print_result = hcb.id_print(x, output_stream=testing_stream)
|
||||
if use_result:
|
||||
@ -2034,6 +2048,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad
|
||||
if use_remat == "old":
|
||||
trans_f = jax.remat(f)
|
||||
elif use_remat == "new":
|
||||
trans_f = ad_checkpoint.checkpoint(f)
|
||||
else:
|
||||
assert use_remat == "none"
|
||||
trans_f = f
|
||||
@ -2068,8 +2084,14 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
2.
|
||||
2."""
|
||||
else:
|
||||
# TODO: we should see two callbacks
|
||||
expected = ""
|
||||
if use_remat == "old":
|
||||
# TODO: we should see two callbacks
|
||||
expected = ""
|
||||
else:
|
||||
# Good: we see two callbacks, whether or not we use the result.
|
||||
expected = """
|
||||
2.
|
||||
2."""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
def test_tap_named_call(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user