mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 15:36:07 +00:00

When we do run_scoped[jaxpr, R1,R2], it can't be assumed that references corresponding to R1 and R2 can be safely discharged. Sometimes they can (eg Accumulator) but sometimes they can't (eg SMEM scratch). It should be up to the lowering rule to do such discharging. This further means that during lowering there is no guarantee that the references will not be used/returned by nested scoped blocks so we also remove that check. PiperOrigin-RevId: 722137352