Adds simple effect types to jaxprs

This commit is contained in:
Sharad Vikram 2022-02-28 13:36:39 -08:00
parent 902fc0c3d2
commit 0fa1eddd25
17 changed files with 383 additions and 158 deletions

View File

@ -328,7 +328,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
for x in jaxpr_unknown.outvars]
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
new_params, source_info_util.current())
new_params, jaxpr_unknown.effects, source_info_util.current())
for t in out_jaxpr_tracers: t.recipe = recipe
# zip together known and unknown outputs

View File

@ -180,7 +180,7 @@ class CustomTransposePrimitive(core.Primitive):
# TODO(frostig,mattjj): reinstate checks
def custom_transpose_typecheck(*avals, **params):
pass
return None, core.no_effects
def custom_transpose_transpose_rule(

View File

@ -303,7 +303,8 @@ def _prune_unused_inputs(
(i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
kept_var_idx, new_invars = util.unzip2(
(i, v) for i, v in enumerate(jaxpr.invars) if v in used)
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns)
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns,
jaxpr.effects)
return new_jaxpr, set(kept_const_idx), set(kept_var_idx)

View File

@ -109,8 +109,7 @@ def _initial_style_jaxprs_with_common_consts(
prefix = util.concatenate(unused_const_vars[:i])
suffix = util.concatenate(unused_const_vars[i + 1:])
constvars = [*prefix, *jaxpr.constvars, *suffix]
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
return jaxpr.replace(constvars=constvars)
consts = util.concatenate(all_consts)
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
@ -315,6 +314,9 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
new_init_val, = tree_unflatten(in_tree, new_init_vals)
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
cond_jaxpr, cond_consts, body_consts, body_tree = rest
joined_effects = core.join_effects(body_jaxpr.effects, cond_jaxpr.effects)
if joined_effects:
raise NotImplementedError('Effects not supported in `while`.')
in_tree_children = in_tree.children()
assert len(in_tree_children) == 1
@ -326,8 +328,9 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
return tree_unflatten(body_tree, outs)
def _while_loop_abstract_eval(*args, **kwargs):
return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals)
def _while_loop_abstract_eval(*args, body_jaxpr, **kwargs):
del args, kwargs
return _map(raise_to_shaped, body_jaxpr.out_avals), core.no_effects
def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
body_jaxpr, cond_nconsts, body_nconsts):
@ -516,7 +519,8 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
invars_aug,
cond_jaxpr.jaxpr.outvars,
cond_jaxpr.jaxpr.eqns)
cond_jaxpr.jaxpr.eqns,
cond_jaxpr.jaxpr.effects)
cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts)
out = while_p.bind(
@ -611,7 +615,7 @@ def _while_transpose_error(*_, **kwargs):
while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_translation(while_p, _while_loop_translation_rule,
@ -875,6 +879,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
if joined_effects:
raise NotImplementedError('Effects not supported in `cond`.')
true_jaxpr, false_jaxpr = jaxprs
out_tree, false_out_tree = out_trees
@ -926,8 +933,8 @@ def _cond_with_per_branch_args(pred,
lambda op: false_fun(op[1]),
(true_operand, false_operand))
def _cond_abstract_eval(*args, **kwargs):
return _map(raise_to_shaped, kwargs["branches"][0].out_avals)
def _cond_abstract_eval(*args, branches, **kwargs):
return _map(raise_to_shaped, branches[0].out_avals), core.no_effects
def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
linear):
@ -1121,11 +1128,13 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
linear_2 = (False,) * num_res + linear
params = dict(branches=branches_2, linear=linear_2)
if any((branch.effects for branch in branches_2)):
raise NotImplementedError('Effects not supported in `cond`.')
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe(
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
source)
core.no_effects, source)
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -1200,7 +1209,8 @@ def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
aug_invars = aug_res_vars + non_res_vars
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns)
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
jaxpr.jaxpr.effects)
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
return jaxpr_aug
@ -1304,6 +1314,11 @@ def _cond_typecheck(*avals, branches, linear):
raise core.JaxprTypeError(
f'cond branches take input types {jaxpr0_in_avals_str}, '
f'called with operands of type {_avals_short(op_avals)}')
if any((b.effects != branches[0].effects for b in branches[1:])):
raise core.JaxprTypeError(
f'cond branches must have matching effect types: '
f'{[b.effects for b in branches]}')
return None, core.no_effects
def cond_bind(*args, branches, linear):
if config.jax_enable_checks:
@ -1316,7 +1331,7 @@ def cond_bind(*args, branches, linear):
cond_p = core.AxisPrimitive('cond')
cond_p.multiple_results = True
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
cond_p.def_abstract_eval(_cond_abstract_eval)
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp
ad.reducing_transposes[cond_p] = _cond_transpose
@ -1509,6 +1524,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
new_init = tree_unflatten(init_tree, new_init_flat)
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
in_flat, jaxpr, consts, out_tree, out_tree_children = rest
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `scan`.')
_check_tree_and_avals("scan carry output and input",
# Extract the subtree and avals for the first element of the return tuple
@ -1713,7 +1730,7 @@ def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
return carry_avals + ys_avals
return carry_avals + ys_avals, core.no_effects
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
linear, unroll):
@ -1884,6 +1901,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
num_consts=num_consts_2,
num_carry=num_carry, linear=tuple(linear_2),
unroll=unroll),
jaxpr_2_opt.effects,
source)
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -2088,6 +2106,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
raise core.JaxprTypeError(
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
f'called with sequence of type\n{_avals_short(x_avals)}')
return None, core.no_effects
def scan_bind(*args, **params):
if config.jax_enable_checks:
@ -2100,7 +2119,7 @@ scan_p = core.AxisPrimitive("scan")
scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
scan_p.def_abstract_eval(_scan_abstract_eval)
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.reducing_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval

