Add memories support to remat.

This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.

Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.

A very basic offloadable policy can look like this:

```
def policy(prim, *avals, **params):
  return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
This commit is contained in:
Yash Katariya 2023-09-12 20:49:25 -07:00 committed by jax authors
parent 7c0abb1c85
commit c41d271175
6 changed files with 113 additions and 11 deletions

View File

@ -1121,7 +1121,7 @@ def partial_eval_jaxpr_custom(
in_inst: bool | Sequence[bool],
ensure_out_unknowns: bool | Sequence[bool],
ensure_out_inst: bool | Sequence[bool],
saveable: Callable[..., bool],
saveable: Callable[..., RematCases_],
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int]:
if type(in_inst) is bool:
in_inst = (in_inst,) * len(jaxpr.invars)
@ -1145,7 +1145,7 @@ def partial_eval_jaxpr_stateful(
in_inst: bool | Sequence[bool],
ensure_out_unknowns: bool | Sequence[bool],
ensure_out_inst: bool | Sequence[bool],
saveable: Callable[..., bool],
saveable: Callable[..., RematCases_],
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
if type(in_inst) is bool:
in_inst = (in_inst,) * len(jaxpr.invars)
@ -1167,7 +1167,7 @@ def _partial_eval_jaxpr_custom_cached(
in_inst: tuple[bool, ...],
ensure_out_unknowns: tuple[bool, ...],
ensure_out_inst: tuple[bool, ...],
saveable: Callable[..., bool],
saveable: Callable[..., RematCases_],
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
env: dict[Var, tuple[bool, bool]] = {}
residuals: OrderedSet[Var] = OrderedSet()
@ -1187,6 +1187,7 @@ def _partial_eval_jaxpr_custom_cached(
residuals.add(x)
return x
newvar = core.gensym(suffix='_offload')
known_eqns, staged_eqns = [], []
map(write, in_unknowns, in_inst, jaxpr.invars)
map(partial(write, False, True), jaxpr.constvars)
@ -1209,10 +1210,30 @@ def _partial_eval_jaxpr_custom_cached(
else:
known_eqns.append(eqn)
# If it's an effectful primitive, we always to run and avoid staging it.
if eqn.effects or saveable(
eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params):
policy = ensure_enum(saveable(
eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params))
if eqn.effects or isinstance(policy, SaveableType):
map(partial(write, False, False), eqn.outvars)
elif isinstance(policy, Offloadable):
from jax._src.dispatch import device_put_p, TransferToMemoryKind # type: ignore
resvars = [newvar(v.aval) for v in eqn.outvars]
offload_eqn = core.JaxprEqn(
eqn.outvars, resvars, device_put_p,
dict(device=TransferToMemoryKind(policy.dst), src=None),
set(), source_info_util.new_source_info())
known_eqns.append(offload_eqn)
# resvars are known and available in the backward jaxpr.
map(partial(write, False, True), resvars)
residuals.update(resvars)
reload_eqn = core.JaxprEqn(
resvars, eqn.outvars, device_put_p, # type: ignore
dict(device=TransferToMemoryKind(policy.src), src=None),
set(), source_info_util.new_source_info())
staged_eqns.append(reload_eqn)
# outvars are known and available in the backward jaxpr.
map(partial(write, False, True), eqn.outvars)
else:
assert isinstance(policy, RecomputeType)
inputs = map(ensure_instantiated, inst_in, eqn.invars)
staged_eqns.append(eqn.replace(invars=inputs))
map(partial(write, False, True), eqn.outvars)
@ -1249,6 +1270,27 @@ def _partial_eval_jaxpr_custom_cached(
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
len(non_input_res_refs))
MemoryKind = str
class RecomputeType: pass
Recompute = RecomputeType()
class SaveableType: pass
Saveable = SaveableType()
class Offloadable(NamedTuple):
src: MemoryKind
dst: MemoryKind
RematCases = Union[RecomputeType, SaveableType, Offloadable]
RematCases_ = Union[RematCases, bool]
def ensure_enum(case: bool | RematCases) -> RematCases:
if isinstance(case, bool):
return Saveable if case else Recompute
return case
# A primitive rule for policy-driven partial evaluation returns a 5-tuple
# with the components representing, respectively:
# * the JaxprEqn for the 'known' side (or None if there is no known component),
@ -1262,12 +1304,12 @@ def _partial_eval_jaxpr_custom_cached(
PartialEvalCustomResult = tuple[Optional[JaxprEqn], Optional[JaxprEqn],
Sequence[bool], Sequence[bool], list[Var]]
PartialEvalCustomRule = Callable[
[Callable[..., bool], Sequence[bool], Sequence[bool], JaxprEqn],
[Callable[..., RematCases_], Sequence[bool], Sequence[bool], JaxprEqn],
PartialEvalCustomResult]
partial_eval_jaxpr_custom_rules: dict[Primitive, PartialEvalCustomRule] = {}
def partial_eval_jaxpr_custom_rule_not_implemented(
name: str, saveable: Callable[..., bool], unks_in: Sequence[bool],
name: str, saveable: Callable[..., RematCases_], unks_in: Sequence[bool],
inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult:
msg = (f'custom-policy remat rule not implemented for {name}, '
'open a feature request at https://github.com/google/jax/issues!')
@ -1287,7 +1329,7 @@ def trivial_ctx(_): yield
def call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater,
saveable: Callable[..., bool], unks_in: list[bool], inst_in: list[bool],
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
ctx: Callable[[core.ParamDict], AbstractContextManager[None]] = trivial_ctx,
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
@ -1319,7 +1361,7 @@ def call_partial_eval_custom_rule(
def closed_call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater,
saveable: Callable[..., bool], unks_in: list[bool], inst_in: list[bool],
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.

View File

@ -467,7 +467,7 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
pe.custom_partial_eval_rules[run_state_p] = _run_state_partial_eval
def _run_state_partial_eval_custom(
saveable: Callable[..., bool],
saveable: Callable[..., pe.RematCases_],
in_unknowns: Sequence[bool],
in_inst: Sequence[bool],
eqn: core.JaxprEqn):

View File

@ -1326,3 +1326,9 @@ def set_env(**kwargs):
finally:
_ = [os.environ.pop(key, None) for key in kwargs]
os.environ.update({k: v for k, v in original.items() if v is not None})
def fwd_bwd_jaxprs(f, *example_args):
fwd_jaxpr, (y_shape, res_shape) = jax.make_jaxpr(
lambda *args: jax.vjp(f, *args), return_shape=True)(*example_args)
bwd_jaxpr = jax.make_jaxpr(lambda res, outs: res(outs))(res_shape, y_shape)
return fwd_jaxpr, bwd_jaxpr

View File

@ -19,3 +19,8 @@ from jax._src.ad_checkpoint import (
print_saved_residuals,
remat,
)
from jax._src.interpreters.partial_eval import (
Recompute,
Saveable,
Offloadable,
)

View File

@ -1414,7 +1414,7 @@ core.axis_substitution_rules[shard_map_p] = _shard_map_axis_subst
# Remat
def _partial_eval_jaxpr_custom_rule(
saveable: Callable[..., bool], unks_in: Sequence[bool],
saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool],
inst_in: Sequence[bool], eqn: core.JaxprEqn
) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool],
list[core.Var]]:

View File

@ -22,6 +22,7 @@ import jax
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.ad_checkpoint import Offloadable, remat
from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
SingleDeviceSharding, GSPMDSharding,
TransferToMemoryKind,
@ -973,6 +974,54 @@ class MemoriesTest(jtu.JaxTestCase):
lowered_text = f.lower(x).as_text("hlo")
self.assertIn("input_output_alias", lowered_text)
def test_remat_jaxpr_offloadable(self):
def policy(prim, *avals, **params):
return Offloadable(src="tpu_hbm", dst="unpinned_host")
@functools.partial(remat, policy=policy)
def f(x):
x = jnp.sin(x)
x = jnp.sin(x)
x = jnp.sin(x)
return x
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, jnp.ones((3)))
self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
fwd_mem_kind_count = str(fwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='unpinned_host')")
self.assertEqual(fwd_mem_kind_count, 3)
self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
bwd_mem_kind_count = str(bwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='tpu_hbm')")
self.assertEqual(bwd_mem_kind_count, 3)
def test_remat_scan_jaxpr_offloadable(self):
def policy(prim, *avals, **params):
return Offloadable(src="tpu_hbm", dst="unpinned_host")
@functools.partial(remat, policy=policy)
def f(x):
def g(y, _):
y = jnp.sin(y)
y = jnp.sin(y)
y = jnp.sin(y)
return y, None
return jax.lax.scan(g, x, None, length=1)[0]
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, jnp.ones((3)))
self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
fwd_mem_kind_count = str(fwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='unpinned_host')")
self.assertEqual(fwd_mem_kind_count, 3)
self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
bwd_mem_kind_count = str(bwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='tpu_hbm')")
self.assertEqual(bwd_mem_kind_count, 3)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())