rocm_jax/jax/experimental
George Necula 3021d3e2e2 [hcb] Add support for remat2 to host_callback
A callback under ad_checkpoint.checkpoint will be invoked
twice when taking the gradient: once during the forward pass
and once again during the backward pass when the residuals
for the forward pass are rematerialized.
2021-12-15 10:32:15 +02:00
..
2021-12-08 19:34:20 +00:00
2021-12-13 21:58:09 -08:00