From 4d715753c45fbc09e02cbb1a5e254e364ee9b896 Mon Sep 17 00:00:00 2001
From: Sharad Vikram <sharadmv@google.com>
Date: Tue, 18 Mar 2025 18:41:35 -0700
Subject: [PATCH] Make sure to DCE read effects

PiperOrigin-RevId: 738215055
---
 jax/_src/interpreters/partial_eval.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py
index 95c09ae94..07c516fd9 100644
--- a/jax/_src/interpreters/partial_eval.py
+++ b/jax/_src/interpreters/partial_eval.py
@@ -41,7 +41,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
                            JaxprEqn, Primitive, ShapedArray, DShapedArray,
                            mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
                            InputType, OutputType, get_referent, JaxprEqnContext)
-from jax._src.state.types import AbstractRef
+from jax._src.state.types import AbstractRef, ReadEffect
 from jax._src.tree_util import (PyTreeDef, treedef_tuple,
                                 tree_flatten, tree_structure)
 from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
 
 
 def has_effects(eqn: JaxprEqn) -> bool:
-  effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
+  effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)
+          and not isinstance(e, ReadEffect)}
   return bool(effs)