mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

This commit involves a few things, which are all united in being about landing the new remat (aka new checkpoint) implementation: * add benchmarks for new remat eager performance, and some caching to make those benchmarks fast * warn when the old-remat-exclusive `concrete` feature is used, with an actionable message pointing to the new recommended approach involving static_argnums * add the static_argnums parameter to both new and old remt * update docstrings (and de-duplicate them to) * add new tests, especially around caching and errors/warnings