diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 95c09ae94..07c516fd9 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -41,7 +41,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) -from jax._src.state.types import AbstractRef +from jax._src.state.types import AbstractRef, ReadEffect from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, @@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], def has_effects(eqn: JaxprEqn) -> bool: - effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} + effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect) + and not isinstance(e, ReadEffect)} return bool(effs)