Enable print in while loops/scan

This commit is contained in:
Sharad Vikram 2022-05-04 11:11:02 -07:00
parent 97b7fd7315
commit dc42d7bb8e
5 changed files with 199 additions and 37 deletions

View File

@ -23,12 +23,15 @@ from jax import tree_util
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.lax import control_flow as lcf
DebugEffect = enum.Enum('DebugEffect', ['PRINT', 'ORDERED_PRINT'])
core.ordered_effects.add(DebugEffect.ORDERED_PRINT)
mlir.lowerable_effects.add(DebugEffect.PRINT)
mlir.lowerable_effects.add(DebugEffect.ORDERED_PRINT)
lcf.allowed_effects.add(DebugEffect.PRINT)
lcf.allowed_effects.add(DebugEffect.ORDERED_PRINT)
# `debug_callback_p` is the main primitive for staging out Python callbacks.
debug_callback_p = core.Primitive('debug_callback')

View File

@ -24,7 +24,7 @@ import inspect
import itertools
import operator
import os
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, List
from typing import Any, Callable, Optional, Sequence, Set, Tuple, TypeVar, List
import numpy as np
@ -66,6 +66,8 @@ T = TypeVar('T')
Array = Any
BooleanNumeric = Any # A bool, or a Boolean array.
allowed_effects: Set[core.Effect] = set()
@cache()
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
@ -321,8 +323,11 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
_check_tree_and_avals("body_fun output and input",
body_tree, body_jaxpr.out_avals,
in_tree_children[0], init_avals)
if cond_jaxpr.effects or body_jaxpr.effects:
raise NotImplementedError('Effects not supported in `while`.')
effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
disallowed_effects = effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
@ -330,9 +335,11 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
del args, kwargs
if cond_jaxpr.effects or body_jaxpr.effects:
raise NotImplementedError('Effects not supported in `while_loop`.')
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
@ -572,9 +579,20 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
batched = bool(pred_aval.shape)
if cond_jaxpr.effects:
# TODO(sharadmv): enable effects in cond
raise NotImplementedError(
'`while_loop` with effects in `cond` not supported.')
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
body_effects = [eff for eff in body_jaxpr.effects
if eff in core.ordered_effects]
num_tokens = len(body_effects)
tokens = [ctx.tokens_in.get(eff) for eff in body_effects]
token_types = [mlir.token_type() for _ in tokens]
loop_carry_types = [*token_types, *loop_carry_types]
flat_loop_carry_types = util.flatten(loop_carry_types)
args = [*tokens, *args]
flat_args = mlir.flatten_lowering_ir_args(args)
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)
@ -582,13 +600,13 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
name_stack = extend_name_stack(ctx.module_context.name_stack, 'while')
if cond_jaxpr.effects:
raise NotImplementedError('`while_loop` with effects in `cond` not supported.')
with ir.InsertionPoint(cond_block):
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
# Remove tokens from cond args
cond_args = cond_args[num_tokens:]
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, 'cond'))
@ -618,14 +636,15 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
# Tokens are at the front of the args list to the while loop
token_args, body_args = util.split_list(body_args, [num_tokens])
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, 'body'))
if body_jaxpr.effects:
raise NotImplementedError('`while_loop` with effects in `body` not supported.')
new_z, _ = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, body_jaxpr.consts),
*(y + z))
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
tokens_in, _map(mlir.ir_constants, body_jaxpr.consts), *(y + z))
out_tokens = [tokens_out.get(eff) for eff in body_effects]
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack,
@ -637,10 +656,13 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
partial(_pred_bcast_select_mhlo, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)
mhlo.ReturnOp([*util.flatten(x), *util.flatten(y), *util.flatten(new_z)])
mhlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
*util.flatten(y), *util.flatten(new_z)])
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
_, _, z = util.split_list(outputs, [cond_nconsts, body_nconsts])
tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts])
if tokens:
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
return z
mlir.register_lowering(while_p, _while_lowering)
@ -1419,8 +1441,10 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
# Extract the subtree and avals for the first element of the return tuple
out_tree_children[0], carry_avals_out,
init_tree, carry_avals)
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `scan`.')
disallowed_effects = jaxpr.effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `scan`: {disallowed_effects}')
out = scan_p.bind(*consts, *in_flat,
reverse=reverse, length=length, jaxpr=jaxpr,
@ -1613,11 +1637,9 @@ def _prepend_dim_to_aval(sz, aval):
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `scan`.')
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
return carry_avals + ys_avals, core.no_effects
return carry_avals + ys_avals, jaxpr.effects
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
linear, unroll):
@ -2053,7 +2075,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
raise core.JaxprTypeError(
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
f'called with sequence of type\n{_avals_short(x_avals)}')
return None, core.no_effects
return None, jaxpr.effects
def scan_bind(*args, **params):
if config.jax_enable_checks:

View File

@ -2285,10 +2285,13 @@ def _check_jaxpr(
else:
out_avals, effects = check_eqn(prim, in_avals, eqn.params)
if eqn.effects != effects:
raise JaxprTypeError("Inferred effects do not match equation effects.")
raise JaxprTypeError("Inferred effects do not match equation effects. "
f"Equation effects: {eqn.effects}. "
f"Jaxpr effects: {effects}")
if not eqn.effects.issubset(jaxpr.effects):
raise JaxprTypeError("Equation effects are not subset of Jaxpr effects. "
f"Equation effects: {eqn.effects}. Jaxpr effects: {jaxpr.effects}")
f"Equation effects: {eqn.effects}. "
f"Jaxpr effects: {jaxpr.effects}")
map(write, eqn.outvars, out_avals)
except JaxprTypeError as e:
ctx, settings = ctx_factory()

View File

@ -13,17 +13,20 @@
# limitations under the License.
import contextlib
import io
import unittest
import textwrap
from unittest import mock
from typing import Callable, Generator
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax.config import config
from jax._src import debugging
from jax._src import lib as jaxlib
from jax._src import test_util as jtu
import jax.numpy as jnp
config.parse_flags_with_absl()
@ -36,6 +39,9 @@ def capture_stdout() -> Generator[Callable[[], str], None, None]:
return fp.getvalue()
yield _read
def _format_multiline(text):
return textwrap.dedent(text).lstrip()
class DebugPrintTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
@ -65,9 +71,6 @@ class DebugPrintTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_can_stage_out_debug_print(self):
if jaxlib.version < (0, 3, 8):
raise unittest.SkipTest(
"`emit_python_callback` only supported in jaxlib >= 0.3.8")
@jax.jit
def f(x):
debug_print('x: {x}', x=x)
@ -77,9 +80,6 @@ class DebugPrintTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_can_stage_out_ordered_print(self):
if jaxlib.version < (0, 3, 8):
raise unittest.SkipTest(
"`emit_python_callback` only supported in jaxlib >= 0.3.8")
@jax.jit
def f(x):
debug_print('x: {x}', x=x, ordered=True)
@ -88,9 +88,74 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "x: 2\n")
class DebugPrintControlFlowTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices("tpu", "gpu")
def test_can_print_inside_scan(self, ordered):
def f(xs):
def _body(carry, x):
debug_print("carry: {carry}", carry=carry, ordered=ordered)
debug_print("x: {x}", x=x, ordered=ordered)
return carry + 1, x + 1
return lax.scan(_body, 2, xs)
with capture_stdout() as output:
f(jnp.arange(2))
self.assertEqual(output(), _format_multiline("""
carry: 2
x: 0
carry: 3
x: 1
"""))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices("tpu", "gpu")
def test_can_print_inside_for_loop(self, ordered):
def f(x):
def _body(i, x):
debug_print("x: {x}", x=x, ordered=ordered)
return x + 1
return lax.fori_loop(0, 5, _body, x)
with capture_stdout() as output:
f(2)
self.assertEqual(output(), _format_multiline("""
x: 2
x: 3
x: 4
x: 5
x: 6
"""))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices("tpu", "gpu")
def test_can_print_inside_while_loop(self, ordered):
def f(x):
def _cond(x):
return x < 10
def _body(x):
debug_print("x: {x}", x=x, ordered=ordered)
return x + 1
return lax.while_loop(_cond, _body, x)
with capture_stdout() as output:
f(5)
self.assertEqual(output(), _format_multiline("""
x: 5
x: 6
x: 7
x: 8
x: 9
"""))
if jaxlib.version < (0, 3, 8):
# No lowering for `emit_python_callback` in older jaxlibs.
del DebugPrintTest
del DebugPrintControlFlowTest
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -31,6 +31,7 @@ from jax._src import lib as jaxlib
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lax import control_flow as lcf
import numpy as np
config.parse_flags_with_absl()
@ -45,8 +46,17 @@ def _(*, effect):
mlir.lowerable_effects.add('foo')
mlir.lowerable_effects.add('foo2')
mlir.lowerable_effects.add('bar')
mlir.lowerable_effects.add('while')
mlir.lowerable_effects.add('while1')
mlir.lowerable_effects.add('while2')
core.ordered_effects.add('foo')
core.ordered_effects.add('foo2')
core.ordered_effects.add('while1')
core.ordered_effects.add('while2')
lcf.allowed_effects.add('while')
lcf.allowed_effects.add('while1')
lcf.allowed_effects.add('while2')
def trivial_effect_lowering(ctx, *, effect):
@ -257,24 +267,27 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def test_should_pass_tokens_into_ordered_effect(self):
def _effect_lowering(ctx, *, effect):
self.assertListEqual(list(ctx.tokens_in.effects()), ['foo'])
ctx.set_tokens_out(ctx.tokens_in)
return []
mlir.register_lowering(effect_p, _effect_lowering)
def setUp(self):
super().setUp()
self.old_x64 = config.jax_enable_x64
config.update('jax_enable_x64', False)
self._old_lowering = mlir._lowerings[effect_p]
def _effect_lowering(ctx, *, effect):
if effect in core.ordered_effects:
expected_effects = [effect]
else:
expected_effects = []
self.assertListEqual(list(ctx.tokens_in.effects()), expected_effects)
ctx.set_tokens_out(ctx.tokens_in)
return []
mlir.register_lowering(effect_p, _effect_lowering)
dispatch.runtime_tokens.clear()
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
config.update('jax_enable_x64', self.old_x64)
mlir.register_lowering(effect_p, self._old_lowering)
def test_cannot_lower_unlowerable_effect(self):
@jax.jit
@ -631,6 +644,37 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f2)(2.)
def test_allowed_effect_in_while_body(self):
def f(x):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect='while')
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
def test_allowed_ordered_effect_in_while_body(self):
def f(x):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect='while1')
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
def test_multiple_allowed_ordered_effect_in_while_body(self):
def f(x):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect='while1')
effect_p.bind(effect='while2')
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
def test_effects_disallowed_in_while(self):
def f1(x):
def cond_fun(x):
@ -654,6 +698,31 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f2)(2.)
def test_allowed_effect_in_scan(self):
def f(x):
def body_fun(carry, x):
effect_p.bind(effect='while')
return carry, x
return lax.scan(body_fun, x, jnp.arange(5))
f(2)
def test_allowed_ordered_effect_in_scan(self):
def f(x):
def body_fun(carry, x):
effect_p.bind(effect='while1')
return carry, x
return lax.scan(body_fun, x, jnp.arange(5))
f(2)
def test_multiple_allowed_ordered_effect_in_scan(self):
def f(x):
def body_fun(carry, x):
effect_p.bind(effect='while1')
effect_p.bind(effect='while2')
return carry, x
return lax.scan(body_fun, x, jnp.arange(5))
f(2)
def test_effects_disallowed_in_scan(self):
def f(x):