mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 10:56:06 +00:00

This is an experimental feature exposed as an extra parameter: `scan(..., _split_transpose:bool)`. If the parameter is true then the transpose of scan generates not just 2 scans (forward and transpose of the linearized forward), but rather 3 scans: (i) forward (as before), (ii) transposed scan that only computes loop-carried state required for back-propagation, but saves other intermediate gradients; (iii) a scan (actually a map) that uses any saved activation gradients and original residuals to compute any other gradients. Warning: this feature is somewhat experimental and may evolve or be rolled back. PiperOrigin-RevId: 619991098