Yash Katariya c41d271175 Add memories support to remat.
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.

Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.

A very basic offloadable policy can look like this:

```
def policy(prim, *avals, **params):
  return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
2023-09-12 20:50:05 -07:00
..
2023-09-12 20:50:05 -07:00
2023-07-21 14:49:44 -04:00