mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10262 from jakevdp:while-loop-error
PiperOrigin-RevId: 441527861
This commit is contained in:
commit
e5f19138d6
@ -193,6 +193,8 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
Returns:
|
||||
Loop value from the final iteration, of type ``a``.
|
||||
"""
|
||||
if not callable(body_fun):
|
||||
raise TypeError("lax.fori_loop: body_fun argument should be callable.")
|
||||
# TODO(phawkins): perhaps do more type checking here, better error messages.
|
||||
lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
|
||||
upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
|
||||
@ -275,6 +277,8 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
Returns:
|
||||
The output from the final iteration of body_fun, of type ``a``.
|
||||
"""
|
||||
if not (callable(body_fun) and callable(cond_fun)):
|
||||
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
|
||||
if config.jax_disable_jit:
|
||||
try:
|
||||
val = init_val
|
||||
@ -753,6 +757,8 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
Value (B) of ``branch(*operands)`` for the branch that was selected based
|
||||
on ``index``.
|
||||
"""
|
||||
if not all(callable(branch) for branch in branches):
|
||||
raise TypeError("lax.switch: branches argument should be a sequence of callables.")
|
||||
if operand is not _no_operand_sentinel:
|
||||
if operands:
|
||||
raise TypeError("if 'operand' keyword is passed then no positional "
|
||||
@ -837,6 +843,8 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
depending on the value of ``pred``. The type can be a scalar, array, or any
|
||||
pytree (nested Python tuple/list/dict) thereof.
|
||||
"""
|
||||
if not (callable(true_fun) and callable(false_fun)):
|
||||
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
||||
if operand is not _no_operand_sentinel:
|
||||
if operands:
|
||||
raise TypeError("if 'operand' keyword is passed then no positional "
|
||||
@ -929,6 +937,8 @@ def _cond_with_per_branch_args(pred,
|
||||
|
||||
Pred has to be a scalar type, collection types (list, tuple) are not supported
|
||||
"""
|
||||
if not (callable(true_fun) and callable(false_fun)):
|
||||
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
||||
return _cond(pred,
|
||||
lambda op: true_fun(op[0]),
|
||||
lambda op: false_fun(op[1]),
|
||||
@ -1462,6 +1472,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
loop carry value and the second element represents the stacked outputs of
|
||||
the second output of ``f`` when scanned over the leading axis of the inputs.
|
||||
"""
|
||||
if not callable(f):
|
||||
raise TypeError("lax.scan: f argument should be a callable.")
|
||||
xs_flat, xs_tree = tree_flatten(xs)
|
||||
|
||||
try:
|
||||
@ -2771,6 +2783,8 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
|
||||
University.
|
||||
"""
|
||||
if not callable(fn):
|
||||
raise TypeError("lax.associative_scan: fn argument should be callable.")
|
||||
elems_flat, tree = tree_flatten(elems)
|
||||
|
||||
if reverse:
|
||||
|
@ -96,6 +96,21 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
jax._src.lax.control_flow._initial_style_jaxpr.cache_clear()
|
||||
jax._src.lax.control_flow._initial_style_jaxprs_with_common_consts.cache_clear()
|
||||
|
||||
def testCallableErrors(self):
|
||||
not_callable = 42
|
||||
with self.assertRaisesRegex(TypeError, "lax.fori_loop.*callable.*"):
|
||||
lax.fori_loop(0, 1, not_callable, 0)
|
||||
with self.assertRaisesRegex(TypeError, "lax.while_loop.*callable.*"):
|
||||
lax.while_loop(not_callable, not_callable, 0)
|
||||
with self.assertRaisesRegex(TypeError, "lax.switch:.*callable.*"):
|
||||
lax.switch(0, [not_callable])
|
||||
with self.assertRaisesRegex(TypeError, "lax.cond.*callable.*"):
|
||||
lax.cond(0, not_callable, not_callable)
|
||||
with self.assertRaisesRegex(TypeError, "lax.scan.*callable.*"):
|
||||
lax.scan(not_callable, 0, 1)
|
||||
with self.assertRaisesRegex(TypeError, "lax.associative_scan.*callable.*"):
|
||||
lax.associative_scan(not_callable, 0)
|
||||
|
||||
def testWhileWithTuple(self):
|
||||
limit = 10
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user