Merge pull request #10262 from jakevdp:while-loop-error

PiperOrigin-RevId: 441527861
This commit is contained in:
jax authors 2022-04-13 11:08:25 -07:00
commit e5f19138d6
2 changed files with 29 additions and 0 deletions

View File

@ -193,6 +193,8 @@ def fori_loop(lower, upper, body_fun, init_val):
Returns:
Loop value from the final iteration, of type ``a``.
"""
if not callable(body_fun):
raise TypeError("lax.fori_loop: body_fun argument should be callable.")
# TODO(phawkins): perhaps do more type checking here, better error messages.
lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
@ -275,6 +277,8 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
Returns:
The output from the final iteration of body_fun, of type ``a``.
"""
if not (callable(body_fun) and callable(cond_fun)):
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
if config.jax_disable_jit:
try:
val = init_val
@ -753,6 +757,8 @@ def switch(index, branches: Sequence[Callable], *operands,
Value (B) of ``branch(*operands)`` for the branch that was selected based
on ``index``.
"""
if not all(callable(branch) for branch in branches):
raise TypeError("lax.switch: branches argument should be a sequence of callables.")
if operand is not _no_operand_sentinel:
if operands:
raise TypeError("if 'operand' keyword is passed then no positional "
@ -837,6 +843,8 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
depending on the value of ``pred``. The type can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof.
"""
if not (callable(true_fun) and callable(false_fun)):
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
if operand is not _no_operand_sentinel:
if operands:
raise TypeError("if 'operand' keyword is passed then no positional "
@ -929,6 +937,8 @@ def _cond_with_per_branch_args(pred,
Pred has to be a scalar type, collection types (list, tuple) are not supported
"""
if not (callable(true_fun) and callable(false_fun)):
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
return _cond(pred,
lambda op: true_fun(op[0]),
lambda op: false_fun(op[1]),
@ -1462,6 +1472,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
loop carry value and the second element represents the stacked outputs of
the second output of ``f`` when scanned over the leading axis of the inputs.
"""
if not callable(f):
raise TypeError("lax.scan: f argument should be a callable.")
xs_flat, xs_tree = tree_flatten(xs)
try:
@ -2771,6 +2783,8 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
University.
"""
if not callable(fn):
raise TypeError("lax.associative_scan: fn argument should be callable.")
elems_flat, tree = tree_flatten(elems)
if reverse:

View File

@ -96,6 +96,21 @@ class LaxControlFlowTest(jtu.JaxTestCase):
jax._src.lax.control_flow._initial_style_jaxpr.cache_clear()
jax._src.lax.control_flow._initial_style_jaxprs_with_common_consts.cache_clear()
def testCallableErrors(self):
not_callable = 42
with self.assertRaisesRegex(TypeError, "lax.fori_loop.*callable.*"):
lax.fori_loop(0, 1, not_callable, 0)
with self.assertRaisesRegex(TypeError, "lax.while_loop.*callable.*"):
lax.while_loop(not_callable, not_callable, 0)
with self.assertRaisesRegex(TypeError, "lax.switch:.*callable.*"):
lax.switch(0, [not_callable])
with self.assertRaisesRegex(TypeError, "lax.cond.*callable.*"):
lax.cond(0, not_callable, not_callable)
with self.assertRaisesRegex(TypeError, "lax.scan.*callable.*"):
lax.scan(not_callable, 0, 1)
with self.assertRaisesRegex(TypeError, "lax.associative_scan.*callable.*"):
lax.associative_scan(not_callable, 0)
def testWhileWithTuple(self):
limit = 10