This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
* add benchmarks for new remat eager performance, and some caching to make those
benchmarks fast
* warn when the old-remat-exclusive `concrete` feature is used, with an
actionable message pointing to the new recommended approach involving static_argnums
* add the static_argnums parameter to both new and old remt
* update docstrings (and de-duplicate them to)
* add new tests, especially around caching and errors/warnings
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.
PiperOrigin-RevId: 406162068