mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Enable print in while loops/scan
This commit is contained in:
parent
97b7fd7315
commit
dc42d7bb8e
@ -23,12 +23,15 @@ from jax import tree_util
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lax import control_flow as lcf
|
||||
|
||||
DebugEffect = enum.Enum('DebugEffect', ['PRINT', 'ORDERED_PRINT'])
|
||||
|
||||
core.ordered_effects.add(DebugEffect.ORDERED_PRINT)
|
||||
mlir.lowerable_effects.add(DebugEffect.PRINT)
|
||||
mlir.lowerable_effects.add(DebugEffect.ORDERED_PRINT)
|
||||
lcf.allowed_effects.add(DebugEffect.PRINT)
|
||||
lcf.allowed_effects.add(DebugEffect.ORDERED_PRINT)
|
||||
|
||||
# `debug_callback_p` is the main primitive for staging out Python callbacks.
|
||||
debug_callback_p = core.Primitive('debug_callback')
|
||||
|
@ -24,7 +24,7 @@ import inspect
|
||||
import itertools
|
||||
import operator
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, List
|
||||
from typing import Any, Callable, Optional, Sequence, Set, Tuple, TypeVar, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -66,6 +66,8 @@ T = TypeVar('T')
|
||||
Array = Any
|
||||
BooleanNumeric = Any # A bool, or a Boolean array.
|
||||
|
||||
allowed_effects: Set[core.Effect] = set()
|
||||
|
||||
@cache()
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
primitive_name: Optional[str] = None):
|
||||
@ -321,8 +323,11 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
_check_tree_and_avals("body_fun output and input",
|
||||
body_tree, body_jaxpr.out_avals,
|
||||
in_tree_children[0], init_avals)
|
||||
if cond_jaxpr.effects or body_jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `while`.')
|
||||
effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
|
||||
disallowed_effects = effects - allowed_effects
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `while`: {disallowed_effects}')
|
||||
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
|
||||
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
|
||||
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
|
||||
@ -330,9 +335,11 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
|
||||
def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
|
||||
del args, kwargs
|
||||
if cond_jaxpr.effects or body_jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `while_loop`.')
|
||||
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `while`: {disallowed_effects}')
|
||||
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
|
||||
|
||||
|
||||
@ -572,9 +579,20 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
body_nconsts):
|
||||
pred_aval = cond_jaxpr.out_avals[0]
|
||||
batched = bool(pred_aval.shape)
|
||||
if cond_jaxpr.effects:
|
||||
# TODO(sharadmv): enable effects in cond
|
||||
raise NotImplementedError(
|
||||
'`while_loop` with effects in `cond` not supported.')
|
||||
|
||||
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
|
||||
body_effects = [eff for eff in body_jaxpr.effects
|
||||
if eff in core.ordered_effects]
|
||||
num_tokens = len(body_effects)
|
||||
tokens = [ctx.tokens_in.get(eff) for eff in body_effects]
|
||||
token_types = [mlir.token_type() for _ in tokens]
|
||||
loop_carry_types = [*token_types, *loop_carry_types]
|
||||
flat_loop_carry_types = util.flatten(loop_carry_types)
|
||||
args = [*tokens, *args]
|
||||
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)
|
||||
@ -582,13 +600,13 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
# Loop condition
|
||||
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
|
||||
name_stack = extend_name_stack(ctx.module_context.name_stack, 'while')
|
||||
if cond_jaxpr.effects:
|
||||
raise NotImplementedError('`while_loop` with effects in `cond` not supported.')
|
||||
with ir.InsertionPoint(cond_block):
|
||||
flat_cond_args = [
|
||||
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
|
||||
]
|
||||
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
|
||||
# Remove tokens from cond args
|
||||
cond_args = cond_args[num_tokens:]
|
||||
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
|
||||
cond_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(name_stack, 'cond'))
|
||||
@ -618,14 +636,15 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
|
||||
]
|
||||
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
|
||||
# Tokens are at the front of the args list to the while loop
|
||||
token_args, body_args = util.split_list(body_args, [num_tokens])
|
||||
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
|
||||
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
|
||||
body_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(name_stack, 'body'))
|
||||
if body_jaxpr.effects:
|
||||
raise NotImplementedError('`while_loop` with effects in `body` not supported.')
|
||||
new_z, _ = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr, mlir.TokenSet(),
|
||||
_map(mlir.ir_constants, body_jaxpr.consts),
|
||||
*(y + z))
|
||||
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
|
||||
tokens_in, _map(mlir.ir_constants, body_jaxpr.consts), *(y + z))
|
||||
out_tokens = [tokens_out.get(eff) for eff in body_effects]
|
||||
if batched:
|
||||
body_pred_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(name_stack,
|
||||
@ -637,10 +656,13 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
partial(_pred_bcast_select_mhlo, pred_aval, body_pred), new_z, z,
|
||||
body_jaxpr.out_avals)
|
||||
|
||||
mhlo.ReturnOp([*util.flatten(x), *util.flatten(y), *util.flatten(new_z)])
|
||||
mhlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
|
||||
*util.flatten(y), *util.flatten(new_z)])
|
||||
|
||||
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
|
||||
_, _, z = util.split_list(outputs, [cond_nconsts, body_nconsts])
|
||||
tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts])
|
||||
if tokens:
|
||||
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
|
||||
return z
|
||||
|
||||
mlir.register_lowering(while_p, _while_lowering)
|
||||
@ -1419,8 +1441,10 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
# Extract the subtree and avals for the first element of the return tuple
|
||||
out_tree_children[0], carry_avals_out,
|
||||
init_tree, carry_avals)
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `scan`.')
|
||||
disallowed_effects = jaxpr.effects - allowed_effects
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `scan`: {disallowed_effects}')
|
||||
|
||||
out = scan_p.bind(*consts, *in_flat,
|
||||
reverse=reverse, length=length, jaxpr=jaxpr,
|
||||
@ -1613,11 +1637,9 @@ def _prepend_dim_to_aval(sz, aval):
|
||||
|
||||
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear, unroll):
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `scan`.')
|
||||
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
|
||||
return carry_avals + ys_avals, core.no_effects
|
||||
return carry_avals + ys_avals, jaxpr.effects
|
||||
|
||||
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
|
||||
linear, unroll):
|
||||
@ -2053,7 +2075,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
|
||||
raise core.JaxprTypeError(
|
||||
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
|
||||
f'called with sequence of type\n{_avals_short(x_avals)}')
|
||||
return None, core.no_effects
|
||||
return None, jaxpr.effects
|
||||
|
||||
def scan_bind(*args, **params):
|
||||
if config.jax_enable_checks:
|
||||
|
@ -2285,10 +2285,13 @@ def _check_jaxpr(
|
||||
else:
|
||||
out_avals, effects = check_eqn(prim, in_avals, eqn.params)
|
||||
if eqn.effects != effects:
|
||||
raise JaxprTypeError("Inferred effects do not match equation effects.")
|
||||
raise JaxprTypeError("Inferred effects do not match equation effects. "
|
||||
f"Equation effects: {eqn.effects}. "
|
||||
f"Jaxpr effects: {effects}")
|
||||
if not eqn.effects.issubset(jaxpr.effects):
|
||||
raise JaxprTypeError("Equation effects are not subset of Jaxpr effects. "
|
||||
f"Equation effects: {eqn.effects}. Jaxpr effects: {jaxpr.effects}")
|
||||
f"Equation effects: {eqn.effects}. "
|
||||
f"Jaxpr effects: {jaxpr.effects}")
|
||||
map(write, eqn.outvars, out_avals)
|
||||
except JaxprTypeError as e:
|
||||
ctx, settings = ctx_factory()
|
||||
|
@ -13,17 +13,20 @@
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import io
|
||||
import unittest
|
||||
import textwrap
|
||||
from unittest import mock
|
||||
|
||||
from typing import Callable, Generator
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax._src import debugging
|
||||
from jax._src import lib as jaxlib
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -36,6 +39,9 @@ def capture_stdout() -> Generator[Callable[[], str], None, None]:
|
||||
return fp.getvalue()
|
||||
yield _read
|
||||
|
||||
def _format_multiline(text):
|
||||
return textwrap.dedent(text).lstrip()
|
||||
|
||||
class DebugPrintTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
@ -65,9 +71,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_can_stage_out_debug_print(self):
|
||||
if jaxlib.version < (0, 3, 8):
|
||||
raise unittest.SkipTest(
|
||||
"`emit_python_callback` only supported in jaxlib >= 0.3.8")
|
||||
@jax.jit
|
||||
def f(x):
|
||||
debug_print('x: {x}', x=x)
|
||||
@ -77,9 +80,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_can_stage_out_ordered_print(self):
|
||||
if jaxlib.version < (0, 3, 8):
|
||||
raise unittest.SkipTest(
|
||||
"`emit_python_callback` only supported in jaxlib >= 0.3.8")
|
||||
@jax.jit
|
||||
def f(x):
|
||||
debug_print('x: {x}', x=x, ordered=True)
|
||||
@ -88,9 +88,74 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "x: 2\n")
|
||||
|
||||
|
||||
class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
|
||||
for ordered in [False, True]))
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_can_print_inside_scan(self, ordered):
|
||||
def f(xs):
|
||||
def _body(carry, x):
|
||||
debug_print("carry: {carry}", carry=carry, ordered=ordered)
|
||||
debug_print("x: {x}", x=x, ordered=ordered)
|
||||
return carry + 1, x + 1
|
||||
return lax.scan(_body, 2, xs)
|
||||
with capture_stdout() as output:
|
||||
f(jnp.arange(2))
|
||||
self.assertEqual(output(), _format_multiline("""
|
||||
carry: 2
|
||||
x: 0
|
||||
carry: 3
|
||||
x: 1
|
||||
"""))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
|
||||
for ordered in [False, True]))
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_can_print_inside_for_loop(self, ordered):
|
||||
def f(x):
|
||||
def _body(i, x):
|
||||
debug_print("x: {x}", x=x, ordered=ordered)
|
||||
return x + 1
|
||||
return lax.fori_loop(0, 5, _body, x)
|
||||
with capture_stdout() as output:
|
||||
f(2)
|
||||
self.assertEqual(output(), _format_multiline("""
|
||||
x: 2
|
||||
x: 3
|
||||
x: 4
|
||||
x: 5
|
||||
x: 6
|
||||
"""))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
|
||||
for ordered in [False, True]))
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_can_print_inside_while_loop(self, ordered):
|
||||
def f(x):
|
||||
def _cond(x):
|
||||
return x < 10
|
||||
def _body(x):
|
||||
debug_print("x: {x}", x=x, ordered=ordered)
|
||||
return x + 1
|
||||
return lax.while_loop(_cond, _body, x)
|
||||
with capture_stdout() as output:
|
||||
f(5)
|
||||
self.assertEqual(output(), _format_multiline("""
|
||||
x: 5
|
||||
x: 6
|
||||
x: 7
|
||||
x: 8
|
||||
x: 9
|
||||
"""))
|
||||
|
||||
if jaxlib.version < (0, 3, 8):
|
||||
# No lowering for `emit_python_callback` in older jaxlibs.
|
||||
del DebugPrintTest
|
||||
del DebugPrintControlFlowTest
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -31,6 +31,7 @@ from jax._src import lib as jaxlib
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.lax import control_flow as lcf
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -45,8 +46,17 @@ def _(*, effect):
|
||||
mlir.lowerable_effects.add('foo')
|
||||
mlir.lowerable_effects.add('foo2')
|
||||
mlir.lowerable_effects.add('bar')
|
||||
mlir.lowerable_effects.add('while')
|
||||
mlir.lowerable_effects.add('while1')
|
||||
mlir.lowerable_effects.add('while2')
|
||||
core.ordered_effects.add('foo')
|
||||
core.ordered_effects.add('foo2')
|
||||
core.ordered_effects.add('while1')
|
||||
core.ordered_effects.add('while2')
|
||||
|
||||
lcf.allowed_effects.add('while')
|
||||
lcf.allowed_effects.add('while1')
|
||||
lcf.allowed_effects.add('while2')
|
||||
|
||||
|
||||
def trivial_effect_lowering(ctx, *, effect):
|
||||
@ -257,24 +267,27 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
|
||||
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
def test_should_pass_tokens_into_ordered_effect(self):
|
||||
|
||||
def _effect_lowering(ctx, *, effect):
|
||||
self.assertListEqual(list(ctx.tokens_in.effects()), ['foo'])
|
||||
ctx.set_tokens_out(ctx.tokens_in)
|
||||
return []
|
||||
mlir.register_lowering(effect_p, _effect_lowering)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.old_x64 = config.jax_enable_x64
|
||||
config.update('jax_enable_x64', False)
|
||||
self._old_lowering = mlir._lowerings[effect_p]
|
||||
def _effect_lowering(ctx, *, effect):
|
||||
if effect in core.ordered_effects:
|
||||
expected_effects = [effect]
|
||||
else:
|
||||
expected_effects = []
|
||||
self.assertListEqual(list(ctx.tokens_in.effects()), expected_effects)
|
||||
ctx.set_tokens_out(ctx.tokens_in)
|
||||
return []
|
||||
mlir.register_lowering(effect_p, _effect_lowering)
|
||||
dispatch.runtime_tokens.clear()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
config.update('jax_enable_x64', self.old_x64)
|
||||
mlir.register_lowering(effect_p, self._old_lowering)
|
||||
|
||||
def test_cannot_lower_unlowerable_effect(self):
|
||||
@jax.jit
|
||||
@ -631,6 +644,37 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f2)(2.)
|
||||
|
||||
def test_allowed_effect_in_while_body(self):
|
||||
def f(x):
|
||||
def cond_fun(x):
|
||||
return False
|
||||
def body_fun(x):
|
||||
effect_p.bind(effect='while')
|
||||
return x
|
||||
return lax.while_loop(cond_fun, body_fun, x)
|
||||
f(2)
|
||||
|
||||
def test_allowed_ordered_effect_in_while_body(self):
|
||||
def f(x):
|
||||
def cond_fun(x):
|
||||
return False
|
||||
def body_fun(x):
|
||||
effect_p.bind(effect='while1')
|
||||
return x
|
||||
return lax.while_loop(cond_fun, body_fun, x)
|
||||
f(2)
|
||||
|
||||
def test_multiple_allowed_ordered_effect_in_while_body(self):
|
||||
def f(x):
|
||||
def cond_fun(x):
|
||||
return False
|
||||
def body_fun(x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect='while2')
|
||||
return x
|
||||
return lax.while_loop(cond_fun, body_fun, x)
|
||||
f(2)
|
||||
|
||||
def test_effects_disallowed_in_while(self):
|
||||
def f1(x):
|
||||
def cond_fun(x):
|
||||
@ -654,6 +698,31 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f2)(2.)
|
||||
|
||||
def test_allowed_effect_in_scan(self):
|
||||
def f(x):
|
||||
def body_fun(carry, x):
|
||||
effect_p.bind(effect='while')
|
||||
return carry, x
|
||||
return lax.scan(body_fun, x, jnp.arange(5))
|
||||
f(2)
|
||||
|
||||
def test_allowed_ordered_effect_in_scan(self):
|
||||
def f(x):
|
||||
def body_fun(carry, x):
|
||||
effect_p.bind(effect='while1')
|
||||
return carry, x
|
||||
return lax.scan(body_fun, x, jnp.arange(5))
|
||||
f(2)
|
||||
|
||||
def test_multiple_allowed_ordered_effect_in_scan(self):
|
||||
def f(x):
|
||||
def body_fun(carry, x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect='while2')
|
||||
return carry, x
|
||||
return lax.scan(body_fun, x, jnp.arange(5))
|
||||
f(2)
|
||||
|
||||
def test_effects_disallowed_in_scan(self):
|
||||
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user