1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

Add the len(arg) to the error message for static_argnums

Helps reduce the confusion on what is considered an argnum.
Ideally there should be static_argkwg

PiperOrigin-RevId: 734591856
This commit is contained in:
jax authors 2025-03-07 09:49:07 -08:00
parent 9f37b5197f
commit ccf7278292

@ -355,7 +355,7 @@ def _remat_static_argnums(fun, static_argnums, args):
raise ValueError("the `static_argnums` argument to `jax.checkpoint` / "
"`jax.remat` can only take integer values greater than or "
"equal to `-len(args)` and less than `len(args)`, but got "
f"{static_argnums}")
f"{static_argnums}, while `len(args)` = {len(args)}")
if not static_argnums:
return fun, args