mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Change to internal dead code elimination. Now the functions in dce_rules
are responsible for checking if the equation has no used outputs or effects, and behaving appropriately in that case (which usually means eliminating said equation).
PiperOrigin-RevId: 695789033
This commit is contained in:
parent
3a5ac487a6
commit
310ff7347c
@ -716,6 +716,8 @@ batching.fancy_primitive_batchers[remat_p] = remat_vmap
|
||||
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
|
||||
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
) -> tuple[list[bool], core.JaxprEqn | None]:
|
||||
if not any(used_outputs) and not pe.has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
|
||||
new_params = dict(eqn.params, jaxpr=new_jaxpr)
|
||||
if (not any(used_inputs) and not any(used_outputs) and
|
||||
|
@ -1531,6 +1531,8 @@ def _remat_opt_transpose(
|
||||
"remat optimization for custom_vjp does not support higher-order AD")
|
||||
|
||||
def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
|
||||
if not any(used_outs) and not pe.has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]])
|
||||
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
|
||||
if any(used_res):
|
||||
|
@ -360,7 +360,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
staged_out_axes, _ = partition_list(out_knowns, out_axes)
|
||||
staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,)
|
||||
|
||||
# Create the input tracers for the staged-out (unkonwn-value) call.
|
||||
# Create the input tracers for the staged-out (unknown-value) call.
|
||||
const_tracers = map(self.new_instantiated_const, res)
|
||||
env_tracers = map(self.to_jaxpr_tracer, env)
|
||||
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
|
||||
@ -1382,6 +1382,11 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||
return new_jaxpr, used_consts, used_inputs
|
||||
|
||||
|
||||
def has_effects(eqn: JaxprEqn) -> bool:
|
||||
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
|
||||
return bool(effs)
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
|
||||
instantiate: tuple[bool, ...]
|
||||
@ -1395,21 +1400,14 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
|
||||
if type(x) is Var:
|
||||
env[x] = read(x) or b
|
||||
|
||||
def has_effects(eqn: JaxprEqn) -> bool:
|
||||
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
|
||||
return bool(effs)
|
||||
|
||||
new_eqns = []
|
||||
map(write, jaxpr.outvars, used_outputs)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
used_outs = map(read, eqn.outvars)
|
||||
if not any(used_outs) and not has_effects(eqn):
|
||||
used_ins = [False] * len(eqn.invars)
|
||||
else:
|
||||
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
|
||||
used_ins, new_eqn = rule(used_outs, eqn)
|
||||
if new_eqn is not None:
|
||||
new_eqns.append(new_eqn)
|
||||
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
|
||||
used_ins, new_eqn = rule(used_outs, eqn)
|
||||
if new_eqn is not None:
|
||||
new_eqns.append(new_eqn)
|
||||
map(write, eqn.invars, used_ins)
|
||||
used_inputs = map(read, jaxpr.invars)
|
||||
used_inputs = map(op.or_, instantiate, used_inputs)
|
||||
@ -1433,7 +1431,9 @@ DCERule = Callable[[list[bool], JaxprEqn],
|
||||
|
||||
def _default_dce_rule(
|
||||
used_outs: list[bool], eqn: JaxprEqn
|
||||
) -> tuple[list[bool], JaxprEqn]:
|
||||
) -> tuple[list[bool], JaxprEqn | None]:
|
||||
if not any(used_outs) and not has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
return [True] * len(eqn.invars), eqn
|
||||
|
||||
dce_rules: dict[Primitive, DCERule] = {}
|
||||
@ -1441,6 +1441,8 @@ dce_rules: dict[Primitive, DCERule] = {}
|
||||
|
||||
def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
|
||||
) -> tuple[list[bool], JaxprEqn | None]:
|
||||
if not any(used_outputs) and not has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
||||
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
|
||||
update_params = call_param_updaters.get(eqn.primitive)
|
||||
@ -1454,6 +1456,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx)
|
||||
return used_inputs, new_eqn
|
||||
|
||||
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
||||
|
||||
|
||||
@ -1465,8 +1468,10 @@ def _cached_closed_call_dce(jaxpr_, used_outputs: tuple[bool, ...]
|
||||
return core.ClosedJaxpr(new_jaxpr, consts), used_inputs
|
||||
|
||||
def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn
|
||||
) -> tuple[list[bool], JaxprEqn]:
|
||||
) -> tuple[list[bool], JaxprEqn | None]:
|
||||
# TODO(mattjj): de-duplicate with above rule?
|
||||
if not any(used_outputs) and not has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
jaxpr_ = eqn.params['call_jaxpr']
|
||||
closed_jaxpr, used_inputs = _cached_closed_call_dce(jaxpr_, tuple(used_outputs))
|
||||
new_params = dict(eqn.params, call_jaxpr=closed_jaxpr)
|
||||
|
@ -1353,6 +1353,8 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval):
|
||||
|
||||
def _pmap_dce_rule(used_outputs, eqn):
|
||||
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
|
||||
if not any(used_outputs) and not pe.has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
axis_name = eqn.params["axis_name"]
|
||||
with core.extend_axis_env_nd([(axis_name, eqn.params["global_axis_size"])]):
|
||||
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
||||
|
@ -642,7 +642,11 @@ def _ordered_unique(xs):
|
||||
return list(d.keys())
|
||||
|
||||
def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
|
||||
) -> tuple[list[bool], core.JaxprEqn]:
|
||||
) -> tuple[list[bool], core.JaxprEqn | None]:
|
||||
|
||||
if not any(used_outputs) and not pe.has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
|
||||
closed_branches = eqn.params['branches']
|
||||
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]
|
||||
|
||||
|
@ -944,7 +944,9 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
|
||||
return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params)
|
||||
|
||||
def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
) -> tuple[list[bool], core.JaxprEqn]:
|
||||
) -> tuple[list[bool], core.JaxprEqn | None]:
|
||||
if not any(used_outputs) and not pe.has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
|
||||
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
|
||||
|
@ -2326,6 +2326,10 @@ def _dce_jaxpr_pjit(
|
||||
|
||||
def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
) -> tuple[list[bool], core.JaxprEqn | None]:
|
||||
|
||||
if not any(used_outputs) and not pe.has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
|
||||
dced_jaxpr, used_inputs = _dce_jaxpr_pjit(
|
||||
eqn.params['jaxpr'], tuple(used_outputs))
|
||||
|
||||
|
@ -1660,6 +1660,8 @@ def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]:
|
||||
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
|
||||
def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
) -> tuple[list[bool], core.JaxprEqn | None]:
|
||||
if not any(used_outputs) and not pe.has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
mesh = eqn.params["mesh"]
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
|
||||
|
@ -63,6 +63,7 @@ from jax._src.interpreters.partial_eval import (
|
||||
debug_info_final as debug_info_final,
|
||||
def_trivial_padding as def_trivial_padding,
|
||||
forwarding_rules as forwarding_rules,
|
||||
has_effects as has_effects,
|
||||
infer_lambda_input_type as infer_lambda_input_type,
|
||||
instantiate_const_at as instantiate_const_at,
|
||||
make_jaxpr_effects as make_jaxpr_effects,
|
||||
|
Loading…
x
Reference in New Issue
Block a user