mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add live-analysis memory optimization to more jaxpr interpreters.
Follow-up on 8a85e76a5cff0897eccbafc48da836b6f6704e5d PiperOrigin-RevId: 540857501
This commit is contained in:
parent
9fdaf5a100
commit
fbf8823da3
@ -401,11 +401,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
|
||||
err_vals, in_args = split_list(args, [err_tree.num_leaves])
|
||||
error = jtu.tree_unflatten(err_tree, err_vals)
|
||||
|
||||
last_used = {v: None for v in jaxpr.outvars if not isinstance(v, core.Literal)}
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
for v in eqn.invars:
|
||||
if not isinstance(v, core.Literal) and v not in last_used:
|
||||
last_used[v] = eqn
|
||||
last_used = core.last_used(jaxpr)
|
||||
|
||||
def read_env(var: core.Atom):
|
||||
if isinstance(var, core.Literal):
|
||||
@ -432,10 +428,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
|
||||
map(write_env, eqn.outvars, outvals)
|
||||
else:
|
||||
write_env(eqn.outvars[0], outvals)
|
||||
for v in set(v for v in eqn.invars if not isinstance(v, core.Literal)):
|
||||
if last_used[v] is eqn:
|
||||
# Delete ref to variable when it is no longer needed by next equations.
|
||||
del env[v]
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
|
||||
return error, map(read_env, jaxpr.outvars)
|
||||
|
||||
|
@ -439,6 +439,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
|
||||
env: Dict[Var, Any] = {}
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
lu = last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
@ -449,6 +450,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
|
||||
map(write, eqn.outvars, ans)
|
||||
else:
|
||||
write(eqn.outvars[0], ans)
|
||||
clean_up_dead_vars(eqn, env, lu)
|
||||
return map(read, jaxpr.outvars)
|
||||
|
||||
|
||||
@ -3148,3 +3150,23 @@ def pp_effect(effect: Effect, context: JaxprPpContext) -> pp.Doc:
|
||||
if hasattr(effect, "_pretty_print"):
|
||||
return effect._pretty_print(context)
|
||||
return pp.text(str(effect))
|
||||
|
||||
# ------------------- Jaxpr util -------------------
|
||||
|
||||
def last_used(jaxpr: Jaxpr) -> Dict[Var, Optional[JaxprEqn]]:
|
||||
"""Returns a mapping from every var in jaxpr to what equation uses it last."""
|
||||
last_used: Dict[Var, Optional[JaxprEqn]] = {
|
||||
v: None for v in jaxpr.outvars if not isinstance(v, Literal)}
|
||||
for eqn in reversed(jaxpr.eqns):
|
||||
for v in eqn.invars:
|
||||
if not isinstance(v, Literal) and v not in last_used:
|
||||
last_used[v] = eqn
|
||||
return last_used
|
||||
|
||||
def clean_up_dead_vars(eqn: JaxprEqn, env: Dict[Var, Any],
|
||||
last_used: Dict[Var, Optional[JaxprEqn]]):
|
||||
"""Remove all eqn.invars from env if eqn is the last time they were used."""
|
||||
for v in set(v for v in eqn.invars if not isinstance(v, Literal)):
|
||||
if last_used[v] is eqn:
|
||||
# Delete ref to variable when it is no longer needed by next equations.
|
||||
del env[v]
|
||||
|
@ -1106,6 +1106,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
last_used = core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_nodes = map(read, eqn.invars)
|
||||
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
|
||||
@ -1168,6 +1169,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
ans, "lowering function returned a bad output", eqn)
|
||||
assert len(ans) == len(eqn.outvars), (ans, eqn)
|
||||
map(write, eqn.outvars, out_nodes)
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
return map(read, jaxpr.outvars), tokens
|
||||
|
||||
def _ir_consts(consts):
|
||||
|
@ -2533,12 +2533,14 @@ def _eval_jaxpr_padded(
|
||||
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
last_used = core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
|
||||
out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
|
||||
rule = padding_rules[eqn.primitive]
|
||||
outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
|
||||
map(write, eqn.outvars, outs)
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
return map(read, jaxpr.outvars)
|
||||
|
||||
def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:
|
||||
|
@ -480,6 +480,7 @@ def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]],
|
||||
|
||||
map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
|
||||
map(write, jaxpr.invars, in_rep)
|
||||
last_used = core.last_used(jaxpr)
|
||||
for e in jaxpr.eqns:
|
||||
rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive))
|
||||
out_rep = rule(mesh, *map(read, e.invars), **e.params)
|
||||
@ -488,6 +489,7 @@ def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]],
|
||||
map(write, e.outvars, out_rep)
|
||||
else:
|
||||
write(e.outvars[0], out_rep)
|
||||
core.clean_up_dead_vars(e, env, last_used)
|
||||
return map(read, jaxpr.outvars)
|
||||
|
||||
def _valid_repeats(mesh: Mesh, rep: Set[AxisName], dst: AxisNames) -> bool:
|
||||
|
Loading…
x
Reference in New Issue
Block a user