A callback under ad_checkpoint.checkpoint will be invoked twice when taking the gradient: once during the forward pass and once again during the backward pass when the residuals for the forward pass are rematerialized.
jax
.
fast_path
contiguous_submeshes
mesh_utils.create_device_mesh()
jax.experimental
jax.example_libraries