Merge pull request #12540 from sharadmv:cond-lowering-fix

PiperOrigin-RevId: 477358889
This commit is contained in:
jax authors 2022-09-27 22:33:12 -07:00
commit 96abd9ac75
2 changed files with 65 additions and 1 deletions

View File

@ -1439,7 +1439,11 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
core.ordered_effects]
if cond_ordered_effects:
def cond(args):
return core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
# Pred can be batched
pred = core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
if batched:
pred = lax._reduce_or(pred, tuple(range(len(pred_aval.shape))))
return pred
def body(args):
return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args))
def new_cond(pred_args):

View File

@ -575,6 +575,66 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
x: 10
"""))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices(*disabled_backends)
def test_can_print_in_batched_while_cond(self, ordered):
def f(x):
def _cond(x):
debug_print("x: {x}", x=x, ordered=ordered)
return x < 5
def _body(x):
return x + 1
return lax.while_loop(_cond, _body, x)
with jtu.capture_stdout() as output:
jax.vmap(f)(jnp.arange(2))
jax.effects_barrier()
if ordered:
expected = _format_multiline("""
x: 0
x: 1
x: 1
x: 2
x: 2
x: 3
x: 3
x: 4
x: 4
x: 5
x: 5
x: 6
""")
self.assertEqual(output(), expected)
else:
# When the print is unordered, the `cond` is called an additional time
# after the `_body` runs, so we get more prints.
expected = _format_multiline("""
x: 0
x: 1
x: 0
x: 1
x: 1
x: 2
x: 1
x: 2
x: 2
x: 3
x: 2
x: 3
x: 3
x: 4
x: 3
x: 4
x: 4
x: 5
x: 4
x: 5
x: 5
x: 5
""")
self._assertLinesEqual(output(), expected)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))