Add live-analysis memory optimization to more jaxpr interpreters.

Follow-up on 8a85e76a5cff0897eccbafc48da836b6f6704e5d

PiperOrigin-RevId: 540857501
This commit is contained in:
Lena Martens 2023-06-16 06:07:54 -07:00 committed by jax authors
parent 9fdaf5a100
commit fbf8823da3
5 changed files with 30 additions and 9 deletions

View File

@ -401,11 +401,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
err_vals, in_args = split_list(args, [err_tree.num_leaves])
error = jtu.tree_unflatten(err_tree, err_vals)
last_used = {v: None for v in jaxpr.outvars if not isinstance(v, core.Literal)}
for eqn in jaxpr.eqns[::-1]:
for v in eqn.invars:
if not isinstance(v, core.Literal) and v not in last_used:
last_used[v] = eqn
last_used = core.last_used(jaxpr)
def read_env(var: core.Atom):
if isinstance(var, core.Literal):
@ -432,10 +428,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
map(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)
for v in set(v for v in eqn.invars if not isinstance(v, core.Literal)):
if last_used[v] is eqn:
# Delete ref to variable when it is no longer needed by next equations.
del env[v]
core.clean_up_dead_vars(eqn, env, last_used)
return error, map(read_env, jaxpr.outvars)

View File

@ -439,6 +439,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
env: Dict[Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
lu = last_used(jaxpr)
for eqn in jaxpr.eqns:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
@ -449,6 +450,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
clean_up_dead_vars(eqn, env, lu)
return map(read, jaxpr.outvars)
@ -3148,3 +3150,23 @@ def pp_effect(effect: Effect, context: JaxprPpContext) -> pp.Doc:
if hasattr(effect, "_pretty_print"):
return effect._pretty_print(context)
return pp.text(str(effect))
# ------------------- Jaxpr util -------------------
def last_used(jaxpr: Jaxpr) -> Dict[Var, Optional[JaxprEqn]]:
"""Returns a mapping from every var in jaxpr to what equation uses it last."""
last_used: Dict[Var, Optional[JaxprEqn]] = {
v: None for v in jaxpr.outvars if not isinstance(v, Literal)}
for eqn in reversed(jaxpr.eqns):
for v in eqn.invars:
if not isinstance(v, Literal) and v not in last_used:
last_used[v] = eqn
return last_used
def clean_up_dead_vars(eqn: JaxprEqn, env: Dict[Var, Any],
last_used: Dict[Var, Optional[JaxprEqn]]):
"""Remove all eqn.invars from env if eqn is the last time they were used."""
for v in set(v for v in eqn.invars if not isinstance(v, Literal)):
if last_used[v] is eqn:
# Delete ref to variable when it is no longer needed by next equations.
del env[v]

View File

@ -1106,6 +1106,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
@ -1168,6 +1169,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
ans, "lowering function returned a bad output", eqn)
assert len(ans) == len(eqn.outvars), (ans, eqn)
map(write, eqn.outvars, out_nodes)
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars), tokens
def _ir_consts(consts):

View File

@ -2533,12 +2533,14 @@ def _eval_jaxpr_padded(
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
rule = padding_rules[eqn.primitive]
outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
map(write, eqn.outvars, outs)
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars)
def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:

View File

@ -480,6 +480,7 @@ def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]],
map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
map(write, jaxpr.invars, in_rep)
last_used = core.last_used(jaxpr)
for e in jaxpr.eqns:
rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive))
out_rep = rule(mesh, *map(read, e.invars), **e.params)
@ -488,6 +489,7 @@ def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]],
map(write, e.outvars, out_rep)
else:
write(e.outvars[0], out_rep)
core.clean_up_dead_vars(e, env, last_used)
return map(read, jaxpr.outvars)
def _valid_repeats(mesh: Mesh, rep: Set[AxisName], dst: AxisNames) -> bool: