From 310ff7347c0c4e2487f14f558d47a044be20effb Mon Sep 17 00:00:00 2001 From: James Martens Date: Tue, 12 Nov 2024 10:36:14 -0800 Subject: [PATCH] Change to internal dead code elimination. Now the functions in `dce_rules` are responsible for checking if the equation has no used outputs or effects, and behaving appropriately in that case (which usually means eliminating said equation). PiperOrigin-RevId: 695789033 --- jax/_src/ad_checkpoint.py | 2 ++ jax/_src/custom_derivatives.py | 2 ++ jax/_src/interpreters/partial_eval.py | 33 +++++++++++++---------- jax/_src/interpreters/pxla.py | 2 ++ jax/_src/lax/control_flow/conditionals.py | 6 ++++- jax/_src/lax/control_flow/loops.py | 4 ++- jax/_src/pjit.py | 4 +++ jax/experimental/shard_map.py | 2 ++ jax/interpreters/partial_eval.py | 1 + 9 files changed, 40 insertions(+), 16 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 5ed0b0192..fc135ac8f 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -716,6 +716,8 @@ batching.fancy_primitive_batchers[remat_p] = remat_vmap # TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) new_params = dict(eqn.params, jaxpr=new_jaxpr) if (not any(used_inputs) and not any(used_outputs) and diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 375efeb71..e37494c4f 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1531,6 +1531,8 @@ def _remat_opt_transpose( "remat optimization for custom_vjp does not support higher-order AD") def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): + if not any(used_outs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]]) outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] if any(used_res): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ad97ef325..5431762d6 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -360,7 +360,7 @@ class JaxprTrace(Trace['JaxprTracer']): staged_out_axes, _ = partition_list(out_knowns, out_axes) staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,) - # Create the input tracers for the staged-out (unkonwn-value) call. + # Create the input tracers for the staged-out (unknown-value) call. const_tracers = map(self.new_instantiated_const, res) env_tracers = map(self.to_jaxpr_tracer, env) unknown_arg_tracers = [t for t in tracers if not t.is_known()] @@ -1382,6 +1382,11 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], return new_jaxpr, used_consts, used_inputs +def has_effects(eqn: JaxprEqn) -> bool: + effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} + return bool(effs) + + @weakref_lru_cache def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...], instantiate: tuple[bool, ...] @@ -1395,21 +1400,14 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...], if type(x) is Var: env[x] = read(x) or b - def has_effects(eqn: JaxprEqn) -> bool: - effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} - return bool(effs) - new_eqns = [] map(write, jaxpr.outvars, used_outputs) for eqn in jaxpr.eqns[::-1]: used_outs = map(read, eqn.outvars) - if not any(used_outs) and not has_effects(eqn): - used_ins = [False] * len(eqn.invars) - else: - rule = dce_rules.get(eqn.primitive, _default_dce_rule) - used_ins, new_eqn = rule(used_outs, eqn) - if new_eqn is not None: - new_eqns.append(new_eqn) + rule = dce_rules.get(eqn.primitive, _default_dce_rule) + used_ins, new_eqn = rule(used_outs, eqn) + if new_eqn is not None: + new_eqns.append(new_eqn) map(write, eqn.invars, used_ins) used_inputs = map(read, jaxpr.invars) used_inputs = map(op.or_, instantiate, used_inputs) @@ -1433,7 +1431,9 @@ DCERule = Callable[[list[bool], JaxprEqn], def _default_dce_rule( used_outs: list[bool], eqn: JaxprEqn - ) -> tuple[list[bool], JaxprEqn]: + ) -> tuple[list[bool], JaxprEqn | None]: + if not any(used_outs) and not has_effects(eqn): + return [False] * len(eqn.invars), None return [True] * len(eqn.invars), eqn dce_rules: dict[Primitive, DCERule] = {} @@ -1441,6 +1441,8 @@ dce_rules: dict[Primitive, DCERule] = {} def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn ) -> tuple[list[bool], JaxprEqn | None]: + if not any(used_outputs) and not has_effects(eqn): + return [False] * len(eqn.invars), None new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) new_params = dict(eqn.params, call_jaxpr=new_jaxpr) update_params = call_param_updaters.get(eqn.primitive) @@ -1454,6 +1456,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn [v for v, used in zip(eqn.outvars, used_outputs) if used], eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx) return used_inputs, new_eqn + dce_rules[core.call_p] = dce_jaxpr_call_rule @@ -1465,8 +1468,10 @@ def _cached_closed_call_dce(jaxpr_, used_outputs: tuple[bool, ...] return core.ClosedJaxpr(new_jaxpr, consts), used_inputs def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn - ) -> tuple[list[bool], JaxprEqn]: + ) -> tuple[list[bool], JaxprEqn | None]: # TODO(mattjj): de-duplicate with above rule? + if not any(used_outputs) and not has_effects(eqn): + return [False] * len(eqn.invars), None jaxpr_ = eqn.params['call_jaxpr'] closed_jaxpr, used_inputs = _cached_closed_call_dce(jaxpr_, tuple(used_outputs)) new_params = dict(eqn.params, call_jaxpr=closed_jaxpr) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9a17194d4..316fbc077 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1353,6 +1353,8 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval): def _pmap_dce_rule(used_outputs, eqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None axis_name = eqn.params["axis_name"] with core.extend_axis_env_nd([(axis_name, eqn.params["global_axis_size"])]): new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 6333638de..9e1f7e04c 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -642,7 +642,11 @@ def _ordered_unique(xs): return list(d.keys()) def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, - ) -> tuple[list[bool], core.JaxprEqn]: + ) -> tuple[list[bool], core.JaxprEqn | None]: + + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + closed_branches = eqn.params['branches'] branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index b5bb8658e..d15917b8b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -944,7 +944,9 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params) def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn - ) -> tuple[list[bool], core.JaxprEqn]: + ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None jaxpr = eqn.params['jaxpr'] num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] num_xs = len(jaxpr.in_avals) - num_consts - num_carry diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index e2c50f2dc..f1844c7ba 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2326,6 +2326,10 @@ def _dce_jaxpr_pjit( def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + dced_jaxpr, used_inputs = _dce_jaxpr_pjit( eqn.params['jaxpr'], tuple(used_outputs)) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 7ddd3805b..3a9446862 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1660,6 +1660,8 @@ def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]: # TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None mesh = eqn.params["mesh"] with core.extend_axis_env_nd(mesh.shape.items()): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 1aa3ebc67..dca438996 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -63,6 +63,7 @@ from jax._src.interpreters.partial_eval import ( debug_info_final as debug_info_final, def_trivial_padding as def_trivial_padding, forwarding_rules as forwarding_rules, + has_effects as has_effects, infer_lambda_input_type as infer_lambda_input_type, instantiate_const_at as instantiate_const_at, make_jaxpr_effects as make_jaxpr_effects,