mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
err on None
predicate to lax.cond
This commit is contained in:
parent
9f96a0474e
commit
4cd0c68136
@ -148,7 +148,8 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
operand=_no_operand_sentinel, linear=None):
|
||||
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
||||
|
||||
``cond()`` has equivalent semantics to this Python implementation::
|
||||
Provided arguments are correctly typed, ``cond()`` has equivalent
|
||||
semantics to this Python implementation::
|
||||
|
||||
def cond(pred, true_fun, false_fun, *operands):
|
||||
if pred:
|
||||
@ -181,6 +182,8 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
operands = (operand,)
|
||||
del operand
|
||||
|
||||
if pred is None:
|
||||
raise TypeError("cond predicate is None")
|
||||
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
|
||||
raise TypeError(
|
||||
f"Pred must be a scalar, got {pred} of " +
|
||||
|
@ -628,6 +628,16 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertEqual(fun(4), cfun(4))
|
||||
self.assertEqual(fun(4), (8, 16))
|
||||
|
||||
def testCondPredIsNone(self):
|
||||
# see https://github.com/google/jax/issues/11574
|
||||
def f(pred, x):
|
||||
return lax.cond(pred, lambda x: x + 1, lambda x: x + 2, x)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "cond predicate is None",
|
||||
lambda: f(None, 1.))
|
||||
self.assertRaisesRegex(TypeError, "cond predicate is None",
|
||||
lambda: jax.jit(f)(None, 1.))
|
||||
|
||||
def testCondTwoOperands(self):
|
||||
# see https://github.com/google/jax/issues/8469
|
||||
add, mul = lax.add, lax.mul
|
||||
|
Loading…
x
Reference in New Issue
Block a user