mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
update checkpoint
attributes according to functools.wraps
This updates the signature in addition to `__doc__`, and that gets picked up by generated API docs.
This commit is contained in:
parent
e81578a9fa
commit
7d494a3852
@ -3050,6 +3050,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
|
||||
@functools.wraps(new_checkpoint) # config.jax_new_checkpoint is True by default
|
||||
def checkpoint(fun: Callable, *,
|
||||
concrete: bool = False,
|
||||
prevent_cse: bool = True,
|
||||
@ -3121,7 +3122,6 @@ def checkpoint(fun: Callable, *,
|
||||
differentiated=False, policy=policy)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
return remat_f
|
||||
checkpoint.__doc__ = new_checkpoint.__doc__
|
||||
remat = checkpoint # type: ignore
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user