mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Add the len(arg) to the error message for static_argnums
Helps reduce the confusion on what is considered an argnum. Ideally there should be static_argkwg PiperOrigin-RevId: 734591856
This commit is contained in:
parent
9f37b5197f
commit
ccf7278292
@ -355,7 +355,7 @@ def _remat_static_argnums(fun, static_argnums, args):
|
||||
raise ValueError("the `static_argnums` argument to `jax.checkpoint` / "
|
||||
"`jax.remat` can only take integer values greater than or "
|
||||
"equal to `-len(args)` and less than `len(args)`, but got "
|
||||
f"{static_argnums}")
|
||||
f"{static_argnums}, while `len(args)` = {len(args)}")
|
||||
|
||||
if not static_argnums:
|
||||
return fun, args
|
||||
|
Loading…
x
Reference in New Issue
Block a user