Make sure to DCE read effects

PiperOrigin-RevId: 738215055
This commit is contained in:
Sharad Vikram 2025-03-18 18:41:35 -07:00 committed by jax authors
parent 8c7a55ea82
commit 4d715753c4

View File

@ -41,7 +41,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src.state.types import AbstractRef, ReadEffect
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
def has_effects(eqn: JaxprEqn) -> bool:
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)
and not isinstance(e, ReadEffect)}
return bool(effs)