mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #8488 from mattjj:issue8469
PiperOrigin-RevId: 408712381
This commit is contained in:
commit
7b79497b44
@ -788,7 +788,10 @@ def cond(*args, **kwargs):
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
return _cond_with_per_branch_args(*ba.args)
|
||||
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
|
||||
_, _, maybe_true_fun, _, maybe_false_fun = ba.args
|
||||
if callable(maybe_true_fun) and callable(maybe_false_fun):
|
||||
return _cond_with_per_branch_args(*ba.args)
|
||||
|
||||
return _cond(*args, **kwargs)
|
||||
|
||||
|
@ -599,8 +599,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def testCondTwoOperands(self):
|
||||
# see https://github.com/google/jax/issues/8469
|
||||
self.skipTest("two-operand cond behavior is ambiguous (#8469)")
|
||||
|
||||
add, mul = lax.add, lax.mul
|
||||
|
||||
def fun(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user