View File

@ -2723,7 +2723,7 @@ def _broadcast_in_dim_staging_rule(
out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info)
invars = [trace.getvar(x), *(trace.getvar(d) for d in dyn_shape)]
eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)],
broadcast_in_dim_p, params, source_info)
broadcast_in_dim_p, params, core.no_effects, source_info)
trace.frame.eqns.append(eqn)
return out_tracer

View File

@ -1642,7 +1642,7 @@ def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision):
if not len(set(axis_name)) == len(axis_name): raise ValueError
pos_aval = lax.dot_general_p.abstract_eval(
x, y, dimension_numbers=[pos_contract, pos_batch],
precision=precision, preferred_element_type=None)
precision=precision, preferred_element_type=None)[0]
common_named_shape = core.join_named_shapes(x.named_shape, y.named_shape)
named_shape = {name: size
for name, size in common_named_shape.items()

View File

@ -54,14 +54,21 @@ map, unsafe_map = safe_map, map
# -------------------- jaxprs --------------------
Effect = Any
Effects = Set[Effect]
no_effects: Effects = set()
class Jaxpr:
constvars: List[Var]
invars: List[Var]
outvars: List[Atom]
eqns: List[JaxprEqn]
effects: Effects
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn]):
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
@ -70,11 +77,14 @@ class Jaxpr:
the inputs to the Jaxpr.
outvars: list of output variables.
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
"""
self.constvars = list(constvars)
self.invars = list(invars)
self.outvars = list(outvars)
self.eqns = list(eqns)
self.effects = effects
def __str__(self):
return str(pp_jaxpr(self, JaxprPpContext(), JaxprPpSettings()))
@ -92,6 +102,18 @@ class Jaxpr:
def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
def replace(self, *, constvars=None, invars=None, outvars=None, eqns=None,
effects=None):
constvars = self.constvars if constvars is None else constvars
invars = self.invars if invars is None else invars
outvars = self.outvars if outvars is None else outvars
eqns = self.eqns if eqns is None else eqns
effects = self.effects if effects is None else effects
return Jaxpr(constvars=constvars, invars=invars, outvars=outvars, eqns=eqns,
effects=effects)
def join_effects(*effects: Effect) -> Effects:
return set.union(*effects) if effects else no_effects
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
for val in params.values():
@ -137,9 +159,18 @@ class ClosedJaxpr:
def eqns(self):
return self.jaxpr.eqns
@property
def effects(self) -> Effects:
return self.jaxpr.effects
def map_jaxpr(self, f):
return ClosedJaxpr(f(self.jaxpr), self.consts)
def replace(self, *, jaxpr=None, consts=None):
jaxpr = self.jaxpr if jaxpr is None else jaxpr
consts = self.consts if consts is None else consts
return ClosedJaxpr(jaxpr, consts)
def __str__(self): return str(self.jaxpr)
def __repr__(self): return repr(self.jaxpr)
@ -164,16 +195,20 @@ class JaxprEqn(NamedTuple):
outvars: List[Var]
primitive: Primitive
params: Dict[str, Any]
effects: Effects
source_info: source_info_util.SourceInfo
def __repr__(self):
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()
def new_jaxpr_eqn(invars, outvars, primitive, params, source_info=None):
def replace(self, *args, **kwargs):
return self._replace(*args, **kwargs)
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
if primitive.call_primitive:
assert len(outvars) == len(params["call_jaxpr"].outvars)
source_info = source_info or source_info_util.new_source_info()
return JaxprEqn(invars, outvars, primitive, params, source_info)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)
@total_ordering
@ -296,9 +331,13 @@ class Primitive:
return impl
def def_abstract_eval(self, abstract_eval):
self.abstract_eval = abstract_eval
self.abstract_eval = _effect_free_abstract_eval(abstract_eval) # type: ignore[assignment]
return abstract_eval
def def_effectful_abstract_eval(self, effectful_abstract_eval):
self.abstract_eval = effectful_abstract_eval # type: ignore[assignment]
return effectful_abstract_eval
def def_custom_bind(self, bind):
self.bind = bind
return bind
@ -315,6 +354,11 @@ class Primitive:
return [], params
def _effect_free_abstract_eval(abstract_eval):
def abstract_eval_(*args, **kwargs):
return abstract_eval(*args, **kwargs), no_effects
return abstract_eval_
# -------------------- lifting --------------------
# TODO(mattjj): replace this approach with a primitive-keyed table of rules
@ -2018,7 +2062,7 @@ def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: Dict[Var, Var
e.eqn = eqn
raise
params = subst_axis_names(eqn.primitive, eqn.params, subst)
return new_jaxpr_eqn(invars, outvars, eqn.primitive, params, eqn.source_info)
return eqn.replace(invars=invars, outvars=outvars, params=params)
def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst):
consts = None
@ -2030,7 +2074,7 @@ def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars]
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns]
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars]
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns)
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, jaxpr.effects)
if consts is not None:
return ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr
@ -2172,15 +2216,20 @@ def _check_jaxpr(
if any(isinstance(ina, ConcreteArray) for ina in in_avals):
raise JaxprTypeError("Equation given ConcreteArray type inputs")
if prim in custom_typechecks:
out_avals = custom_typechecks[prim](*in_avals, **eqn.params)
out_avals, effects = custom_typechecks[prim](*in_avals, **eqn.params)
if out_avals is None:
out_avals = [v.aval for v in eqn.outvars]
elif prim.call_primitive:
out_avals = check_call(ctx_factory, prim, in_avals, eqn.params)
out_avals, effects = check_call(ctx_factory, prim, in_avals, eqn.params)
elif prim.map_primitive:
out_avals = check_map(ctx_factory, prim, in_avals, eqn.params)
out_avals, effects = check_map(ctx_factory, prim, in_avals, eqn.params)
else:
out_avals = check_eqn(prim, in_avals, eqn.params)
out_avals, effects = check_eqn(prim, in_avals, eqn.params)
if eqn.effects != effects:
print(eqn.effects, effects)
raise JaxprTypeError("Inferred effects do not match equation effects.")
if not eqn.effects.issubset(jaxpr.effects):
raise JaxprTypeError("Equation effects are not subset of Jaxpr effects.")
map(write, eqn.outvars, out_avals)
except JaxprTypeError as e:
ctx, settings = ctx_factory()
@ -2196,10 +2245,10 @@ def check_eqn(prim, in_avals, params):
for jaxpr in jaxprs_in_params(params):
check_jaxpr(jaxpr)
out_avals = prim.abstract_eval(*in_avals, **params)
out_avals, effects = prim.abstract_eval(*in_avals, **params)
if not prim.multiple_results:
out_avals = [out_avals]
return out_avals
return out_avals, effects
def check_call(ctx_factory, prim, in_avals, params):
if "call_jaxpr" not in params:
@ -2221,12 +2270,14 @@ def check_call(ctx_factory, prim, in_avals, params):
_check_jaxpr(ctx_factory, call_jaxpr, in_avals)
out_avals = [v.aval for v in call_jaxpr.outvars]
return out_avals
return out_avals, call_jaxpr.effects
def check_map(ctx_factory, prim, in_avals, params):
if "call_jaxpr" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
if call_jaxpr.effects:
raise JaxprTypeError(f"Map primitive {prim} mapping an effectful function")
if "axis_size" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_size' parameter")
axis_size = params["axis_size"]
@ -2257,7 +2308,7 @@ def check_map(ctx_factory, prim, in_avals, params):
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval
for aval, out_axis in zip(mapped_out_avals, out_axes)]
return out_avals
return out_avals, call_jaxpr.effects
# ------------------- Jaxpr printed representation -------------------

View File

@ -619,8 +619,7 @@ def ignore_errors_jaxpr(jaxpr, error):
new_vars = core.gensym([jaxpr])
new_invars = (new_vars(err_aval), new_vars(code_aval),
new_vars(payload_aval), *jaxpr.invars)
new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars,
jaxpr.outvars, jaxpr.eqns)
new_jaxpr = jaxpr.replace(invars=new_invars)
return core.ClosedJaxpr(new_jaxpr, consts)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,

View File

@ -1441,10 +1441,10 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
# We need tokens but none is given in input; make one depending on all invars
eqns.append(
core.new_jaxpr_eqn(jaxpr.invars, [last_token_var],
lax.create_token_p, {}, source_info_util.current()))
lax.create_token_p, {}, core.no_effects, source_info_util.current()))
eqns.append(
core.new_jaxpr_eqn(jaxpr.invars, [last_itoken_var],
lax.create_token_p, {}, source_info_util.current()))
lax.create_token_p, {}, core.no_effects, source_info_util.current()))
for eqn in jaxpr.eqns:
if not core.primitive_uses_outfeed(eqn.primitive, eqn.params):
@ -1458,7 +1458,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
last_itoken_var = output_itoken_var
outvars = jaxpr.outvars + ([last_token_var, last_itoken_var] if has_output_token else [])
new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr.effects)
return new_jaxpr
@ -1476,11 +1476,9 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
"""
if eqn.primitive is outside_call_p:
assert "has_token" not in eqn.params
eqns.append(
core.new_jaxpr_eqn(eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
dict(eqn.params, has_token=True),
eqn.source_info))
eqns.append(eqn.replace(invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(eqn.params, has_token=True)))
elif eqn.primitive is lax.while_p:
cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
eqn.params,
@ -1492,28 +1490,26 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
return
eqns.append(
core.new_jaxpr_eqn(
eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
dict(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
body_jaxpr=_rewrite_closed_jaxpr(body_jaxpr, True, True),
cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True,
False)), eqn.source_info))
cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True, False))))
elif eqn.primitive is lax.cond_p:
branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
index, *operands = eqn.invars
new_invars = [index, *operands, input_token_var, input_itoken_var]
eqns.append(
core.new_jaxpr_eqn(
new_invars, eqn.outvars + [output_token_var, output_itoken_var],
eqn.primitive,
dict(
eqn.replace(
invars=new_invars, outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
branches=tuple(
_rewrite_closed_jaxpr(jaxpr, True, True)
for jaxpr in branches),
linear=(*linear, False, False)), eqn.source_info))
linear=(*linear, False, False))))
elif eqn.primitive is lax.scan_p:
num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict(
eqn.params,
@ -1537,56 +1533,51 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
new_jaxpr_outvars[num_carry:-2])
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
eqns.append(
core.new_jaxpr_eqn(
new_invars,
eqn.replace(
invars=new_invars,
# Output token is at the end of carry result
eqn.outvars[0:num_carry] + [output_token_var, output_itoken_var] +
eqn.outvars[num_carry:],
eqn.primitive,
dict(
outvars=(eqn.outvars[0:num_carry] + [output_token_var, output_itoken_var] +
eqn.outvars[num_carry:]),
params=dict(
eqn.params,
jaxpr=new_jaxpr,
num_carry=num_carry + 2,
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:]),
eqn.source_info))
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:])))
elif eqn.primitive is xla.xla_call_p:
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
eqns.append(
core.new_jaxpr_eqn(
eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
dict(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
donated_invars=eqn.params["donated_invars"] + (False, False)),
eqn.source_info))
donated_invars=eqn.params["donated_invars"] + (False, False))))
elif eqn.primitive is pxla.xla_pmap_p:
# We broadcast the input token into an array of tokens
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
eqns.append(
core.new_jaxpr_eqn(
eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var],
eqn.primitive,
dict(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
donated_invars=eqn.params["donated_invars"] + (False, False),
# Sharding/unsharding of tokens in pmap_translation are special
# cased to just pass-through the token
in_axes=eqn.params["in_axes"] + (None, None),
out_axes=eqn.params["out_axes"] + (0, 0)),
eqn.source_info))
out_axes=eqn.params["out_axes"] + (0, 0))))
elif eqn.primitive is pe.remat_call_p:
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
eqns.append(
core.new_jaxpr_eqn(
eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
dict(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
), eqn.source_info))
)))
elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p:
fun_jaxpr = eqn.params["fun_jaxpr"]
@ -1594,15 +1585,14 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
assert False, "Should not be reached"
eqns.append(
core.new_jaxpr_eqn(
eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
dict(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
jvp_jaxpr_thunk=unreachable_thunk
),
eqn.source_info))
)))
elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p:
fun_jaxpr = eqn.params["fun_jaxpr"]
new_invars = [*eqn.invars, input_token_var, input_itoken_var]
@ -1611,11 +1601,10 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
assert False, "Should not be reached"
eqns.append(
core.new_jaxpr_eqn(
new_invars,
eqn.outvars + [output_token_var, output_itoken_var],
eqn.primitive,
dict(
eqn.replace(
invars=new_invars,
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
fwd_jaxpr_thunk=unreachable_thunk,
@ -1623,25 +1612,24 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
# should not be needed because this rewrite is just before
# compilation to XLA, which does not use those parameters.
bwd="illegal param",
out_trees="illegal param"),
eqn.source_info))
out_trees="illegal param")))
elif eqn.primitive is core.named_call_p:
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
eqns.append(
core.new_jaxpr_eqn(
eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
dict(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
), eqn.source_info))
)))
elif eqn.primitive is pjit.pjit_p:
jaxpr = cast(core.ClosedJaxpr, eqn.params["jaxpr"])
eqns.append(
core.new_jaxpr_eqn(
eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
dict(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True),
donated_invars=eqn.params["donated_invars"] + (False, False),
@ -1649,7 +1637,7 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
(pjit.REPLICATED, pjit.REPLICATED)),
out_axis_resources=(eqn.params["out_axis_resources"] +
(pjit.REPLICATED, pjit.REPLICATED)),
), eqn.source_info))
)))
else:
raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
@ -1678,6 +1666,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
name="cond_before",
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
inline=False),
transformed_cond_jaxpr.jaxpr.effects,
eqn.source_info))
# Make a new cond "lambda pred, carry, token, itoken: pred"
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
@ -1686,7 +1675,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
[mk_new_var(input_token_var.aval),
mk_new_var(input_itoken_var.aval)])
new_cond_jaxpr = core.ClosedJaxpr(
core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), [])
core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], [], set()), [])
# Make a new body:
# "lambda cond_constvars, body_constvars, pred, carry, token, itoken:
# carry2, token2, itoken2 = rewrite(BODY)(body_constvars, carry, token, itoken)
@ -1723,6 +1712,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
name="body",
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals),
inline=False),
transformed_body_jaxpr.effects,
eqn.source_info),
core.new_jaxpr_eqn(
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
@ -1732,14 +1722,16 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
name="cond_body",
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
inline=False),
transformed_cond_jaxpr.effects,
eqn.source_info)
]
effects = core.join_effects(*(eqn.effects for eqn in new_body_eqns))
new_body_jaxpr = core.ClosedJaxpr(
core.Jaxpr([], (new_body_invars_cond_constvars +
new_body_invars_body_constvars + [new_body_invars_pred] +
new_body_invars_carry + [new_body_invars_token, new_body_invars_itoken]),
([new_body_pred2] + new_body_carry2 + [new_body_token3, new_body_itoken3]),
new_body_eqns), [])
new_body_eqns, effects), [])
pred_out = mk_new_var(cond_jaxpr.out_avals[0])
eqns.append(
@ -1752,7 +1744,9 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
cond_jaxpr=new_cond_jaxpr,
cond_nconsts=0,
body_jaxpr=new_body_jaxpr,
body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
body_nconsts=cond_nconsts + body_nconsts),
new_body_jaxpr.effects,
eqn.source_info))
# We need an identity primitive to simplify rewriting

View File

@ -809,7 +809,7 @@ class TensorFlowTrace(core.Trace):
# abstract evaluation rules can properly track polymorphic shapes.
# Unfortunately under op-by-op execution this is a rare occasion where we
# need abstract evaluation.
out_aval = primitive.abstract_eval(*args_avals, **params)
out_aval, _ = primitive.abstract_eval(*args_avals, **params)
args_tf: Sequence[TfVal] = [t.val for t in tracers]
def invoke_impl() -> TfVal:
if impl_needs_avals:

View File

@ -987,7 +987,7 @@ def _typecheck_xmap(
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
for a, a_out_axes in zip(mapped_out_avals, out_axes)]
return out_avals
return out_avals, call_jaxpr.effects
core.custom_typechecks[xmap_p] = _typecheck_xmap
@ -1087,7 +1087,7 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
del new_params['out_axes_thunk']
del new_params['spmd_out_axes_thunk']
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
new_params, source_info)
new_params, call_jaxpr.effects, source_info)
self.frame.eqns.append(eqn)
return out_tracers
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
@ -1273,7 +1273,7 @@ def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params
eqn = new_eqn_recipe((*const_tracers, *unknown_tracers_in),
unknown_tracers_out,
primitive, new_params, source_info_util.current())
primitive, new_params, jaxpr.effects, source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
return pe._zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
pe.JaxprTrace.process_xmap = _jaxpr_trace_process_xmap
@ -2079,13 +2079,14 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
for eqn in jaxpr.eqns:
new_jaxpr_params = core.traverse_jaxpr_params(rec, eqn.params)
tmp_outvars = [gen_fresh_name(v.aval) for v in eqn.outvars]
new_eqns.append(core.JaxprEqn(eqn.invars, tmp_outvars, eqn.primitive,
dict(eqn.params, **new_jaxpr_params), eqn.source_info))
new_eqns.append(eqn.replace(
outvars=tmp_outvars, params=dict(eqn.params, **new_jaxpr_params)))
for outvar, tmpvar in zip(eqn.outvars, tmp_outvars):
new_eqns.append(core.JaxprEqn([tmpvar], [outvar], sharding_constraint_p,
dict(resource_env=resource_env, axis_resources=ParsedPartitionSpec((), ())),
set(),
eqn.source_info))
return core.Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
return jaxpr.replace(eqns=new_eqns)
def _flatten_axes(what, tree, axes, tupled_args):
try:

