rocm_jax/tests/jaxpr_effects_test.py
Daniel Suo 2d1bc5c2a0 Refactor Jax FFI lowering to prepare for implementing CPU/GPU callbacks using XLA's FFI.
- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.

PiperOrigin-RevId: 729561359
2025-02-21 09:45:59 -08:00

1028 lines
31 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import unittest
from absl.testing import absltest
import jax
from jax import api_util
import jax.numpy as jnp
from jax import lax
from jax.experimental import pjit
from jax._src import ad_checkpoint
from jax._src import callback as cb
from jax._src import dispatch
from jax._src import config
from jax._src import core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
import numpy as np
config.parse_flags_with_absl()
jtu.request_cpu_devices(2)
effect_p = core.Primitive('effect')
effect_p.multiple_results = True
@effect_p.def_effectful_abstract_eval
def _(*avals, effect):
return avals, {effect}
def effect_jvp_rule(primals, tangents, effect):
return effect_p.bind(*primals, effect=effect), tangents
ad.primitive_jvps[effect_p] = effect_jvp_rule
class BasicEffect(effects.Effect):
def __init__(self, name):
self.name = name
__repr__ = lambda self: self.name
class OrderedEffect(BasicEffect): pass
class UnlowerableEffect(effects.Effect): pass
class WhileEffect(effects.Effect): pass
class RematEffect(effects.Effect): pass
class InputEffect(effects.JaxprInputEffect): pass
foo_effect = OrderedEffect("foo")
foo2_effect = OrderedEffect("foo2")
bar_effect = BasicEffect("bar")
baz_effect = UnlowerableEffect()
while_effect = WhileEffect()
while1_effect = WhileEffect()
while2_effect = WhileEffect()
log_effect = OrderedEffect("log")
unordered_log_effect = BasicEffect("unordered_log")
effects.lowerable_effects.add_type(BasicEffect)
effects.lowerable_effects.add_type(WhileEffect)
effects.ordered_effects.add_type(OrderedEffect)
effects.ordered_effects.add_type(WhileEffect)
effects.control_flow_allowed_effects.add_type(WhileEffect)
effects.remat_allowed_effects.add_type(RematEffect)
effects.control_flow_allowed_effects.add_type(InputEffect)
def trivial_effect_lowering(ctx, *, effect):
ctx.set_tokens_out(ctx.tokens_in)
return []
mlir.register_lowering(effect_p, trivial_effect_lowering)
def function_effect_lowering(ctx, *, effect):
def _f(ctx):
ctx.set_tokens_out(ctx.tokens_in)
return []
func = mlir._emit_lowering_rule_as_fun(_f, ctx)
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
effs = list(ctx.tokens_in.effects())
in_tokens = [ctx.tokens_in.get(eff) for eff in effs]
token_types = [mlir.token_type() for _ in effs]
output_types = [*token_types, *output_types]
flat_output_types = mlir.flatten_ir_types(output_types)
call = mlir.func_dialect.CallOp(flat_output_types,
mlir.ir.FlatSymbolRefAttr.get(func.name.value),
mlir.flatten_ir_values(in_tokens))
tokens, out = util.split_list(call.results, [len(ctx.tokens_in)])
ctx.set_tokens_out(mlir.TokenSet(zip(effs, tokens)))
return out
callback_p = core.Primitive('callback')
callback_p.multiple_results = True
@callback_p.def_impl
def _(*args, callback, out_avals, effect):
del out_avals, effect
callback(*args)
return []
@callback_p.def_effectful_abstract_eval
def _(*avals, callback, out_avals, effect):
del avals, callback
return out_avals, {effect}
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect):
del out_avals
token_in = None
if effects.ordered_effects.contains(effect):
token_in = ctx.tokens_in.get(effect)
out_op, token_out, _ = cb.emit_python_callback(
ctx, callback, token_in, list(args), list(ctx.avals_in),
list(ctx.avals_out), has_side_effect=True)
if token_out:
ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect:
token_out})))
return out_op
mlir.register_lowering(callback_p, callback_effect_lowering)
class JaxprEffectsTest(jtu.JaxTestCase):
def test_trivial_jaxpr_has_no_effects(self):
def f(x):
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.)
self.assertEqual(core.no_effects, jaxpr.effects)
def test_effectful_primitive_in_jaxpr_creates_effects(self):
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.)
self.assertEqual({foo_effect}, jaxpr.jaxpr.eqns[0].effects)
self.assertEqual({foo_effect}, jaxpr.effects)
def test_different_effects_in_jaxpr(self):
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.)
self.assertEqual({foo_effect}, jaxpr.jaxpr.eqns[0].effects)
self.assertEqual({bar_effect}, jaxpr.jaxpr.eqns[1].effects)
self.assertEqual({foo_effect, bar_effect}, jaxpr.effects)
def test_jaxpr_typecheck_should_verify_eqn_effects_are_subset(self):
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.).jaxpr
# Edit jaxpr to make its type wrong
jaxpr = jaxpr.replace(effects={foo_effect})
with self.assertRaisesRegex(core.JaxprTypeError,
'Equation effect not present in jaxpr effects.'):
core.check_jaxpr(jaxpr)
class HigherOrderPrimitiveTest(jtu.JaxTestCase):
def test_core_call_primitive_inherits_effects(self):
def f(x):
def f_(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return [x]
dbg = api_util.debug_info("test", f_, (2.,), {})
return core.call(
lu.wrap_init(f_, debug_info=dbg), x)[0]
jaxpr = jax.make_jaxpr(f)(2.)
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
def test_jit_primitive_inherits_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
jax.make_jaxpr(f)(2.)
jaxpr = jax.make_jaxpr(f)(2.)
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
def test_remat_call_primitive_inherits_effects(self):
@jax.checkpoint
def f(x):
x, = effect_p.bind(x, effect=foo_effect)
x, = effect_p.bind(x, effect=bar_effect)
return x
jax.make_jaxpr(f)(2.)
with self.assertRaisesRegex(NotImplementedError, "Effects not supported"):
jax.make_jaxpr(lambda x: jax.linearize(f, x)[1](x))(2.)
def test_new_remat_allows_certain_effects(self):
remat_effect = RematEffect()
@ad_checkpoint.checkpoint
def f(x):
x, = effect_p.bind(x, effect=remat_effect)
return x
jaxpr = jax.make_jaxpr(f)(2.)
self.assertSetEqual(jaxpr.effects, {remat_effect})
def test_custom_jvp_primitive_inherits_effects(self):
@jax.custom_jvp
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
f.defjvp(lambda x, t: (x, t))
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
def test_custom_vjp_primitive_inherits_effects(self):
@jax.custom_vjp
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
f.defvjp(
fwd=lambda x: (x, ()),
bwd=lambda _, g: g)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
def test_pmap_inherits_effects(self):
@jax.pmap
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
with self.assertRaisesRegex(
ValueError,
r"Ordered effects not supported for map primitives: \[.*\]"):
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def test_pjit_inherits_effects(self):
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['x'])
spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
f = pjit.pjit(f, in_shardings=spec, out_shardings=spec)
with mesh:
jaxpr = jax.make_jaxpr(f)(np.arange(jax.local_device_count()))
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
@jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.enter_context(config.enable_x64(False))
self._old_lowering = mlir._lowerings[effect_p]
def _effect_lowering(ctx, *, effect):
if effects.ordered_effects.contains(effect):
expected_effects = [effect]
else:
expected_effects = []
self.assertListEqual(list(ctx.tokens_in.effects()), expected_effects)
ctx.set_tokens_out(ctx.tokens_in)
return []
mlir.register_lowering(effect_p, _effect_lowering)
jax.effects_barrier()
dispatch.runtime_tokens.clear()
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
mlir.register_lowering(effect_p, self._old_lowering)
def test_can_lower_lowerable_effect(self):
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1.
f.lower(2.)
def test_cannot_lower_unlowerable_effect(self):
@jax.jit
def f(x):
effect_p.bind(effect=baz_effect)
return x + 1.
with self.assertRaisesRegex(ValueError, "Cannot lower jaxpr with effects"):
f.lower(2.)
def test_should_not_pass_tokens_into_unordered_effect(self):
def effect_lowering(ctx, *, effect):
self.assertEmpty(ctx.tokens_in)
return []
mlir.register_lowering(effect_p, effect_lowering)
@jax.jit
def f(x):
effect_p.bind(effect=bar_effect)
return x + 1.
f.lower(2.)
def test_lowering_that_doesnt_set_tokens_should_cause_error(self):
def bad_effect_lowering(ctx, *, effect):
# Doesn't call `ctx.set_tokens_out`!
return []
mlir.register_lowering(effect_p, bad_effect_lowering)
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1.
with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` needs to '
'set `tokens_out`'):
f.lower(2.)
def test_lowering_that_sets_wrong_tokens_should_cause_error(self):
def bad_effect_lowering(ctx, *, effect):
ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get(foo_effect)))
return []
mlir.register_lowering(effect_p, bad_effect_lowering)
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1.
with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` returns '
'incorrect set of output token.'):
f.lower(2.)
def test_nontrivial_lowering_with_ordered_effect_should_consume_token(self):
mlir.register_lowering(effect_p, function_effect_lowering)
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1.
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
call_op = main.body.blocks[0].operations[0]
self.assertEqual(call_op.operation.name, 'func.call')
self.assertEqual(str(call_op.attributes['callee']), '@effect')
func = module.body.operations[1]
self.assertEqual(func.name.value, "effect")
self.assertIn('hlo.token', str(func.type.inputs[0]))
self.assertIn('hlo.token', str(func.type.results[0]))
def test_nontrivial_lowering_with_unordered_effect_should_consume_token(self):
mlir.register_lowering(effect_p, function_effect_lowering)
@jax.jit
def f(x):
effect_p.bind(effect=bar_effect)
return x + 1.
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
first_op = main.body.blocks[0].operations[0]
self.assertEqual(first_op.operation.name, "func.call")
self.assertEqual(str(first_op.attributes["callee"]), "@effect")
self.assertLen(list(first_op.operands), 0)
func = module.body.operations[1]
self.assertEqual(func.name.value, "effect")
self.assertLen(list(func.type.inputs), 0)
self.assertLen(list(func.type.results), 0)
def test_lowered_jaxpr_without_ordered_effects_takes_no_dummy_inputs(self):
@jax.jit
def f(x):
effect_p.bind(effect=bar_effect)
return x + 1.
module = f.lower(1.).compiler_ir()
input_types = module.body.operations[0].type.inputs
self.assertLen(list(input_types), 1)
self.assertEqual(str(input_types[0]), 'tensor<f32>')
# First output should be output token
result_types = module.body.operations[0].type.results
self.assertLen(list(result_types), 1)
self.assertEqual(str(result_types[0]), 'tensor<f32>')
def test_lowered_jaxpr_with_ordered_effects_takes_token_inputs(self):
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1.
module = f.lower(1.).compiler_ir()
input_types = module.body.operations[0].type.inputs
token_type = '!stablehlo.token'
# First argument should be a token
self.assertLen(list(input_types), 2)
self.assertEqual(str(input_types[0]), token_type)
# First output should be a token
result_types = module.body.operations[0].type.results
self.assertLen(list(result_types), 2)
self.assertEqual(str(result_types[0]), token_type)
def test_lowered_jaxpr_with_multiple_ordered_effects_takes_in_tokens(self):
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=foo2_effect)
return x + 1.
module = f.lower(1.).compiler_ir()
input_types = module.body.operations[0].type.inputs
token_type = '!stablehlo.token'
# First two arguments should be token values
self.assertLen(list(input_types), 3)
self.assertEqual(str(input_types[0]), token_type)
self.assertEqual(str(input_types[1]), token_type)
# First two outputs should be token values
result_types = module.body.operations[0].type.results
self.assertLen(list(result_types), 3)
self.assertEqual(str(result_types[0]), token_type)
self.assertEqual(str(result_types[1]), token_type)
def test_can_lower_and_run_jaxpr_with_ordered_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1.
self.assertEqual(f(2.), 3.)
def test_can_lower_and_run_jaxpr_with_unordered_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect=bar_effect)
return x + 1.
self.assertEqual(f(2.), 3.)
def test_cant_jit_and_pmap_function_with_unordered_effects(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
@jax.jit
@jax.pmap
def f(x):
effect_p.bind(effect=bar_effect)
return x + 1
with jtu.ignore_warning():
f(jnp.arange(jax.device_count())) # doesn't crash
def test_cant_jit_and_pmap_function_with_ordered_effects(self):
@jax.jit
@jax.pmap
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1
with self.assertRaisesRegex(
ValueError,
r"Ordered effects not supported for map primitives: \[foo\]"):
f(jnp.arange(jax.device_count()))
def test_runtime_tokens_should_update_after_running_effectful_function(self):
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1.
self.assertNotIn(foo_effect, dispatch.runtime_tokens.current_tokens)
f(2.)
prev_token = dispatch.runtime_tokens.current_tokens[foo_effect]
f(2.)
curr_token = dispatch.runtime_tokens.current_tokens[foo_effect]
self.assertIsNot(prev_token, curr_token)
def test_can_lower_multiple_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=foo2_effect)
return x + 1.
@jax.jit
def g(x):
effect_p.bind(effect=foo_effect)
return x + 1.
self.assertNotIn(foo_effect, dispatch.runtime_tokens.current_tokens)
self.assertNotIn(foo2_effect, dispatch.runtime_tokens.current_tokens)
f(2.)
foo_token = dispatch.runtime_tokens.current_tokens[foo_effect]
foo2_token = dispatch.runtime_tokens.current_tokens[foo2_effect]
f(2.)
self.assertIsNot(foo_token, dispatch.runtime_tokens.current_tokens[foo_effect])
self.assertIsNot(foo2_token, dispatch.runtime_tokens.current_tokens[foo2_effect])
foo_token = dispatch.runtime_tokens.current_tokens[foo_effect]
foo2_token = dispatch.runtime_tokens.current_tokens[foo2_effect]
g(2.)
self.assertIsNot(foo_token, dispatch.runtime_tokens.current_tokens[foo_effect])
self.assertIs(foo2_token, dispatch.runtime_tokens.current_tokens[foo2_effect])
class EffectOrderingTest(jtu.JaxTestCase):
def test_can_execute_python_callback(self):
log = []
def log_value(x):
log.append(x)
return ()
@jax.jit
def f(x):
return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[])
f(2.)
jax.effects_barrier()
self.assertListEqual(log, [2.])
f(3.)
jax.effects_barrier()
self.assertListEqual(log, [2., 3.])
# TODO(b/307211483): Investigate failure
@jtu.skip_on_devices("tpu")
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
log = []
def log_value(x):
log.append(x)
return ()
@jax.jit
def f(x):
# Expensive computation
x = x.dot(x)
x = jnp.log(x.sum())
return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[])
@jax.jit
def g(x):
return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[])
x = jax.device_put(jnp.ones((500, 500)), jax.devices()[0])
y = jax.device_put(3., jax.devices()[1])
for _ in range(3):
f(x)
g(y)
jax.effects_barrier()
f_, g_ = float(jnp.log(1.25e8)), 3.
expected_log = [f_, g_, f_, g_, f_, g_]
self.assertListEqual(log, expected_log)
def test_different_threads_get_different_tokens(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
tokens = []
def _noop(_):
return ()
def f(x):
# Runs in a thread.
res = jax.jit(
lambda x: callback_p.bind(
x, callback=_noop, effect=log_effect, out_avals=[])
)(x)
tokens.append(dispatch.runtime_tokens.current_tokens[log_effect])
return res
t1 = threading.Thread(target=lambda: f(2.))
t2 = threading.Thread(target=lambda: f(3.))
t1.start()
t2.start()
t1.join()
t2.join()
token1, token2 = tokens
self.assertIsNot(token1, token2)
class ParallelEffectsTest(jtu.JaxTestCase):
def test_cannot_pmap_unlowerable_effect(self):
def f(x):
# abc is not lowerable
effect_p.bind(effect='abc')
return x
with self.assertRaisesRegex(
ValueError, "Cannot lower jaxpr with effects: {'abc'}"):
jax.pmap(f)(jnp.arange(jax.local_device_count()))
def test_cannot_pmap_ordered_effect(self):
def f(x):
# foo is lowerable and ordered
effect_p.bind(effect=foo_effect)
return x
with self.assertRaisesRegex(
ValueError, "Ordered effects not supported in `pmap`."):
jax.pmap(f)(jnp.arange(jax.local_device_count()))
def test_can_pmap_unordered_effect(self):
def f(x):
# bar is lowerable and unordered
effect_p.bind(effect=bar_effect)
return x
jax.pmap(f)(jnp.arange(jax.local_device_count()))
def test_can_pmap_unordered_callback(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
log = set()
def log_value(x):
log.add(int(x))
return ()
@jax.pmap
def f(x):
callback_p.bind(
x, callback=log_value, effect=unordered_log_effect, out_avals=[])
return x + 1
f(jnp.arange(2)).block_until_ready()
jax.effects_barrier()
self.assertSetEqual({0, 1}, log)
class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_effects_disallowed_in_cond(self):
def f1(x):
def true_fun(x):
effect_p.bind(effect=foo_effect)
return x
def false_fun(x):
return x
return lax.cond(True, true_fun, false_fun, x)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f1)(2.)
def test_allowed_effect_in_cond(self):
def f(x):
def true_fun(x):
effect_p.bind(effect=while_effect)
return x
def false_fun(x):
effect_p.bind(effect=while_effect)
return x
return lax.cond(x, true_fun, false_fun, x)
f(2)
def test_allowed_effect_in_cond_jvp(self):
def f(x):
def true_fun(x):
effect_p.bind(effect=while_effect)
return x
def false_fun(x):
effect_p.bind(effect=while_effect)
return x
return lax.cond(True, true_fun, false_fun, x)
# test primal side gets effect
primal_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(f, x)[0])(2.)
self.assertEqual(primal_jaxpr.effects, {while_effect})
# and tangent side does not
_, f_lin = jax.linearize(f, 2.)
lin_jaxpr = f_lin.func.fun.args[0]
self.assertEqual(lin_jaxpr.effects, set())
def test_allowed_effect_in_cond_jvp2(self):
@jax.custom_jvp
def print_tangents(x):
return x
@print_tangents.defjvp
def foo_jvp(primals, tangents):
x, = primals
t, = tangents
# TODO(mattjj,sharadmv): don't require data dependence for jax.linearize!
# effect_p.bind(t, effect=while_effect)
t, = effect_p.bind(t, effect=while_effect) # data dep only on tangents
return x, t
def f(x):
def true_fun(x):
return print_tangents(x)
def false_fun(x):
return print_tangents(x)
return lax.cond(True, true_fun, false_fun, x)
# test primal side does not get effect
primal_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(f, x)[0])(2.)
self.assertEqual(primal_jaxpr.effects, set())
# and tangent side does
_, f_lin = jax.linearize(f, 2.)
lin_jaxpr = f_lin.func.fun.args[0]
self.assertEqual(lin_jaxpr.effects, {while_effect})
def test_allowed_ordered_effect_in_cond(self):
def f(x):
def true_fun(x):
effect_p.bind(effect=while1_effect)
return x
def false_fun(x):
effect_p.bind(effect=while1_effect)
return x
return lax.cond(x, true_fun, false_fun, x)
f(2)
def test_multiple_allowed_ordered_effect_in_cond(self):
def f(x):
def true_fun(x):
effect_p.bind(effect=while1_effect)
effect_p.bind(effect=while2_effect)
return x
def false_fun(x):
effect_p.bind(effect=while1_effect)
effect_p.bind(effect=while2_effect)
return x
return lax.cond(x, true_fun, false_fun, x)
f(2)
def f2(x):
def true_fun(x):
return x
def false_fun(x):
effect_p.bind(effect=foo_effect)
return x
return lax.cond(True, true_fun, false_fun, x)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f2)(2.)
def test_allowed_effect_in_while_body(self):
def f(x):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect=while_effect)
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
def test_allowed_effect_in_cond_body(self):
def f(x):
def cond_fun(x):
effect_p.bind(effect=while_effect)
return False
def body_fun(x):
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
def test_allowed_ordered_effect_in_while_body(self):
def f(x):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect=while1_effect)
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
def test_multiple_allowed_ordered_effect_in_while_body(self):
def f(x):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect=while1_effect)
effect_p.bind(effect=while2_effect)
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
def test_effects_disallowed_in_while(self):
def f1(x):
def cond_fun(x):
effect_p.bind(effect=foo_effect)
return False
def body_fun(x):
return x
return lax.while_loop(cond_fun, body_fun, x)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f1)(2.)
def f2(x):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect=foo_effect)
return x
return lax.while_loop(cond_fun, body_fun, x)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f2)(2.)
def test_allowed_effect_in_scan(self):
def f(x):
def body_fun(carry, x):
effect_p.bind(effect=while_effect)
return carry, x
return lax.scan(body_fun, x, jnp.arange(5))
f(2)
def test_allowed_ordered_effect_in_scan(self):
def f(x):
def body_fun(carry, x):
effect_p.bind(effect=while1_effect)
return carry, x
return lax.scan(body_fun, x, jnp.arange(5))
f(2)
def test_multiple_allowed_ordered_effect_in_scan(self):
def f(x):
def body_fun(carry, x):
effect_p.bind(effect=while1_effect)
effect_p.bind(effect=while2_effect)
return carry, x
return lax.scan(body_fun, x, jnp.arange(5))
f(2)
def test_effects_disallowed_in_scan(self):
def f(x):
def body(carry, x):
effect_p.bind(effect=foo_effect)
return carry, x
return lax.scan(body, x, jnp.arange(4))
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
input_effect_p = core.Primitive('input_effect')
input_effect_p.multiple_results = True
input_effect = input_effect_p.bind
def _input_effect_abstract_eval(*avals, index):
return [], {InputEffect(index)}
input_effect_p.def_effectful_abstract_eval(_input_effect_abstract_eval)
class JaxprInputEffectTest(jtu.JaxTestCase):
def test_simple_jaxpr_input_effect(self):
def f(x, y):
input_effect(x, y, index=0)
jaxpr = jax.make_jaxpr(f)(0, 1)
self.assertIn(InputEffect(0), jaxpr.effects)
def test_jaxpr_input_effect_is_tracked_by_index_properly(self):
def f(x, y):
input_effect(y, x, index=0)
jaxpr = jax.make_jaxpr(f)(0, 1)
self.assertIn(InputEffect(1), jaxpr.effects)
def f(x, y):
input_effect(y, x, index=1)
jaxpr = jax.make_jaxpr(f)(0, 1)
self.assertIn(InputEffect(0), jaxpr.effects)
def test_jaxpr_input_effect_is_tracked_through_a_jit(self):
@jax.jit
def f(x, y):
input_effect(y, x, index=0)
jaxpr = jax.make_jaxpr(f)(0, 1)
self.assertIn(InputEffect(1), jaxpr.effects)
@jax.jit
def f(x, y):
return jax.jit(lambda a, b: input_effect(b, a, index=1))(x, y)
jaxpr = jax.make_jaxpr(f)(0, 1)
self.assertIn(InputEffect(0), jaxpr.effects)
x = np.array([0, 1])
@jax.jit
def f(y):
return input_effect(x, y, index=0)
jaxpr = jax.make_jaxpr(f)(0)
self.assertIn(InputEffect(0), jaxpr.effects)
def test_jaxpr_input_effect_is_tracked_through_partial_eval_custom(self):
def f(_, y):
input_effect(y, index=0)
jaxpr = jax.make_jaxpr(f)(0, 1)
self.assertIn(InputEffect(1), jaxpr.effects)
jaxpr_left, jaxpr_right, _, _, _ = pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, [False, True], in_inst=[False, True],
ensure_out_unknowns=[], ensure_out_inst=[],
saveable=lambda *_, **__: True)
self.assertEmpty(jaxpr_left.effects)
self.assertSetEqual({InputEffect(0)}, jaxpr_right.effects)
jaxpr_left, jaxpr_right, _, _, _ = pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, [True, False], in_inst=[True, False],
ensure_out_unknowns=[], ensure_out_inst=[],
saveable=lambda *_, **__: True)
self.assertEmpty(jaxpr_right.effects)
self.assertSetEqual({InputEffect(0)}, jaxpr_left.effects)
def test_jaxpr_input_effect_is_tracked_through_dce(self):
def f(_, y):
input_effect(y, index=0)
jaxpr = jax.make_jaxpr(f)(0, 1)
self.assertIn(InputEffect(1), jaxpr.effects)
jaxpr2, _ = pe.dce_jaxpr(jaxpr.jaxpr, [], instantiate=[False, False])
self.assertIn(InputEffect(0), jaxpr2.effects)
@jax.jit
def f(_, y):
input_effect(y, index=0)
jaxpr = jax.make_jaxpr(f)(0, 1)
self.assertIn(InputEffect(1), jaxpr.effects)
jaxpr2, _ = pe.dce_jaxpr(jaxpr.jaxpr, [], instantiate=[False, False])
self.assertIn(InputEffect(0), jaxpr2.effects)
x = np.ones(2, np.int32)
def f(_):
input_effect(x, index=0)
jaxpr = jax.make_jaxpr(f)(0)
self.assertIn(InputEffect(0), jaxpr.effects)
jaxpr3, _ = pe.dce_jaxpr(jaxpr.jaxpr, [], instantiate=[False])
self.assertIn(InputEffect(0), jaxpr3.effects)
def test_jaxpr_input_effect_is_tracked_through_while_loop(self):
y = np.ones(2)
def make_fun(index):
def f(x):
def body(y):
input_effect(x, y, index=index)
return y
lax.while_loop(lambda _: True, body, y)
return f
jaxpr = jax.make_jaxpr(make_fun(0))(0)
self.assertIn(InputEffect(1), jaxpr.effects)
jaxpr = jax.make_jaxpr(make_fun(1))(0)
self.assertIn(InputEffect(0), jaxpr.effects)
def f(x):
def body(y):
input_effect(x, y, index=1)
return y
lax.while_loop(lambda _: (x > 0).all(), body, y)
jaxpr = jax.make_jaxpr(f)(0)
self.assertIn(InputEffect(0), jaxpr.effects)
def test_jaxpr_input_effect_is_tracked_through_scan(self):
c = np.ones(2)
def make_fun(index):
def f(xs, z):
def body(z, x):
input_effect(x, z, c, index=index)
return z, x
lax.scan(body, z, xs)
return f
jaxpr = jax.make_jaxpr(make_fun(0))(jnp.arange(8), 0)
self.assertIn(InputEffect(1), jaxpr.effects)
jaxpr = jax.make_jaxpr(make_fun(1))(jnp.arange(8), 0)
self.assertIn(InputEffect(2), jaxpr.effects)
jaxpr = jax.make_jaxpr(make_fun(2))(jnp.arange(8), 0)
self.assertIn(InputEffect(0), jaxpr.effects)
def test_jaxpr_input_effect_is_tracked_through_scan_with_dce(self):
c = np.ones(2)
def make_fun(index):
def f(xs, z):
def body(z, x):
input_effect(x, z, c, index=index)
return z, x
lax.scan(body, z, xs)
return f
jaxpr = jax.make_jaxpr(make_fun(0))(jnp.arange(8), 0)
jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, [])
self.assertIn(InputEffect(1), jaxpr.effects)
jaxpr = jax.make_jaxpr(make_fun(1))(jnp.arange(8), 0)
jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, [])
self.assertIn(InputEffect(2), jaxpr.effects)
jaxpr = jax.make_jaxpr(make_fun(2))(jnp.arange(8), 0)
jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, [])
self.assertIn(InputEffect(0), jaxpr.effects)
def test_jaxpr_input_effect_is_tracked_through_cond(self):
c = np.ones(2)
def make_fun(index):
def f(x):
def true_fun(x):
input_effect(x, c, index=index)
return x
def false_fun(x):
return x
lax.cond(False, true_fun, false_fun, x)
return f
# [c, pred, x]
jaxpr = jax.make_jaxpr(make_fun(0))(0)
self.assertIn(InputEffect(1), jaxpr.effects)
jaxpr = jax.make_jaxpr(make_fun(1))(0)
self.assertIn(InputEffect(0), jaxpr.effects)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())