mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 22:06:06 +00:00

We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>