err on None predicate to lax.cond

This commit is contained in:
Roy Frostig 2022-07-26 13:12:16 -07:00
parent 9f96a0474e
commit 4cd0c68136
2 changed files with 14 additions and 1 deletions

View File

@ -148,7 +148,8 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.
``cond()`` has equivalent semantics to this Python implementation::
Provided arguments are correctly typed, ``cond()`` has equivalent
semantics to this Python implementation::
def cond(pred, true_fun, false_fun, *operands):
if pred:
@ -181,6 +182,8 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operands = (operand,)
del operand
if pred is None:
raise TypeError("cond predicate is None")
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
raise TypeError(
f"Pred must be a scalar, got {pred} of " +

View File

@ -628,6 +628,16 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(fun(4), cfun(4))
self.assertEqual(fun(4), (8, 16))
def testCondPredIsNone(self):
# see https://github.com/google/jax/issues/11574
def f(pred, x):
return lax.cond(pred, lambda x: x + 1, lambda x: x + 2, x)
self.assertRaisesRegex(TypeError, "cond predicate is None",
lambda: f(None, 1.))
self.assertRaisesRegex(TypeError, "cond predicate is None",
lambda: jax.jit(f)(None, 1.))
def testCondTwoOperands(self):
# see https://github.com/google/jax/issues/8469
add, mul = lax.add, lax.mul