View File

@ -864,6 +864,7 @@ def _pjit_partial_eval(trace, *in_tracers,
unknown_tracers_out,
pjit_p,
unknown_params,
unknown_jaxpr.effects,
source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
return pe._zip_knowns(known_tracers_out, unknown_tracers_out, unknown_outs)

View File

@ -707,7 +707,8 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
new_invars, new_outvars, jaxpr.jaxpr.eqns)
new_invars, new_outvars, jaxpr.jaxpr.eqns,
jaxpr.jaxpr.effects)
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
def _perm(primal_counts, tangent_counts, lst):

View File

@ -492,7 +492,8 @@ class MaskTrace(Trace):
if masking_rule is None:
raise NotImplementedError(
f'Masking rule for {primitive} not implemented yet.')
out_aval = primitive.abstract_eval(*(t.aval for t in tracers), **params)
# Ignore effects for now.
out_aval, _ = primitive.abstract_eval(*(t.aval for t in tracers), **params)
vals, polymorphic_shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers)
logical_shapes = map(shape_as_value, polymorphic_shapes)
# TODO(mattjj): generalize mask rule signature

View File

@ -167,19 +167,19 @@ class JaxprTrace(Trace):
return primitive.bind(*consts, **params)
tracers = map(self.instantiate_const, tracers)
avals = [t.aval for t in tracers]
out_aval = primitive.abstract_eval(*avals, **params)
out_aval, effects = primitive.abstract_eval(*avals, **params)
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
if primitive.multiple_results:
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
for aval in out_aval]
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source)
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, source)
for t in out_tracers: t.recipe = eqn
return out_tracers
else:
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
params, source)
params, effects, source)
return out_tracer
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
@ -222,7 +222,7 @@ class JaxprTrace(Trace):
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params, source)
out_tracers, primitive, staged_params, jaxpr.effects, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
@ -288,6 +288,7 @@ class JaxprTrace(Trace):
for a in out_avals]
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params,
jaxpr.effects,
source_info_util.current())
for t in out_tracers: t.recipe = eqn
@ -313,7 +314,8 @@ class JaxprTrace(Trace):
new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, source)
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
jaxpr.effects, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
@ -352,7 +354,7 @@ class JaxprTrace(Trace):
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
primitive, staged_params, source)
primitive, staged_params, jaxpr.effects, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
@ -412,6 +414,7 @@ class JaxprTrace(Trace):
dict(fun_jaxpr=closed_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts) + len(env)),
jaxpr.effects,
source)
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -436,7 +439,7 @@ class JaxprTrace(Trace):
in_tracers = map(self.instantiate_const, tracers)
new_params = dict(params, call=call)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params,
source_info_util.current())
core.no_effects, source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -473,6 +476,7 @@ class JaxprTrace(Trace):
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(consts) + len(env),
bwd=bwd, out_trees=out_trees),
jaxpr.effects,
source)
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -636,12 +640,14 @@ class JaxprEqnRecipe(NamedTuple):
outvars: 'Sequence[ref[JaxprTracer]]'
primitive: Primitive
params: Dict[str, Any]
effects: core.Effects
source_info: source_info_util.SourceInfo
def new_eqn_recipe(invars: Sequence[JaxprTracer],
outvars: Sequence[JaxprTracer],
primitive: Primitive,
params: Dict[str, Any],
effects: core.Effects,
source_info: source_info_util.SourceInfo
) -> JaxprEqnRecipe:
"""Constructs a new JaxEqnRecipe.
@ -665,17 +671,17 @@ def new_eqn_recipe(invars: Sequence[JaxprTracer],
assert ("donated_invars" in params and
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
params, source_info)
params, effects, source_info)
def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
_, in_tracers, out_tracer_refs, primitive, params, effects, source_info = recipe
out_tracers = [t_ref() for t_ref in out_tracer_refs]
invars = [getvar(t) for t in in_tracers]
outvars = [core.DropVar(core.abstract_unit) if t is None
else cast(Var, getvar(t)) for t in out_tracers]
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
return new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
@ -738,7 +744,9 @@ def tracers_to_jaxpr(
env_vars, env_vals = unzip2(env.items())
const_vars, const_vals = unzip2(consts.items())
# The env_vars are pre-pended to the invars
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
effects = core.join_effects(*(eqn.effects for eqn in eqns))
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers),
eqns, effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
return jaxpr, const_vals, env_vals
@ -748,7 +756,8 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
config.jax_enable_checks and core.check_jaxpr(jaxpr)
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
@ -756,7 +765,8 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
config.jax_enable_checks and core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr)
return converted_jaxpr
@ -913,7 +923,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
for x in jaxpr_unknown.outvars]
new_params = dict(params, call_jaxpr=jaxpr_unknown, differentiated=True)
recipe = new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_call_p,
new_params, source_info_util.current())
new_params, jaxpr_unknown.effects, source_info_util.current())
for t in out_jaxpr_tracers: t.recipe = recipe
return _zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
else:
@ -960,7 +970,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
const_tracers = map(trace.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p,
new_params, source_info_util.current())
new_params, new_jaxpr.effects, source_info_util.current())
for t in unknown_output_tracers: t.recipe = eqn
return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
@ -1009,8 +1019,7 @@ def _partial_eval_jaxpr_custom(
map(write, unks_out, inst_out, eqn.outvars)
elif any(unks_in):
inputs = map(ensure_instantiated, inst_in, eqn.invars)
staged_eqns.append(new_jaxpr_eqn(inputs, eqn.outvars, eqn.primitive,
eqn.params, eqn.source_info))
staged_eqns.append(eqn.replace(invars=inputs))
map(partial(write, True, True), eqn.outvars)
else:
known_eqns.append(eqn)
@ -1018,8 +1027,7 @@ def _partial_eval_jaxpr_custom(
map(partial(write, False, False), eqn.outvars)
else:
inputs = map(ensure_instantiated, inst_in, eqn.invars)
staged_eqns.append(new_jaxpr_eqn(inputs, eqn.outvars, eqn.primitive,
eqn.params, eqn.source_info))
staged_eqns.append(eqn.replace(invars=inputs))
map(partial(write, False, True), eqn.outvars)
out_unknowns, out_inst = unzip2(map(read, jaxpr.outvars))
assert all(type(v) is Var for v in residuals), residuals
@ -1027,11 +1035,15 @@ def _partial_eval_jaxpr_custom(
ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
outs_known_, _ = partition_list(out_unknowns, jaxpr.outvars)
outs_known = [x for x in outs_known_ if x.aval is not abstract_unit]
jaxpr_known = Jaxpr((), ins_known, [*outs_known, *residuals], known_eqns)
known_effects = core.join_effects(*(eqn.effects for eqn in known_eqns))
jaxpr_known = Jaxpr((), ins_known, [*outs_known, *residuals], known_eqns,
known_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
jaxpr_staged = Jaxpr((), [*residuals, *jaxpr.invars], outs_staged, staged_eqns)
staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns))
jaxpr_staged = Jaxpr((), [*residuals, *jaxpr.invars], outs_staged,
staged_eqns, staged_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals)
@ -1087,9 +1099,10 @@ def call_partial_eval_custom_rule(
params_known, params_staged = params_updater(
unks_in, kept_outs_known, kept_outs_staged, num_res, params_known, params_staged)
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, eqn.source_info)
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
eqn_staged = new_jaxpr_eqn([*residuals, *eqn.invars], out_binders_staged,
eqn.primitive, params_staged, eqn.source_info)
eqn.primitive, params_staged,
jaxpr_staged.effects, eqn.source_info)
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is Var and not inst]
@ -1142,7 +1155,7 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: List[bool]
new_jaxpr = Jaxpr((),
[v for v, b in zip(jaxpr.invars, used_inputs) if b],
[v for v, b in zip(jaxpr.outvars, used_outputs) if b],
new_eqns[::-1])
new_eqns[::-1], jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
return new_jaxpr, used_inputs
@ -1160,7 +1173,7 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
new_params = update_params(new_params, used_inputs, 0)
new_eqn = new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, eqn.source_info)
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
@ -1189,14 +1202,15 @@ def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False)
new_eqns.append(eqn)
needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
new_eqns = new_eqns[::-1]
return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars, new_eqns)
return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars, new_eqns,
jaxpr.effects)
@weakref_lru_cache
def _drop_vars(jaxpr: Jaxpr, drop_ins: Tuple[bool, ...], drop_outs: Tuple[bool, ...]):
return Jaxpr(jaxpr.constvars,
[v for v, d in zip(jaxpr.invars, drop_ins) if not d],
[v for v, d in zip(jaxpr.outvars, drop_outs) if not d],
jaxpr.eqns)
jaxpr.eqns, jaxpr.effects)
def _reconstruct_pval(pval1: PartialVal, const2: core.Value):
@ -1215,7 +1229,8 @@ def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) ->
assert len(closed_jaxpr.in_avals) == len(to_move)
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars,
closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns)
closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns,
closed_jaxpr.jaxpr.effects)
new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
return new_closed_jaxpr
@ -1286,6 +1301,7 @@ class JaxprStackFrame:
tracers: List[DynamicJaxprTracer] # hold onto strong refs for all tracers
eqns: List[JaxprEqn]
invars: List[Var]
effects: core.Effects
def __init__(self):
self.gensym = core.gensym()
@ -1295,13 +1311,18 @@ class JaxprStackFrame:
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
self.eqns = [] # cleared when we pop frame from main
self.invars = []
self.effects = set()
def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)
self.effects |= eqn.effects
def to_jaxpr(self, out_tracers):
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, self.effects)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
@ -1335,8 +1356,7 @@ def _const_folding_and_forwarding(jaxpr, constvals):
new_eqns = []
for eqn in jaxpr.eqns:
# always apply invar substitutions
eqn = JaxprEqn([var_subs.get(v, v) for v in eqn.invars], eqn.outvars,
eqn.primitive, eqn.params, eqn.source_info)
eqn = eqn.replace(invars=[var_subs.get(v, v) for v in eqn.invars])
# if any inputs are constants and we have a constant-folding rule, apply it
if eqn.primitive in const_fold_rules and any(v in consts for v in eqn.invars):
consts_in = [consts.get(v) for v in eqn.invars]
@ -1357,7 +1377,7 @@ def _const_folding_and_forwarding(jaxpr, constvals):
new_eqns.append(eqn)
new_constvars, new_constvals = unzip2(consts.items())
new_outvars = [var_subs.get(v, v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns)
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns, jaxpr.effects)
return new_jaxpr, new_constvals
ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
@ -1402,10 +1422,10 @@ def _inline_literals(jaxpr, constvals):
invars = [lit(v) or var(v) for v in eqn.invars]
outvars = [var(v) if v in used else core.DropVar(v.aval)
for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
jaxpr.effects)
return new_jaxpr, new_constvals
class DynamicJaxprTrace(core.Trace):
@ -1485,14 +1505,14 @@ class DynamicJaxprTrace(core.Trace):
def default_process_primitive(self, primitive, tracers, params):
avals = [t.aval for t in tracers]
out_avals = primitive.abstract_eval(*avals, **params)
out_avals, effects = primitive.abstract_eval(*avals, **params)
out_avals = [out_avals] if not primitive.multiple_results else out_avals
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
self.frame.eqns.append(eqn)
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
self.frame.add_eqn(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()
def process_call(self, call_primitive, f, tracers, params):
@ -1520,7 +1540,8 @@ class DynamicJaxprTrace(core.Trace):
new_params = update_params(new_params, [True] * len(tracers),
len(consts) + len(dim_tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
call_primitive, new_params, source_info)
call_primitive, new_params,
new_params['call_jaxpr'].effects, source_info)
self.frame.eqns.append(eqn)
return out_tracers
@ -1554,7 +1575,7 @@ class DynamicJaxprTrace(core.Trace):
if update_params:
new_params = update_params(new_params, [True] * len(tracers), len(consts))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
new_params, source_info)
new_params, new_params['call_jaxpr'].effects, source_info)
self.frame.eqns.append(eqn)
return out_tracers
@ -1577,6 +1598,7 @@ class DynamicJaxprTrace(core.Trace):
dict(fun_jaxpr=closed_fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts)),
fun_jaxpr.effects,
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers
@ -1601,6 +1623,7 @@ class DynamicJaxprTrace(core.Trace):
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(consts),
bwd=bwd, out_trees=out_trees),
fun_jaxpr.effects,
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers
@ -1640,8 +1663,9 @@ class DynamicJaxprTrace(core.Trace):
transpose_jaxpr_thunk=transpose_jaxpr_thunk,
out_types=out_types, res_tree=res_tree,
lin_tree=lin_tree, out_tree=out_tree),
closed_call_jaxpr.effects,
source_info_util.current())
self.frame.eqns.append(eqn)
self.frame.add_eqn(eqn)
return out_tracers

133
tests/jaxpr_effects_test.py Normal file
View File

@ -0,0 +1,133 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax import core
from jax import lax
from jax.config import config
from jax._src import test_util as jtu
config.parse_flags_with_absl()
effect_p = core.Primitive('effect')
effect_p.multiple_results = True
@effect_p.def_effectful_abstract_eval
def _(*, effect):
return [], {effect}
class JaxprEffectsTest(jtu.JaxTestCase):
def test_trivial_jaxpr_has_no_effects(self):
def f(x):
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.)
self.assertEqual(core.no_effects, jaxpr.effects)
def test_effectful_primitive_in_jaxpr_creates_effects(self):
def f(x):
effect_p.bind(effect='foo')
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.)
self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects)
self.assertEqual({'foo'}, jaxpr.effects)
def test_different_effects_in_jaxpr(self):
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.)
self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects)
self.assertEqual({'bar'}, jaxpr.jaxpr.eqns[1].effects)
self.assertEqual({'foo', 'bar'}, jaxpr.effects)
def test_jaxpr_typecheck_should_verify_eqn_effects_are_subset(self):
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.).jaxpr
# Edit jaxpr to make its type wrong
jaxpr = jaxpr.replace(effects={'foo'})
with self.assertRaisesRegex(core.JaxprTypeError,
'Equation effects are not subset of Jaxpr effects.'):
core.check_jaxpr(jaxpr)
class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_effects_disallowed_in_cond(self):
def f1(x):
def true_fun(x):
effect_p.bind(effect='foo')
return x
def false_fun(x):
return x
return lax.cond(True, true_fun, false_fun, x)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f1)(2.)
def f2(x):
def true_fun(x):
return x
def false_fun(x):
effect_p.bind(effect='foo')
return x
return lax.cond(True, true_fun, false_fun, x)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f2)(2.)
def test_effects_disallowed_in_while(self):
def f1(x):
def cond_fun(x):
effect_p.bind(effect='foo')
return False
def body_fun(x):
return x
return lax.while_loop(cond_fun, body_fun, x)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f1)(2.)
def f2(x):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect='foo')
return x
return lax.while_loop(cond_fun, body_fun, x)
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f2)(2.)
def test_effects_disallowed_in_scan(self):
def f(x):
def body(carry, x):
effect_p.bind(effect='foo')
return carry, x
return lax.scan(body, x, jnp.arange(4))
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -2888,20 +2888,20 @@ class LaxNamedShapeTest(jtu.JaxTestCase):
def test_abstract_eval(self):
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
out = lax.sin_p.abstract_eval(aval1)
out, _ = lax.sin_p.abstract_eval(aval1)
self.assertEqual(out, aval1)
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
out = lax.add_p.abstract_eval(aval1, aval2)
out, _ = lax.add_p.abstract_eval(aval1, aval2)
self.assertEqual(out, expected)
def test_abstract_eval_collective(self):
with core.extend_axis_env('i', 10, None):
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
expected = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
out, = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
(out,), _ = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
self.assertEqual(out, expected)
if __name__ == '__main__':