mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #25798 from gnecula:fix_fori_error
PiperOrigin-RevId: 715258789
This commit is contained in:
commit
4f2f5fa53a
@ -608,8 +608,17 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):
|
||||
# Prefer this to functools.wraps because it does not create a reference to
|
||||
# the wrapped function.
|
||||
sourceinfo = fun_sourceinfo(wrapped)
|
||||
if sourceinfo is not None:
|
||||
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
|
||||
|
||||
# TODO(mattjj): make this function internal to this module
|
||||
def fun_sourceinfo(fun: Callable) -> str | None:
|
||||
res = getattr(fun, "__fun_sourceinfo__", None)
|
||||
if res is not None: return res
|
||||
while isinstance(fun, partial):
|
||||
fun = fun.func
|
||||
fun = inspect.unwrap(fun)
|
||||
|
@ -25,6 +25,7 @@ import weakref
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -1965,18 +1966,19 @@ def _fori_cond_fun(loop_carry):
|
||||
|
||||
@weakref_lru_cache
|
||||
def _fori_body_fun(body_fun):
|
||||
body_fun = weakref.ref(body_fun)
|
||||
body_fun_ref = weakref.ref(body_fun)
|
||||
|
||||
def while_body_fun(loop_carry):
|
||||
i, upper, x = loop_carry
|
||||
return lax.add(i, lax._const(i, 1)), upper, body_fun()(i, x)
|
||||
return lax.add(i, lax._const(i, 1)), upper, body_fun_ref()(i, x)
|
||||
return while_body_fun
|
||||
|
||||
@weakref_lru_cache
|
||||
def _fori_scan_body_fun(body_fun):
|
||||
body_fun = weakref.ref(body_fun)
|
||||
body_fun_ref = weakref.ref(body_fun)
|
||||
def scanned_fun(loop_carry, _):
|
||||
i, x = loop_carry
|
||||
return (i + 1, body_fun()(i, x)), None
|
||||
return (i + 1, body_fun_ref()(i, x)), None
|
||||
return scanned_fun
|
||||
|
||||
@api_boundary
|
||||
@ -2085,8 +2087,10 @@ def fori_loop(lower, upper, body_fun, init_val,
|
||||
# non-jit implementation of scan does not support length=0
|
||||
return init_val
|
||||
|
||||
scan_body = _fori_scan_body_fun(body_fun)
|
||||
api_util.save_wrapped_fun_sourceinfo(scan_body, body_fun)
|
||||
(_, result), _ = scan(
|
||||
_fori_scan_body_fun(body_fun),
|
||||
scan_body,
|
||||
(lower_, init_val),
|
||||
None,
|
||||
length=length,
|
||||
@ -2101,7 +2105,9 @@ def fori_loop(lower, upper, body_fun, init_val,
|
||||
lower = lax.convert_element_type(lower, dtype) # type: ignore
|
||||
if upper_dtype != dtype:
|
||||
upper = lax.convert_element_type(upper, dtype) # type: ignore
|
||||
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
|
||||
while_body_fun = _fori_body_fun(body_fun)
|
||||
api_util.save_wrapped_fun_sourceinfo(while_body_fun, body_fun)
|
||||
_, _, result = while_loop(_fori_cond_fun, while_body_fun,
|
||||
(lower, upper, init_val))
|
||||
return result
|
||||
|
||||
|
@ -589,6 +589,21 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
init = jnp.float32(10)
|
||||
self.assertEqual(fori_loop_with_static_upper_and_lower(init), init)
|
||||
|
||||
def test_fori_error_points_to_user_code(self):
|
||||
# See https://github.com/jax-ml/jax/issues/23637
|
||||
def my_body(_, c):
|
||||
return bool(c)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
jax.errors.TracerBoolConversionError,
|
||||
"occurred while tracing the function my_body at .*control_flow_test.py.* for scan"):
|
||||
jax.lax.fori_loop(0, 5, my_body, 3.)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
jax.errors.TracerBoolConversionError,
|
||||
"occurred while tracing the function my_body at .*control_flow_test.py.* for while_loop"):
|
||||
jax.jit(lambda ubound: jax.lax.fori_loop(0, ubound, my_body, 3.))(5)
|
||||
|
||||
def testForiLoopBatched(self):
|
||||
def body_fun(i, loop_carry):
|
||||
x, y = loop_carry
|
||||
|
Loading…
x
Reference in New Issue
Block a user