update checkpoint attributes according to functools.wraps

This updates the signature in addition to `__doc__`, and that gets
picked up by generated API docs.
This commit is contained in:
Roy Frostig 2022-08-10 13:04:24 -07:00
parent e81578a9fa
commit 7d494a3852

View File

@ -3050,6 +3050,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
return tree_unflatten(out_tree(), out)
@functools.wraps(new_checkpoint) # config.jax_new_checkpoint is True by default
def checkpoint(fun: Callable, *,
concrete: bool = False,
prevent_cse: bool = True,
@ -3121,7 +3122,6 @@ def checkpoint(fun: Callable, *,
differentiated=False, policy=policy)
return tree_unflatten(out_tree(), out_flat)
return remat_f
checkpoint.__doc__ = new_checkpoint.__doc__
remat = checkpoint # type: ignore