jax authors e03f1d4fd1 Allows for splitting the transpose of a scan into a scan and a map.
This is an experimental feature exposed as an extra parameter: `scan(..., _split_transpose:bool)`.

If the parameter is true then the transpose of scan generates not just 2 scans
(forward and transpose of the linearized forward), but rather 3 scans: (i)
forward (as before), (ii) transposed scan that only computes loop-carried state
required for back-propagation, but saves other intermediate gradients; (iii) a
scan (actually a map) that uses any saved activation gradients and original
residuals to compute any other gradients.

Warning: this feature is somewhat experimental and may evolve or be rolled back.
PiperOrigin-RevId: 619991098
2024-03-28 10:54:50 -07:00
..
2023-08-01 16:43:13 -07:00
2024-03-27 10:28:57 -07:00