mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Add tests for higher order primitives
This commit is contained in:
parent
6ba9fb699d
commit
4392b07022
@ -276,10 +276,12 @@ def remat_impl(*args, jaxpr, prevent_cse, differentiated, policy):
|
||||
del prevent_cse, differentiated, policy # Unused.
|
||||
return core.eval_jaxpr(jaxpr, (), *args)
|
||||
|
||||
@remat_p.def_abstract_eval
|
||||
@remat_p.def_effectful_abstract_eval
|
||||
def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy):
|
||||
del args, prevent_cse, differentiated, policy # Unused.
|
||||
return [v.aval for v in jaxpr.outvars]
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `remat`.')
|
||||
return [v.aval for v in jaxpr.outvars], jaxpr.effects
|
||||
|
||||
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
|
||||
assert not jaxpr.constvars
|
||||
|
@ -329,12 +329,14 @@ def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params):
|
||||
|
||||
def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params):
|
||||
del args, params
|
||||
return fun_jaxpr.out_avals
|
||||
if fun_jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `custom_jvp`.')
|
||||
return fun_jaxpr.out_avals, fun_jaxpr.effects
|
||||
|
||||
custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr')
|
||||
custom_jvp_call_jaxpr_p.multiple_results = True
|
||||
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
|
||||
custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
|
||||
custom_jvp_call_jaxpr_p.def_effectful_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
|
||||
CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p
|
||||
|
||||
mlir.register_lowering(custom_jvp_call_jaxpr_p, mlir.lower_fun(
|
||||
@ -694,12 +696,14 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
|
||||
return core.jaxpr_as_fun(fun_jaxpr)(*args)
|
||||
|
||||
def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
|
||||
return fun_jaxpr.out_avals
|
||||
if fun_jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `custom_vjp`.')
|
||||
return fun_jaxpr.out_avals, fun_jaxpr.effects
|
||||
|
||||
custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr')
|
||||
custom_vjp_call_jaxpr_p.multiple_results = True
|
||||
custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl)
|
||||
custom_vjp_call_jaxpr_p.def_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
|
||||
custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
|
||||
CustomVJPCallPrimitive.initial_style = custom_vjp_call_jaxpr_p
|
||||
|
||||
mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
|
||||
|
@ -206,6 +206,8 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
"for jit in {elapsed_time} sec"):
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, abstract_args, pe.debug_info_final(fun, "jit"), which_explicit)
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Lowering jaxprs with effects not supported.')
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
# TODO(mattjj): handle argument pruning w/ dynamic shapes
|
||||
|
@ -314,23 +314,25 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
new_init_val, = tree_unflatten(in_tree, new_init_vals)
|
||||
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
|
||||
cond_jaxpr, cond_consts, body_consts, body_tree = rest
|
||||
joined_effects = core.join_effects(body_jaxpr.effects, cond_jaxpr.effects)
|
||||
if joined_effects:
|
||||
raise NotImplementedError('Effects not supported in `while`.')
|
||||
|
||||
in_tree_children = in_tree.children()
|
||||
assert len(in_tree_children) == 1
|
||||
_check_tree_and_avals("body_fun output and input",
|
||||
body_tree, body_jaxpr.out_avals,
|
||||
in_tree_children[0], init_avals)
|
||||
if cond_jaxpr.effects or body_jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `while`.')
|
||||
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
|
||||
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
|
||||
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
|
||||
return tree_unflatten(body_tree, outs)
|
||||
|
||||
def _while_loop_abstract_eval(*args, body_jaxpr, **kwargs):
|
||||
def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
|
||||
del args, kwargs
|
||||
return _map(raise_to_shaped, body_jaxpr.out_avals), core.no_effects
|
||||
if cond_jaxpr.effects or body_jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `while_loop`.')
|
||||
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
|
||||
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
|
||||
|
||||
def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
body_jaxpr, cond_nconsts, body_nconsts):
|
||||
@ -796,11 +798,12 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
branches, ops_tree, ops_avals, primitive_name='switch')
|
||||
|
||||
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
||||
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
||||
out_trees[0], jaxprs[0].out_avals,
|
||||
out_tree, jaxpr.out_avals)
|
||||
if any(b.effects for b in jaxprs):
|
||||
raise NotImplementedError('Effects not supported in `switch`.')
|
||||
|
||||
linear = (False,) * (len(consts) + len(ops))
|
||||
out = cond_p.bind(
|
||||
@ -879,15 +882,14 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
|
||||
if joined_effects:
|
||||
raise NotImplementedError('Effects not supported in `cond`.')
|
||||
true_jaxpr, false_jaxpr = jaxprs
|
||||
out_tree, false_out_tree = out_trees
|
||||
|
||||
_check_tree_and_avals("true_fun and false_fun output",
|
||||
out_tree, true_jaxpr.out_avals,
|
||||
false_out_tree, false_jaxpr.out_avals)
|
||||
if any(b.effects for b in jaxprs):
|
||||
raise NotImplementedError('Effects not supported in `cond`.')
|
||||
|
||||
index = lax.convert_element_type(pred, np.int32)
|
||||
|
||||
@ -934,7 +936,10 @@ def _cond_with_per_branch_args(pred,
|
||||
(true_operand, false_operand))
|
||||
|
||||
def _cond_abstract_eval(*args, branches, **kwargs):
|
||||
return _map(raise_to_shaped, branches[0].out_avals), core.no_effects
|
||||
if any(b.effects for b in branches):
|
||||
raise NotImplementedError('Effects not supported in `cond`.')
|
||||
joined_effects = core.join_effects(*(b.effects for b in branches))
|
||||
return _map(raise_to_shaped, branches[0].out_avals), joined_effects
|
||||
|
||||
def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
|
||||
linear):
|
||||
@ -1128,8 +1133,6 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
|
||||
linear_2 = (False,) * num_res + linear
|
||||
params = dict(branches=branches_2, linear=linear_2)
|
||||
if any((branch.effects for branch in branches_2)):
|
||||
raise NotImplementedError('Effects not supported in `cond`.')
|
||||
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = pe.new_eqn_recipe(
|
||||
@ -1282,6 +1285,8 @@ def _cond_typecheck(*avals, branches, linear):
|
||||
jaxpr0 = branches[0]
|
||||
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
|
||||
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
|
||||
if any(b.effects for b in branches):
|
||||
raise NotImplementedError('Effects not supported in `cond`.')
|
||||
|
||||
for i, jaxpr in enumerate(branches[1:]):
|
||||
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
|
||||
@ -1318,7 +1323,8 @@ def _cond_typecheck(*avals, branches, linear):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches must have matching effect types: '
|
||||
f'{[b.effects for b in branches]}')
|
||||
return None, core.no_effects
|
||||
joined_effects = core.join_effects(*(b.effects for b in branches))
|
||||
return None, joined_effects
|
||||
|
||||
def cond_bind(*args, branches, linear):
|
||||
if config.jax_enable_checks:
|
||||
@ -1524,13 +1530,13 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
new_init = tree_unflatten(init_tree, new_init_flat)
|
||||
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
|
||||
in_flat, jaxpr, consts, out_tree, out_tree_children = rest
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `scan`.')
|
||||
|
||||
_check_tree_and_avals("scan carry output and input",
|
||||
# Extract the subtree and avals for the first element of the return tuple
|
||||
out_tree_children[0], carry_avals_out,
|
||||
init_tree, carry_avals)
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `scan`.')
|
||||
|
||||
out = scan_p.bind(*consts, *in_flat,
|
||||
reverse=reverse, length=length, jaxpr=jaxpr,
|
||||
@ -1723,6 +1729,8 @@ def _prepend_dim_to_aval(sz, aval):
|
||||
|
||||
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear, unroll):
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `scan`.')
|
||||
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
|
||||
return carry_avals + ys_avals, core.no_effects
|
||||
|
@ -2260,10 +2260,10 @@ def _check_jaxpr(
|
||||
else:
|
||||
out_avals, effects = check_eqn(prim, in_avals, eqn.params)
|
||||
if eqn.effects != effects:
|
||||
print(eqn.effects, effects)
|
||||
raise JaxprTypeError("Inferred effects do not match equation effects.")
|
||||
if not eqn.effects.issubset(jaxpr.effects):
|
||||
raise JaxprTypeError("Equation effects are not subset of Jaxpr effects.")
|
||||
raise JaxprTypeError("Equation effects are not subset of Jaxpr effects. "
|
||||
f"Equation effects: {eqn.effects}. Jaxpr effects: {jaxpr.effects}")
|
||||
map(write, eqn.outvars, out_avals)
|
||||
except JaxprTypeError as e:
|
||||
ctx, settings = ctx_factory()
|
||||
|
@ -1051,6 +1051,8 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
||||
jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, mapped_in_avals)
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `xmap`.')
|
||||
out_axes = params['out_axes_thunk']()
|
||||
if params['spmd_out_axes_thunk'] is not None:
|
||||
spmd_out_axes = params['spmd_out_axes_thunk']()
|
||||
|
@ -665,9 +665,11 @@ def _pjit_lower(
|
||||
|
||||
def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env,
|
||||
out_positional_semantics, **_):
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `pjit`.')
|
||||
return global_to_local(out_positional_semantics, resource_env.physical_mesh,
|
||||
jaxpr.out_avals, out_axis_resources)
|
||||
pjit_p.def_abstract_eval(_pjit_abstract_eval)
|
||||
jaxpr.out_avals, out_axis_resources), jaxpr.effects
|
||||
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
||||
|
||||
|
||||
def _pjit_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,
|
||||
|
@ -1533,6 +1533,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
with core.new_sublevel():
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, in_avals, keep_inputs=keep_inputs)
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported for call primitives.')
|
||||
tracers = [*im_tracers, *explicit_tracers]
|
||||
if params.get('inline', False):
|
||||
return core.eval_jaxpr(jaxpr, consts, *tracers)
|
||||
@ -1552,7 +1554,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
|
||||
call_primitive, new_params,
|
||||
new_params['call_jaxpr'].effects, source_info)
|
||||
self.frame.eqns.append(eqn)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracers, params):
|
||||
@ -1568,6 +1570,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
with core.new_sublevel():
|
||||
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, reduced_in_avals)
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported for map primitives.')
|
||||
out_axes = params['out_axes_thunk']()
|
||||
out_avals = [core.unmapped_aval(axis_size, axis_name, out_axis, a)
|
||||
if out_axis is not None else a
|
||||
@ -1586,7 +1590,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
new_params = update_params(new_params, [True] * len(tracers), len(consts))
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
|
||||
new_params, new_params['call_jaxpr'].effects, source_info)
|
||||
self.frame.eqns.append(eqn)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
def post_process_map(self, map_primitive, out_tracers, params):
|
||||
@ -1596,6 +1600,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
with core.new_sublevel():
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
if fun_jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `custom_jvp`.')
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
main_ = ref(self.main)
|
||||
jvp_jaxpr_thunk = _memoize(
|
||||
@ -1610,7 +1616,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
num_consts=len(consts)),
|
||||
fun_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
self.frame.eqns.append(eqn)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
def post_process_custom_jvp_call(self, out_tracers, _):
|
||||
@ -1620,6 +1626,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
with core.new_sublevel():
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
if fun_jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `custom_vjp`.')
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
main_ = ref(self.main)
|
||||
fwd_jaxpr_thunk = _memoize(
|
||||
@ -1635,7 +1643,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
bwd=bwd, out_trees=out_trees),
|
||||
fun_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
self.frame.eqns.append(eqn)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
def post_process_custom_vjp_call(self, out_tracers, _):
|
||||
|
@ -11,14 +11,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import ad_checkpoint
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.config import config
|
||||
from jax._src import test_util as jtu
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -70,6 +77,114 @@ class JaxprEffectsTest(jtu.JaxTestCase):
|
||||
'Equation effects are not subset of Jaxpr effects.'):
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
|
||||
def test_core_call_primitive_inherits_effects(self):
|
||||
|
||||
def f(x):
|
||||
@lu.wrap_init
|
||||
def f_(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return [x]
|
||||
return core.call(f_, x)[0]
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f)(2.)
|
||||
|
||||
def test_xla_call_primitive_inherits_effects(self):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f)(2.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{flavor}", flavor=flavor)
|
||||
for flavor in ["old", "new"]))
|
||||
def test_remat_call_primitive_inherits_effects(self, flavor):
|
||||
remat = jax.remat if flavor == "old" else ad_checkpoint.checkpoint
|
||||
|
||||
@remat
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f)(2.)
|
||||
|
||||
def test_custom_jvp_primitive_inherits_effects(self):
|
||||
|
||||
@jax.custom_jvp
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x
|
||||
f.defjvp(lambda x, t: (x, t))
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f)(2.)
|
||||
|
||||
def test_custom_vjp_primitive_inherits_effects(self):
|
||||
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x
|
||||
f.defvjp(
|
||||
fwd=lambda x: (x, ()),
|
||||
bwd=lambda _, g: g)
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f)(2.)
|
||||
|
||||
def test_pmap_inherits_effects(self):
|
||||
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
def test_xmap_inherits_effects(self):
|
||||
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x
|
||||
f = maps.xmap(f, in_axes=['a'], out_axes=['a'])
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
def test_pjit_inherits_effects(self):
|
||||
if jax.default_backend() not in {'gpu', 'tpu'}:
|
||||
raise unittest.SkipTest("pjit only supports GPU and TPU backends")
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x
|
||||
f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'),
|
||||
out_axis_resources=pjit.PartitionSpec('x'))
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
with maps.Mesh(np.array(jax.devices()), ['x']):
|
||||
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
|
||||
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
def test_cannot_lower_jaxpr_with_effects_in_hop(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Lowering jaxprs with '
|
||||
'effects not supported'):
|
||||
f(2.)
|
||||
|
||||
|
||||
class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_effects_disallowed_in_cond(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user