Change to internal dead code elimination. Now the functions in dce_rules are responsible for checking if the equation has no used outputs or effects, and behaving appropriately in that case (which usually means eliminating said equation).

PiperOrigin-RevId: 695789033
This commit is contained in:
James Martens 2024-11-12 10:36:14 -08:00 committed by jax authors
parent 3a5ac487a6
commit 310ff7347c
9 changed files with 40 additions and 16 deletions

View File

@ -716,6 +716,8 @@ batching.fancy_primitive_batchers[remat_p] = remat_vmap
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
new_params = dict(eqn.params, jaxpr=new_jaxpr)
if (not any(used_inputs) and not any(used_outputs) and

View File

@ -1531,6 +1531,8 @@ def _remat_opt_transpose(
"remat optimization for custom_vjp does not support higher-order AD")
def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
if not any(used_outs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]])
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
if any(used_res):

View File

@ -360,7 +360,7 @@ class JaxprTrace(Trace['JaxprTracer']):
staged_out_axes, _ = partition_list(out_knowns, out_axes)
staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,)
# Create the input tracers for the staged-out (unkonwn-value) call.
# Create the input tracers for the staged-out (unknown-value) call.
const_tracers = map(self.new_instantiated_const, res)
env_tracers = map(self.to_jaxpr_tracer, env)
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
@ -1382,6 +1382,11 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
return new_jaxpr, used_consts, used_inputs
def has_effects(eqn: JaxprEqn) -> bool:
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
return bool(effs)
@weakref_lru_cache
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
instantiate: tuple[bool, ...]
@ -1395,21 +1400,14 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
if type(x) is Var:
env[x] = read(x) or b
def has_effects(eqn: JaxprEqn) -> bool:
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
return bool(effs)
new_eqns = []
map(write, jaxpr.outvars, used_outputs)
for eqn in jaxpr.eqns[::-1]:
used_outs = map(read, eqn.outvars)
if not any(used_outs) and not has_effects(eqn):
used_ins = [False] * len(eqn.invars)
else:
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
used_ins, new_eqn = rule(used_outs, eqn)
if new_eqn is not None:
new_eqns.append(new_eqn)
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
used_ins, new_eqn = rule(used_outs, eqn)
if new_eqn is not None:
new_eqns.append(new_eqn)
map(write, eqn.invars, used_ins)
used_inputs = map(read, jaxpr.invars)
used_inputs = map(op.or_, instantiate, used_inputs)
@ -1433,7 +1431,9 @@ DCERule = Callable[[list[bool], JaxprEqn],
def _default_dce_rule(
used_outs: list[bool], eqn: JaxprEqn
) -> tuple[list[bool], JaxprEqn]:
) -> tuple[list[bool], JaxprEqn | None]:
if not any(used_outs) and not has_effects(eqn):
return [False] * len(eqn.invars), None
return [True] * len(eqn.invars), eqn
dce_rules: dict[Primitive, DCERule] = {}
@ -1441,6 +1441,8 @@ dce_rules: dict[Primitive, DCERule] = {}
def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
) -> tuple[list[bool], JaxprEqn | None]:
if not any(used_outputs) and not has_effects(eqn):
return [False] * len(eqn.invars), None
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
update_params = call_param_updaters.get(eqn.primitive)
@ -1454,6 +1456,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
@ -1465,8 +1468,10 @@ def _cached_closed_call_dce(jaxpr_, used_outputs: tuple[bool, ...]
return core.ClosedJaxpr(new_jaxpr, consts), used_inputs
def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn
) -> tuple[list[bool], JaxprEqn]:
) -> tuple[list[bool], JaxprEqn | None]:
# TODO(mattjj): de-duplicate with above rule?
if not any(used_outputs) and not has_effects(eqn):
return [False] * len(eqn.invars), None
jaxpr_ = eqn.params['call_jaxpr']
closed_jaxpr, used_inputs = _cached_closed_call_dce(jaxpr_, tuple(used_outputs))
new_params = dict(eqn.params, call_jaxpr=closed_jaxpr)

View File

@ -1353,6 +1353,8 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval):
def _pmap_dce_rule(used_outputs, eqn):
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
axis_name = eqn.params["axis_name"]
with core.extend_axis_env_nd([(axis_name, eqn.params["global_axis_size"])]):
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)

View File

@ -642,7 +642,11 @@ def _ordered_unique(xs):
return list(d.keys())
def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
) -> tuple[list[bool], core.JaxprEqn]:
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
closed_branches = eqn.params['branches']
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]

View File

@ -944,7 +944,9 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params)
def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn]:
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_xs = len(jaxpr.in_avals) - num_consts - num_carry

View File

@ -2326,6 +2326,10 @@ def _dce_jaxpr_pjit(
def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
dced_jaxpr, used_inputs = _dce_jaxpr_pjit(
eqn.params['jaxpr'], tuple(used_outputs))

View File

@ -1660,6 +1660,8 @@ def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]:
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
mesh = eqn.params["mesh"]
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)

View File

@ -63,6 +63,7 @@ from jax._src.interpreters.partial_eval import (
debug_info_final as debug_info_final,
def_trivial_padding as def_trivial_padding,
forwarding_rules as forwarding_rules,
has_effects as has_effects,
infer_lambda_input_type as infer_lambda_input_type,
instantiate_const_at as instantiate_const_at,
make_jaxpr_effects as make_jaxpr_effects,