mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #12540 from sharadmv:cond-lowering-fix
PiperOrigin-RevId: 477358889
This commit is contained in:
commit
96abd9ac75
@ -1439,7 +1439,11 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
core.ordered_effects]
|
||||
if cond_ordered_effects:
|
||||
def cond(args):
|
||||
return core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
|
||||
# Pred can be batched
|
||||
pred = core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
|
||||
if batched:
|
||||
pred = lax._reduce_or(pred, tuple(range(len(pred_aval.shape))))
|
||||
return pred
|
||||
def body(args):
|
||||
return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args))
|
||||
def new_cond(pred_args):
|
||||
|
@ -575,6 +575,66 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
x: 10
|
||||
"""))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
|
||||
for ordered in [False, True]))
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
def test_can_print_in_batched_while_cond(self, ordered):
|
||||
def f(x):
|
||||
def _cond(x):
|
||||
debug_print("x: {x}", x=x, ordered=ordered)
|
||||
return x < 5
|
||||
def _body(x):
|
||||
return x + 1
|
||||
return lax.while_loop(_cond, _body, x)
|
||||
with jtu.capture_stdout() as output:
|
||||
jax.vmap(f)(jnp.arange(2))
|
||||
jax.effects_barrier()
|
||||
if ordered:
|
||||
expected = _format_multiline("""
|
||||
x: 0
|
||||
x: 1
|
||||
x: 1
|
||||
x: 2
|
||||
x: 2
|
||||
x: 3
|
||||
x: 3
|
||||
x: 4
|
||||
x: 4
|
||||
x: 5
|
||||
x: 5
|
||||
x: 6
|
||||
""")
|
||||
self.assertEqual(output(), expected)
|
||||
else:
|
||||
# When the print is unordered, the `cond` is called an additional time
|
||||
# after the `_body` runs, so we get more prints.
|
||||
expected = _format_multiline("""
|
||||
x: 0
|
||||
x: 1
|
||||
x: 0
|
||||
x: 1
|
||||
x: 1
|
||||
x: 2
|
||||
x: 1
|
||||
x: 2
|
||||
x: 2
|
||||
x: 3
|
||||
x: 2
|
||||
x: 3
|
||||
x: 3
|
||||
x: 4
|
||||
x: 3
|
||||
x: 4
|
||||
x: 4
|
||||
x: 5
|
||||
x: 4
|
||||
x: 5
|
||||
x: 5
|
||||
x: 5
|
||||
""")
|
||||
self._assertLinesEqual(output(), expected)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
|
||||
for ordered in [False, True]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user