From fbf8823da3f0cc0e32255783db0c8070613372ef Mon Sep 17 00:00:00 2001 From: Lena Martens Date: Fri, 16 Jun 2023 06:07:54 -0700 Subject: [PATCH] Add live-analysis memory optimization to more jaxpr interpreters. Follow-up on 8a85e76a5cff0897eccbafc48da836b6f6704e5d PiperOrigin-RevId: 540857501 --- jax/_src/checkify.py | 11 ++--------- jax/_src/core.py | 22 ++++++++++++++++++++++ jax/_src/interpreters/mlir.py | 2 ++ jax/_src/interpreters/partial_eval.py | 2 ++ jax/experimental/shard_map.py | 2 ++ 5 files changed, 30 insertions(+), 9 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index c3f9a376b..ecf56407b 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -401,11 +401,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value], err_vals, in_args = split_list(args, [err_tree.num_leaves]) error = jtu.tree_unflatten(err_tree, err_vals) - last_used = {v: None for v in jaxpr.outvars if not isinstance(v, core.Literal)} - for eqn in jaxpr.eqns[::-1]: - for v in eqn.invars: - if not isinstance(v, core.Literal) and v not in last_used: - last_used[v] = eqn + last_used = core.last_used(jaxpr) def read_env(var: core.Atom): if isinstance(var, core.Literal): @@ -432,10 +428,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value], map(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) - for v in set(v for v in eqn.invars if not isinstance(v, core.Literal)): - if last_used[v] is eqn: - # Delete ref to variable when it is no longer needed by next equations. - del env[v] + core.clean_up_dead_vars(eqn, env, last_used) return error, map(read_env, jaxpr.outvars) diff --git a/jax/_src/core.py b/jax/_src/core.py index d0dfc0726..f2d7ad86a 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -439,6 +439,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True): env: Dict[Var, Any] = {} map(write, jaxpr.constvars, consts) map(write, jaxpr.invars, args) + lu = last_used(jaxpr) for eqn in jaxpr.eqns: subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack @@ -449,6 +450,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True): map(write, eqn.outvars, ans) else: write(eqn.outvars[0], ans) + clean_up_dead_vars(eqn, env, lu) return map(read, jaxpr.outvars) @@ -3148,3 +3150,23 @@ def pp_effect(effect: Effect, context: JaxprPpContext) -> pp.Doc: if hasattr(effect, "_pretty_print"): return effect._pretty_print(context) return pp.text(str(effect)) + +# ------------------- Jaxpr util ------------------- + +def last_used(jaxpr: Jaxpr) -> Dict[Var, Optional[JaxprEqn]]: + """Returns a mapping from every var in jaxpr to what equation uses it last.""" + last_used: Dict[Var, Optional[JaxprEqn]] = { + v: None for v in jaxpr.outvars if not isinstance(v, Literal)} + for eqn in reversed(jaxpr.eqns): + for v in eqn.invars: + if not isinstance(v, Literal) and v not in last_used: + last_used[v] = eqn + return last_used + +def clean_up_dead_vars(eqn: JaxprEqn, env: Dict[Var, Any], + last_used: Dict[Var, Optional[JaxprEqn]]): + """Remove all eqn.invars from env if eqn is the last time they were used.""" + for v in set(v for v in eqn.invars if not isinstance(v, Literal)): + if last_used[v] is eqn: + # Delete ref to variable when it is no longer needed by next equations. + del env[v] diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 22109d1fc..86f21a6cb 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1106,6 +1106,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values) map(write, jaxpr.constvars, consts) map(write, jaxpr.invars, args) + last_used = core.last_used(jaxpr) for eqn in jaxpr.eqns: in_nodes = map(read, eqn.invars) assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack) @@ -1168,6 +1169,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, ans, "lowering function returned a bad output", eqn) assert len(ans) == len(eqn.outvars), (ans, eqn) map(write, eqn.outvars, out_nodes) + core.clean_up_dead_vars(eqn, env, last_used) return map(read, jaxpr.outvars), tokens def _ir_consts(consts): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index e0a07d221..1b57a9d3c 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2533,12 +2533,14 @@ def _eval_jaxpr_padded( map(write, jaxpr.constvars, consts) map(write, jaxpr.invars, args) + last_used = core.last_used(jaxpr) for eqn in jaxpr.eqns: in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars] out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars] rule = padding_rules[eqn.primitive] outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params) map(write, eqn.outvars, outs) + core.clean_up_dead_vars(eqn, env, last_used) return map(read, jaxpr.outvars) def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue: diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 8dbbddf26..7840b0c85 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -480,6 +480,7 @@ def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]], map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars)) map(write, jaxpr.invars, in_rep) + last_used = core.last_used(jaxpr) for e in jaxpr.eqns: rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive)) out_rep = rule(mesh, *map(read, e.invars), **e.params) @@ -488,6 +489,7 @@ def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]], map(write, e.outvars, out_rep) else: write(e.outvars[0], out_rep) + core.clean_up_dead_vars(e, env, last_used) return map(read, jaxpr.outvars) def _valid_repeats(mesh: Mesh, rep: Set[AxisName], dst: AxisNames) -> bool: