Add tests for higher order primitives

This commit is contained in:
Sharad Vikram 2022-04-12 13:32:43 -07:00
parent 6ba9fb699d
commit 4392b07022
9 changed files with 172 additions and 29 deletions

View File

@ -276,10 +276,12 @@ def remat_impl(*args, jaxpr, prevent_cse, differentiated, policy):
del prevent_cse, differentiated, policy # Unused.
return core.eval_jaxpr(jaxpr, (), *args)
@remat_p.def_abstract_eval
@remat_p.def_effectful_abstract_eval
def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy):
del args, prevent_cse, differentiated, policy # Unused.
return [v.aval for v in jaxpr.outvars]
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `remat`.')
return [v.aval for v in jaxpr.outvars], jaxpr.effects
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
assert not jaxpr.constvars

View File

@ -329,12 +329,14 @@ def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params):
def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params):
del args, params
return fun_jaxpr.out_avals
if fun_jaxpr.effects:
raise NotImplementedError('Effects not supported in `custom_jvp`.')
return fun_jaxpr.out_avals, fun_jaxpr.effects
custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr')
custom_jvp_call_jaxpr_p.multiple_results = True
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
custom_jvp_call_jaxpr_p.def_effectful_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p
mlir.register_lowering(custom_jvp_call_jaxpr_p, mlir.lower_fun(
@ -694,12 +696,14 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
return core.jaxpr_as_fun(fun_jaxpr)(*args)
def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
return fun_jaxpr.out_avals
if fun_jaxpr.effects:
raise NotImplementedError('Effects not supported in `custom_vjp`.')
return fun_jaxpr.out_avals, fun_jaxpr.effects
custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr')
custom_vjp_call_jaxpr_p.multiple_results = True
custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl)
custom_vjp_call_jaxpr_p.def_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
CustomVJPCallPrimitive.initial_style = custom_vjp_call_jaxpr_p
mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(

View File

@ -206,6 +206,8 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
"for jit in {elapsed_time} sec"):
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"), which_explicit)
if jaxpr.effects:
raise NotImplementedError('Lowering jaxprs with effects not supported.')
if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError("Encountered an unexpected tracer.")
# TODO(mattjj): handle argument pruning w/ dynamic shapes

View File

@ -314,23 +314,25 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
new_init_val, = tree_unflatten(in_tree, new_init_vals)
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
cond_jaxpr, cond_consts, body_consts, body_tree = rest
joined_effects = core.join_effects(body_jaxpr.effects, cond_jaxpr.effects)
if joined_effects:
raise NotImplementedError('Effects not supported in `while`.')
in_tree_children = in_tree.children()
assert len(in_tree_children) == 1
_check_tree_and_avals("body_fun output and input",
body_tree, body_jaxpr.out_avals,
in_tree_children[0], init_avals)
if cond_jaxpr.effects or body_jaxpr.effects:
raise NotImplementedError('Effects not supported in `while`.')
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
return tree_unflatten(body_tree, outs)
def _while_loop_abstract_eval(*args, body_jaxpr, **kwargs):
def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
del args, kwargs
return _map(raise_to_shaped, body_jaxpr.out_avals), core.no_effects
if cond_jaxpr.effects or body_jaxpr.effects:
raise NotImplementedError('Effects not supported in `while_loop`.')
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
body_jaxpr, cond_nconsts, body_nconsts):
@ -796,11 +798,12 @@ def switch(index, branches: Sequence[Callable], *operands,
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
out_trees[0], jaxprs[0].out_avals,
out_tree, jaxpr.out_avals)
if any(b.effects for b in jaxprs):
raise NotImplementedError('Effects not supported in `switch`.')
linear = (False,) * (len(consts) + len(ops))
out = cond_p.bind(
@ -879,15 +882,14 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
if joined_effects:
raise NotImplementedError('Effects not supported in `cond`.')
true_jaxpr, false_jaxpr = jaxprs
out_tree, false_out_tree = out_trees
_check_tree_and_avals("true_fun and false_fun output",
out_tree, true_jaxpr.out_avals,
false_out_tree, false_jaxpr.out_avals)
if any(b.effects for b in jaxprs):
raise NotImplementedError('Effects not supported in `cond`.')
index = lax.convert_element_type(pred, np.int32)
@ -934,7 +936,10 @@ def _cond_with_per_branch_args(pred,
(true_operand, false_operand))
def _cond_abstract_eval(*args, branches, **kwargs):
return _map(raise_to_shaped, branches[0].out_avals), core.no_effects
if any(b.effects for b in branches):
raise NotImplementedError('Effects not supported in `cond`.')
joined_effects = core.join_effects(*(b.effects for b in branches))
return _map(raise_to_shaped, branches[0].out_avals), joined_effects
def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
linear):
@ -1128,8 +1133,6 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
linear_2 = (False,) * num_res + linear
params = dict(branches=branches_2, linear=linear_2)
if any((branch.effects for branch in branches_2)):
raise NotImplementedError('Effects not supported in `cond`.')
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe(
@ -1282,6 +1285,8 @@ def _cond_typecheck(*avals, branches, linear):
jaxpr0 = branches[0]
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
if any(b.effects for b in branches):
raise NotImplementedError('Effects not supported in `cond`.')
for i, jaxpr in enumerate(branches[1:]):
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
@ -1318,7 +1323,8 @@ def _cond_typecheck(*avals, branches, linear):
raise core.JaxprTypeError(
f'cond branches must have matching effect types: '
f'{[b.effects for b in branches]}')
return None, core.no_effects
joined_effects = core.join_effects(*(b.effects for b in branches))
return None, joined_effects
def cond_bind(*args, branches, linear):
if config.jax_enable_checks:
@ -1524,13 +1530,13 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
new_init = tree_unflatten(init_tree, new_init_flat)
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
in_flat, jaxpr, consts, out_tree, out_tree_children = rest
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `scan`.')
_check_tree_and_avals("scan carry output and input",
# Extract the subtree and avals for the first element of the return tuple
out_tree_children[0], carry_avals_out,
init_tree, carry_avals)
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `scan`.')
out = scan_p.bind(*consts, *in_flat,
reverse=reverse, length=length, jaxpr=jaxpr,
@ -1723,6 +1729,8 @@ def _prepend_dim_to_aval(sz, aval):
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `scan`.')
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
return carry_avals + ys_avals, core.no_effects

View File

@ -2260,10 +2260,10 @@ def _check_jaxpr(
else:
out_avals, effects = check_eqn(prim, in_avals, eqn.params)
if eqn.effects != effects:
print(eqn.effects, effects)
raise JaxprTypeError("Inferred effects do not match equation effects.")
if not eqn.effects.issubset(jaxpr.effects):
raise JaxprTypeError("Equation effects are not subset of Jaxpr effects.")
raise JaxprTypeError("Equation effects are not subset of Jaxpr effects. "
f"Equation effects: {eqn.effects}. Jaxpr effects: {jaxpr.effects}")
map(write, eqn.outvars, out_avals)
except JaxprTypeError as e:
ctx, settings = ctx_factory()

View File

@ -1051,6 +1051,8 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
with core.extend_axis_env_nd(global_axis_sizes.items()):
jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, mapped_in_avals)
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `xmap`.')
out_axes = params['out_axes_thunk']()
if params['spmd_out_axes_thunk'] is not None:
spmd_out_axes = params['spmd_out_axes_thunk']()

View File

@ -665,9 +665,11 @@ def _pjit_lower(
def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env,
out_positional_semantics, **_):
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `pjit`.')
return global_to_local(out_positional_semantics, resource_env.physical_mesh,
jaxpr.out_avals, out_axis_resources)
pjit_p.def_abstract_eval(_pjit_abstract_eval)
jaxpr.out_avals, out_axis_resources), jaxpr.effects
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
def _pjit_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,

View File

@ -1533,6 +1533,8 @@ class DynamicJaxprTrace(core.Trace):
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, in_avals, keep_inputs=keep_inputs)
if jaxpr.effects:
raise NotImplementedError('Effects not supported for call primitives.')
tracers = [*im_tracers, *explicit_tracers]
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *tracers)
@ -1552,7 +1554,7 @@ class DynamicJaxprTrace(core.Trace):
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
call_primitive, new_params,
new_params['call_jaxpr'].effects, source_info)
self.frame.eqns.append(eqn)
self.frame.add_eqn(eqn)
return out_tracers
def post_process_call(self, call_primitive, out_tracers, params):
@ -1568,6 +1570,8 @@ class DynamicJaxprTrace(core.Trace):
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals)
if jaxpr.effects:
raise NotImplementedError('Effects not supported for map primitives.')
out_axes = params['out_axes_thunk']()
out_avals = [core.unmapped_aval(axis_size, axis_name, out_axis, a)
if out_axis is not None else a
@ -1586,7 +1590,7 @@ class DynamicJaxprTrace(core.Trace):
new_params = update_params(new_params, [True] * len(tracers), len(consts))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
new_params, new_params['call_jaxpr'].effects, source_info)
self.frame.eqns.append(eqn)
self.frame.add_eqn(eqn)
return out_tracers
def post_process_map(self, map_primitive, out_tracers, params):
@ -1596,6 +1600,8 @@ class DynamicJaxprTrace(core.Trace):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
if fun_jaxpr.effects:
raise NotImplementedError('Effects not supported in `custom_jvp`.')
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
main_ = ref(self.main)
jvp_jaxpr_thunk = _memoize(
@ -1610,7 +1616,7 @@ class DynamicJaxprTrace(core.Trace):
num_consts=len(consts)),
fun_jaxpr.effects,
source_info_util.current())
self.frame.eqns.append(eqn)
self.frame.add_eqn(eqn)
return out_tracers
def post_process_custom_jvp_call(self, out_tracers, _):
@ -1620,6 +1626,8 @@ class DynamicJaxprTrace(core.Trace):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
if fun_jaxpr.effects:
raise NotImplementedError('Effects not supported in `custom_vjp`.')
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
main_ = ref(self.main)
fwd_jaxpr_thunk = _memoize(
@ -1635,7 +1643,7 @@ class DynamicJaxprTrace(core.Trace):
bwd=bwd, out_trees=out_trees),
fun_jaxpr.effects,
source_info_util.current())
self.frame.eqns.append(eqn)
self.frame.add_eqn(eqn)
return out_tracers
def post_process_custom_vjp_call(self, out_tracers, _):

View File

@ -11,14 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax import ad_checkpoint
from jax import core
from jax import lax
from jax import linear_util as lu
from jax.experimental import maps
from jax.experimental import pjit
from jax.config import config
from jax._src import test_util as jtu
import numpy as np
config.parse_flags_with_absl()
@ -70,6 +77,114 @@ class JaxprEffectsTest(jtu.JaxTestCase):
'Equation effects are not subset of Jaxpr effects.'):
core.check_jaxpr(jaxpr)
class HigherOrderPrimitiveTest(jtu.JaxTestCase):
def test_core_call_primitive_inherits_effects(self):
def f(x):
@lu.wrap_init
def f_(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return [x]
return core.call(f_, x)[0]
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
def test_xla_call_primitive_inherits_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{flavor}", flavor=flavor)
for flavor in ["old", "new"]))
def test_remat_call_primitive_inherits_effects(self, flavor):
remat = jax.remat if flavor == "old" else ad_checkpoint.checkpoint
@remat
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
def test_custom_jvp_primitive_inherits_effects(self):
@jax.custom_jvp
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x
f.defjvp(lambda x, t: (x, t))
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
def test_custom_vjp_primitive_inherits_effects(self):
@jax.custom_vjp
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x
f.defvjp(
fwd=lambda x: (x, ()),
bwd=lambda _, g: g)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
def test_pmap_inherits_effects(self):
@jax.pmap
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def test_xmap_inherits_effects(self):
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x
f = maps.xmap(f, in_axes=['a'], out_axes=['a'])
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def test_pjit_inherits_effects(self):
if jax.default_backend() not in {'gpu', 'tpu'}:
raise unittest.SkipTest("pjit only supports GPU and TPU backends")
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x
f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'),
out_axis_resources=pjit.PartitionSpec('x'))
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
with maps.Mesh(np.array(jax.devices()), ['x']):
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def test_cannot_lower_jaxpr_with_effects_in_hop(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
return x + 1.
with self.assertRaisesRegex(NotImplementedError, 'Lowering jaxprs with '
'effects not supported'):
f(2.)
class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_effects_disallowed_in_cond(self):