Merge pull request #25798 from gnecula:fix_fori_error

PiperOrigin-RevId: 715258789
This commit is contained in:
jax authors 2025-01-14 00:01:30 -08:00
commit 4f2f5fa53a
3 changed files with 36 additions and 6 deletions

View File

@ -608,8 +608,17 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
except (ValueError, TypeError):
return None
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):
# Prefer this to functools.wraps because it does not create a reference to
# the wrapped function.
sourceinfo = fun_sourceinfo(wrapped)
if sourceinfo is not None:
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str | None:
res = getattr(fun, "__fun_sourceinfo__", None)
if res is not None: return res
while isinstance(fun, partial):
fun = fun.func
fun = inspect.unwrap(fun)

View File

@ -25,6 +25,7 @@ import weakref
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
@ -1965,18 +1966,19 @@ def _fori_cond_fun(loop_carry):
@weakref_lru_cache
def _fori_body_fun(body_fun):
body_fun = weakref.ref(body_fun)
body_fun_ref = weakref.ref(body_fun)
def while_body_fun(loop_carry):
i, upper, x = loop_carry
return lax.add(i, lax._const(i, 1)), upper, body_fun()(i, x)
return lax.add(i, lax._const(i, 1)), upper, body_fun_ref()(i, x)
return while_body_fun
@weakref_lru_cache
def _fori_scan_body_fun(body_fun):
body_fun = weakref.ref(body_fun)
body_fun_ref = weakref.ref(body_fun)
def scanned_fun(loop_carry, _):
i, x = loop_carry
return (i + 1, body_fun()(i, x)), None
return (i + 1, body_fun_ref()(i, x)), None
return scanned_fun
@api_boundary
@ -2085,8 +2087,10 @@ def fori_loop(lower, upper, body_fun, init_val,
# non-jit implementation of scan does not support length=0
return init_val
scan_body = _fori_scan_body_fun(body_fun)
api_util.save_wrapped_fun_sourceinfo(scan_body, body_fun)
(_, result), _ = scan(
_fori_scan_body_fun(body_fun),
scan_body,
(lower_, init_val),
None,
length=length,
@ -2101,7 +2105,9 @@ def fori_loop(lower, upper, body_fun, init_val,
lower = lax.convert_element_type(lower, dtype) # type: ignore
if upper_dtype != dtype:
upper = lax.convert_element_type(upper, dtype) # type: ignore
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
while_body_fun = _fori_body_fun(body_fun)
api_util.save_wrapped_fun_sourceinfo(while_body_fun, body_fun)
_, _, result = while_loop(_fori_cond_fun, while_body_fun,
(lower, upper, init_val))
return result

View File

@ -589,6 +589,21 @@ class LaxControlFlowTest(jtu.JaxTestCase):
init = jnp.float32(10)
self.assertEqual(fori_loop_with_static_upper_and_lower(init), init)
def test_fori_error_points_to_user_code(self):
# See https://github.com/jax-ml/jax/issues/23637
def my_body(_, c):
return bool(c)
with self.assertRaisesRegex(
jax.errors.TracerBoolConversionError,
"occurred while tracing the function my_body at .*control_flow_test.py.* for scan"):
jax.lax.fori_loop(0, 5, my_body, 3.)
with self.assertRaisesRegex(
jax.errors.TracerBoolConversionError,
"occurred while tracing the function my_body at .*control_flow_test.py.* for while_loop"):
jax.jit(lambda ubound: jax.lax.fori_loop(0, ubound, my_body, 3.))(5)
def testForiLoopBatched(self):
def body_fun(i, loop_carry):
x, y = loop_carry