From 4d715753c45fbc09e02cbb1a5e254e364ee9b896 Mon Sep 17 00:00:00 2001 From: Sharad Vikram <sharadmv@google.com> Date: Tue, 18 Mar 2025 18:41:35 -0700 Subject: [PATCH] Make sure to DCE read effects PiperOrigin-RevId: 738215055 --- jax/_src/interpreters/partial_eval.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 95c09ae94..07c516fd9 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -41,7 +41,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) -from jax._src.state.types import AbstractRef +from jax._src.state.types import AbstractRef, ReadEffect from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, @@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], def has_effects(eqn: JaxprEqn) -> bool: - effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} + effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect) + and not isinstance(e, ReadEffect)} return bool(effs)