Fix bug in argnums_partial_except when static_argnums is unsorted.

This commit is contained in:
Dan Foreman-Mackey 2025-04-16 16:18:10 -04:00
parent 310fafa9da
commit 9afc047bf0
2 changed files with 16 additions and 1 deletions

View File

@ -259,7 +259,7 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
dyn_args = tuple(args[i] for i in dyn_argnums)
fixed_args = []
for i in static_argnums:
for i in sorted(static_argnums):
# TODO(shoyer): set allow_invalid=True permanently after static_argnames.
if allow_invalid and i >= len(args):
continue

View File

@ -4313,6 +4313,21 @@ class APITest(jtu.JaxTestCase):
for i in range(3): # Loop verifies we exercise both Python and C++ dispatch
self.assertEqual(2 * i, g(2, i), msg=i)
def test_make_jaxpr_static_argnums_order(self):
# https://github.com/jax-ml/jax/issues/28065
def f(a, b, c):
x = a + c
y = b * c
z = x - y
return z
for static_argnums in [(1, 0), (0, 1)]:
val = jax.jit(f, static_argnums=static_argnums)(1, 2, 3)
self.assertEqual(val, -2)
jaxpr = jax.make_jaxpr(f, static_argnums=static_argnums)(1, 2, 3)
self.assertEqual(jaxpr.eqns[0].invars[0].val, 1)
self.assertEqual(jaxpr.eqns[1].invars[0].val, 2)
def test_fastpath_cache_confusion(self):
# https://github.com/jax-ml/jax/issues/12542
@jax.jit