mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix bug in argnums_partial_except when static_argnums is unsorted.
This commit is contained in:
parent
310fafa9da
commit
9afc047bf0
@ -259,7 +259,7 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
|
||||
dyn_args = tuple(args[i] for i in dyn_argnums)
|
||||
|
||||
fixed_args = []
|
||||
for i in static_argnums:
|
||||
for i in sorted(static_argnums):
|
||||
# TODO(shoyer): set allow_invalid=True permanently after static_argnames.
|
||||
if allow_invalid and i >= len(args):
|
||||
continue
|
||||
|
@ -4313,6 +4313,21 @@ class APITest(jtu.JaxTestCase):
|
||||
for i in range(3): # Loop verifies we exercise both Python and C++ dispatch
|
||||
self.assertEqual(2 * i, g(2, i), msg=i)
|
||||
|
||||
def test_make_jaxpr_static_argnums_order(self):
|
||||
# https://github.com/jax-ml/jax/issues/28065
|
||||
def f(a, b, c):
|
||||
x = a + c
|
||||
y = b * c
|
||||
z = x - y
|
||||
return z
|
||||
|
||||
for static_argnums in [(1, 0), (0, 1)]:
|
||||
val = jax.jit(f, static_argnums=static_argnums)(1, 2, 3)
|
||||
self.assertEqual(val, -2)
|
||||
jaxpr = jax.make_jaxpr(f, static_argnums=static_argnums)(1, 2, 3)
|
||||
self.assertEqual(jaxpr.eqns[0].invars[0].val, 1)
|
||||
self.assertEqual(jaxpr.eqns[1].invars[0].val, 2)
|
||||
|
||||
def test_fastpath_cache_confusion(self):
|
||||
# https://github.com/jax-ml/jax/issues/12542
|
||||
@jax.jit
|
||||
|
Loading…
x
Reference in New Issue
Block a user