Merge pull request #8488 from mattjj:issue8469

PiperOrigin-RevId: 408712381
This commit is contained in:
jax authors 2021-11-09 14:52:15 -08:00
commit 7b79497b44
2 changed files with 4 additions and 3 deletions

View File

@ -788,7 +788,10 @@ def cond(*args, **kwargs):
except TypeError:
pass
else:
return _cond_with_per_branch_args(*ba.args)
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
_, _, maybe_true_fun, _, maybe_false_fun = ba.args
if callable(maybe_true_fun) and callable(maybe_false_fun):
return _cond_with_per_branch_args(*ba.args)
return _cond(*args, **kwargs)

View File

@ -599,8 +599,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def testCondTwoOperands(self):
# see https://github.com/google/jax/issues/8469
self.skipTest("two-operand cond behavior is ambiguous (#8469)")
add, mul = lax.add, lax.mul
def fun(x):