mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Make sure to DCE read effects
PiperOrigin-RevId: 738215055
This commit is contained in:
parent
8c7a55ea82
commit
4d715753c4
@ -41,7 +41,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
JaxprEqn, Primitive, ShapedArray, DShapedArray,
|
||||
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent, JaxprEqnContext)
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.state.types import AbstractRef, ReadEffect
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
|
||||
tree_flatten, tree_structure)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||
|
||||
|
||||
def has_effects(eqn: JaxprEqn) -> bool:
|
||||
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
|
||||
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)
|
||||
and not isinstance(e, ReadEffect)}
|
||||
return bool(effs)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user