mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add memories support to remat.
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass. Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool. A very basic offloadable policy can look like this: ``` def policy(prim, *avals, **params): return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host') ``` Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 564914301
This commit is contained in:
parent
7c0abb1c85
commit
c41d271175
@ -1121,7 +1121,7 @@ def partial_eval_jaxpr_custom(
|
||||
in_inst: bool | Sequence[bool],
|
||||
ensure_out_unknowns: bool | Sequence[bool],
|
||||
ensure_out_inst: bool | Sequence[bool],
|
||||
saveable: Callable[..., bool],
|
||||
saveable: Callable[..., RematCases_],
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int]:
|
||||
if type(in_inst) is bool:
|
||||
in_inst = (in_inst,) * len(jaxpr.invars)
|
||||
@ -1145,7 +1145,7 @@ def partial_eval_jaxpr_stateful(
|
||||
in_inst: bool | Sequence[bool],
|
||||
ensure_out_unknowns: bool | Sequence[bool],
|
||||
ensure_out_inst: bool | Sequence[bool],
|
||||
saveable: Callable[..., bool],
|
||||
saveable: Callable[..., RematCases_],
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
|
||||
if type(in_inst) is bool:
|
||||
in_inst = (in_inst,) * len(jaxpr.invars)
|
||||
@ -1167,7 +1167,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
in_inst: tuple[bool, ...],
|
||||
ensure_out_unknowns: tuple[bool, ...],
|
||||
ensure_out_inst: tuple[bool, ...],
|
||||
saveable: Callable[..., bool],
|
||||
saveable: Callable[..., RematCases_],
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
|
||||
env: dict[Var, tuple[bool, bool]] = {}
|
||||
residuals: OrderedSet[Var] = OrderedSet()
|
||||
@ -1187,6 +1187,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
residuals.add(x)
|
||||
return x
|
||||
|
||||
newvar = core.gensym(suffix='_offload')
|
||||
known_eqns, staged_eqns = [], []
|
||||
map(write, in_unknowns, in_inst, jaxpr.invars)
|
||||
map(partial(write, False, True), jaxpr.constvars)
|
||||
@ -1209,10 +1210,30 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
else:
|
||||
known_eqns.append(eqn)
|
||||
# If it's an effectful primitive, we always to run and avoid staging it.
|
||||
if eqn.effects or saveable(
|
||||
eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params):
|
||||
policy = ensure_enum(saveable(
|
||||
eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params))
|
||||
if eqn.effects or isinstance(policy, SaveableType):
|
||||
map(partial(write, False, False), eqn.outvars)
|
||||
elif isinstance(policy, Offloadable):
|
||||
from jax._src.dispatch import device_put_p, TransferToMemoryKind # type: ignore
|
||||
resvars = [newvar(v.aval) for v in eqn.outvars]
|
||||
offload_eqn = core.JaxprEqn(
|
||||
eqn.outvars, resvars, device_put_p,
|
||||
dict(device=TransferToMemoryKind(policy.dst), src=None),
|
||||
set(), source_info_util.new_source_info())
|
||||
known_eqns.append(offload_eqn)
|
||||
# resvars are known and available in the backward jaxpr.
|
||||
map(partial(write, False, True), resvars)
|
||||
residuals.update(resvars)
|
||||
reload_eqn = core.JaxprEqn(
|
||||
resvars, eqn.outvars, device_put_p, # type: ignore
|
||||
dict(device=TransferToMemoryKind(policy.src), src=None),
|
||||
set(), source_info_util.new_source_info())
|
||||
staged_eqns.append(reload_eqn)
|
||||
# outvars are known and available in the backward jaxpr.
|
||||
map(partial(write, False, True), eqn.outvars)
|
||||
else:
|
||||
assert isinstance(policy, RecomputeType)
|
||||
inputs = map(ensure_instantiated, inst_in, eqn.invars)
|
||||
staged_eqns.append(eqn.replace(invars=inputs))
|
||||
map(partial(write, False, True), eqn.outvars)
|
||||
@ -1249,6 +1270,27 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
|
||||
len(non_input_res_refs))
|
||||
|
||||
|
||||
MemoryKind = str
|
||||
|
||||
class RecomputeType: pass
|
||||
Recompute = RecomputeType()
|
||||
|
||||
class SaveableType: pass
|
||||
Saveable = SaveableType()
|
||||
|
||||
class Offloadable(NamedTuple):
|
||||
src: MemoryKind
|
||||
dst: MemoryKind
|
||||
|
||||
RematCases = Union[RecomputeType, SaveableType, Offloadable]
|
||||
RematCases_ = Union[RematCases, bool]
|
||||
|
||||
def ensure_enum(case: bool | RematCases) -> RematCases:
|
||||
if isinstance(case, bool):
|
||||
return Saveable if case else Recompute
|
||||
return case
|
||||
|
||||
# A primitive rule for policy-driven partial evaluation returns a 5-tuple
|
||||
# with the components representing, respectively:
|
||||
# * the JaxprEqn for the 'known' side (or None if there is no known component),
|
||||
@ -1262,12 +1304,12 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
PartialEvalCustomResult = tuple[Optional[JaxprEqn], Optional[JaxprEqn],
|
||||
Sequence[bool], Sequence[bool], list[Var]]
|
||||
PartialEvalCustomRule = Callable[
|
||||
[Callable[..., bool], Sequence[bool], Sequence[bool], JaxprEqn],
|
||||
[Callable[..., RematCases_], Sequence[bool], Sequence[bool], JaxprEqn],
|
||||
PartialEvalCustomResult]
|
||||
partial_eval_jaxpr_custom_rules: dict[Primitive, PartialEvalCustomRule] = {}
|
||||
|
||||
def partial_eval_jaxpr_custom_rule_not_implemented(
|
||||
name: str, saveable: Callable[..., bool], unks_in: Sequence[bool],
|
||||
name: str, saveable: Callable[..., RematCases_], unks_in: Sequence[bool],
|
||||
inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult:
|
||||
msg = (f'custom-policy remat rule not implemented for {name}, '
|
||||
'open a feature request at https://github.com/google/jax/issues!')
|
||||
@ -1287,7 +1329,7 @@ def trivial_ctx(_): yield
|
||||
|
||||
def call_partial_eval_custom_rule(
|
||||
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
||||
saveable: Callable[..., bool], unks_in: list[bool], inst_in: list[bool],
|
||||
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
|
||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||
ctx: Callable[[core.ParamDict], AbstractContextManager[None]] = trivial_ctx,
|
||||
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
|
||||
@ -1319,7 +1361,7 @@ def call_partial_eval_custom_rule(
|
||||
|
||||
def closed_call_partial_eval_custom_rule(
|
||||
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
||||
saveable: Callable[..., bool], unks_in: list[bool], inst_in: list[bool],
|
||||
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
|
||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
|
||||
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
|
||||
|
@ -467,7 +467,7 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
|
||||
pe.custom_partial_eval_rules[run_state_p] = _run_state_partial_eval
|
||||
|
||||
def _run_state_partial_eval_custom(
|
||||
saveable: Callable[..., bool],
|
||||
saveable: Callable[..., pe.RematCases_],
|
||||
in_unknowns: Sequence[bool],
|
||||
in_inst: Sequence[bool],
|
||||
eqn: core.JaxprEqn):
|
||||
|
@ -1326,3 +1326,9 @@ def set_env(**kwargs):
|
||||
finally:
|
||||
_ = [os.environ.pop(key, None) for key in kwargs]
|
||||
os.environ.update({k: v for k, v in original.items() if v is not None})
|
||||
|
||||
def fwd_bwd_jaxprs(f, *example_args):
|
||||
fwd_jaxpr, (y_shape, res_shape) = jax.make_jaxpr(
|
||||
lambda *args: jax.vjp(f, *args), return_shape=True)(*example_args)
|
||||
bwd_jaxpr = jax.make_jaxpr(lambda res, outs: res(outs))(res_shape, y_shape)
|
||||
return fwd_jaxpr, bwd_jaxpr
|
||||
|
@ -19,3 +19,8 @@ from jax._src.ad_checkpoint import (
|
||||
print_saved_residuals,
|
||||
remat,
|
||||
)
|
||||
from jax._src.interpreters.partial_eval import (
|
||||
Recompute,
|
||||
Saveable,
|
||||
Offloadable,
|
||||
)
|
||||
|
@ -1414,7 +1414,7 @@ core.axis_substitution_rules[shard_map_p] = _shard_map_axis_subst
|
||||
# Remat
|
||||
|
||||
def _partial_eval_jaxpr_custom_rule(
|
||||
saveable: Callable[..., bool], unks_in: Sequence[bool],
|
||||
saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool],
|
||||
inst_in: Sequence[bool], eqn: core.JaxprEqn
|
||||
) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool],
|
||||
list[core.Var]]:
|
||||
|
@ -22,6 +22,7 @@ import jax
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.ad_checkpoint import Offloadable, remat
|
||||
from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
|
||||
SingleDeviceSharding, GSPMDSharding,
|
||||
TransferToMemoryKind,
|
||||
@ -973,6 +974,54 @@ class MemoriesTest(jtu.JaxTestCase):
|
||||
lowered_text = f.lower(x).as_text("hlo")
|
||||
self.assertIn("input_output_alias", lowered_text)
|
||||
|
||||
def test_remat_jaxpr_offloadable(self):
|
||||
def policy(prim, *avals, **params):
|
||||
return Offloadable(src="tpu_hbm", dst="unpinned_host")
|
||||
|
||||
@functools.partial(remat, policy=policy)
|
||||
def f(x):
|
||||
x = jnp.sin(x)
|
||||
x = jnp.sin(x)
|
||||
x = jnp.sin(x)
|
||||
return x
|
||||
|
||||
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, jnp.ones((3)))
|
||||
|
||||
self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
|
||||
fwd_mem_kind_count = str(fwd_jaxpr).count(
|
||||
"TransferToMemoryKind(memory_kind='unpinned_host')")
|
||||
self.assertEqual(fwd_mem_kind_count, 3)
|
||||
|
||||
self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
|
||||
bwd_mem_kind_count = str(bwd_jaxpr).count(
|
||||
"TransferToMemoryKind(memory_kind='tpu_hbm')")
|
||||
self.assertEqual(bwd_mem_kind_count, 3)
|
||||
|
||||
def test_remat_scan_jaxpr_offloadable(self):
|
||||
def policy(prim, *avals, **params):
|
||||
return Offloadable(src="tpu_hbm", dst="unpinned_host")
|
||||
|
||||
@functools.partial(remat, policy=policy)
|
||||
def f(x):
|
||||
def g(y, _):
|
||||
y = jnp.sin(y)
|
||||
y = jnp.sin(y)
|
||||
y = jnp.sin(y)
|
||||
return y, None
|
||||
return jax.lax.scan(g, x, None, length=1)[0]
|
||||
|
||||
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, jnp.ones((3)))
|
||||
|
||||
self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
|
||||
fwd_mem_kind_count = str(fwd_jaxpr).count(
|
||||
"TransferToMemoryKind(memory_kind='unpinned_host')")
|
||||
self.assertEqual(fwd_mem_kind_count, 3)
|
||||
|
||||
self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
|
||||
bwd_mem_kind_count = str(bwd_jaxpr).count(
|
||||
"TransferToMemoryKind(memory_kind='tpu_hbm')")
|
||||
self.assertEqual(bwd_mem_kind_count, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user