mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Adds simple effect types to jaxprs
This commit is contained in:
parent
902fc0c3d2
commit
0fa1eddd25
@ -328,7 +328,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
for x in jaxpr_unknown.outvars]
|
||||
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
|
||||
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
|
||||
new_params, source_info_util.current())
|
||||
new_params, jaxpr_unknown.effects, source_info_util.current())
|
||||
for t in out_jaxpr_tracers: t.recipe = recipe
|
||||
|
||||
# zip together known and unknown outputs
|
||||
|
@ -180,7 +180,7 @@ class CustomTransposePrimitive(core.Primitive):
|
||||
|
||||
# TODO(frostig,mattjj): reinstate checks
|
||||
def custom_transpose_typecheck(*avals, **params):
|
||||
pass
|
||||
return None, core.no_effects
|
||||
|
||||
|
||||
def custom_transpose_transpose_rule(
|
||||
|
@ -303,7 +303,8 @@ def _prune_unused_inputs(
|
||||
(i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
|
||||
kept_var_idx, new_invars = util.unzip2(
|
||||
(i, v) for i, v in enumerate(jaxpr.invars) if v in used)
|
||||
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns)
|
||||
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns,
|
||||
jaxpr.effects)
|
||||
return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
|
||||
|
||||
|
||||
|
@ -109,8 +109,7 @@ def _initial_style_jaxprs_with_common_consts(
|
||||
prefix = util.concatenate(unused_const_vars[:i])
|
||||
suffix = util.concatenate(unused_const_vars[i + 1:])
|
||||
constvars = [*prefix, *jaxpr.constvars, *suffix]
|
||||
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
return jaxpr.replace(constvars=constvars)
|
||||
|
||||
consts = util.concatenate(all_consts)
|
||||
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
|
||||
@ -315,6 +314,9 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
new_init_val, = tree_unflatten(in_tree, new_init_vals)
|
||||
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
|
||||
cond_jaxpr, cond_consts, body_consts, body_tree = rest
|
||||
joined_effects = core.join_effects(body_jaxpr.effects, cond_jaxpr.effects)
|
||||
if joined_effects:
|
||||
raise NotImplementedError('Effects not supported in `while`.')
|
||||
|
||||
in_tree_children = in_tree.children()
|
||||
assert len(in_tree_children) == 1
|
||||
@ -326,8 +328,9 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
|
||||
return tree_unflatten(body_tree, outs)
|
||||
|
||||
def _while_loop_abstract_eval(*args, **kwargs):
|
||||
return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals)
|
||||
def _while_loop_abstract_eval(*args, body_jaxpr, **kwargs):
|
||||
del args, kwargs
|
||||
return _map(raise_to_shaped, body_jaxpr.out_avals), core.no_effects
|
||||
|
||||
def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
body_jaxpr, cond_nconsts, body_nconsts):
|
||||
@ -516,7 +519,8 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
|
||||
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
|
||||
invars_aug,
|
||||
cond_jaxpr.jaxpr.outvars,
|
||||
cond_jaxpr.jaxpr.eqns)
|
||||
cond_jaxpr.jaxpr.eqns,
|
||||
cond_jaxpr.jaxpr.effects)
|
||||
cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts)
|
||||
|
||||
out = while_p.bind(
|
||||
@ -611,7 +615,7 @@ def _while_transpose_error(*_, **kwargs):
|
||||
while_p = core.AxisPrimitive('while')
|
||||
while_p.multiple_results = True
|
||||
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
||||
while_p.def_abstract_eval(_while_loop_abstract_eval)
|
||||
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
|
||||
ad.primitive_jvps[while_p] = _while_loop_jvp
|
||||
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
||||
xla.register_translation(while_p, _while_loop_translation_rule,
|
||||
@ -875,6 +879,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
|
||||
if joined_effects:
|
||||
raise NotImplementedError('Effects not supported in `cond`.')
|
||||
true_jaxpr, false_jaxpr = jaxprs
|
||||
out_tree, false_out_tree = out_trees
|
||||
|
||||
@ -926,8 +933,8 @@ def _cond_with_per_branch_args(pred,
|
||||
lambda op: false_fun(op[1]),
|
||||
(true_operand, false_operand))
|
||||
|
||||
def _cond_abstract_eval(*args, **kwargs):
|
||||
return _map(raise_to_shaped, kwargs["branches"][0].out_avals)
|
||||
def _cond_abstract_eval(*args, branches, **kwargs):
|
||||
return _map(raise_to_shaped, branches[0].out_avals), core.no_effects
|
||||
|
||||
def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
|
||||
linear):
|
||||
@ -1121,11 +1128,13 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
|
||||
linear_2 = (False,) * num_res + linear
|
||||
params = dict(branches=branches_2, linear=linear_2)
|
||||
if any((branch.effects for branch in branches_2)):
|
||||
raise NotImplementedError('Effects not supported in `cond`.')
|
||||
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = pe.new_eqn_recipe(
|
||||
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
|
||||
source)
|
||||
core.no_effects, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
@ -1200,7 +1209,8 @@ def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
||||
aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
|
||||
aug_invars = aug_res_vars + non_res_vars
|
||||
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
|
||||
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns)
|
||||
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
|
||||
jaxpr.jaxpr.effects)
|
||||
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
|
||||
return jaxpr_aug
|
||||
|
||||
@ -1304,6 +1314,11 @@ def _cond_typecheck(*avals, branches, linear):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches take input types {jaxpr0_in_avals_str}, '
|
||||
f'called with operands of type {_avals_short(op_avals)}')
|
||||
if any((b.effects != branches[0].effects for b in branches[1:])):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches must have matching effect types: '
|
||||
f'{[b.effects for b in branches]}')
|
||||
return None, core.no_effects
|
||||
|
||||
def cond_bind(*args, branches, linear):
|
||||
if config.jax_enable_checks:
|
||||
@ -1316,7 +1331,7 @@ def cond_bind(*args, branches, linear):
|
||||
cond_p = core.AxisPrimitive('cond')
|
||||
cond_p.multiple_results = True
|
||||
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
||||
cond_p.def_abstract_eval(_cond_abstract_eval)
|
||||
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
|
||||
cond_p.def_custom_bind(cond_bind)
|
||||
ad.primitive_jvps[cond_p] = _cond_jvp
|
||||
ad.reducing_transposes[cond_p] = _cond_transpose
|
||||
@ -1509,6 +1524,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
new_init = tree_unflatten(init_tree, new_init_flat)
|
||||
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
|
||||
in_flat, jaxpr, consts, out_tree, out_tree_children = rest
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported in `scan`.')
|
||||
|
||||
_check_tree_and_avals("scan carry output and input",
|
||||
# Extract the subtree and avals for the first element of the return tuple
|
||||
@ -1713,7 +1730,7 @@ def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear, unroll):
|
||||
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
|
||||
return carry_avals + ys_avals
|
||||
return carry_avals + ys_avals, core.no_effects
|
||||
|
||||
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
|
||||
linear, unroll):
|
||||
@ -1884,6 +1901,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
num_consts=num_consts_2,
|
||||
num_carry=num_carry, linear=tuple(linear_2),
|
||||
unroll=unroll),
|
||||
jaxpr_2_opt.effects,
|
||||
source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
@ -2088,6 +2106,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
|
||||
raise core.JaxprTypeError(
|
||||
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
|
||||
f'called with sequence of type\n{_avals_short(x_avals)}')
|
||||
return None, core.no_effects
|
||||
|
||||
def scan_bind(*args, **params):
|
||||
if config.jax_enable_checks:
|
||||
@ -2100,7 +2119,7 @@ scan_p = core.AxisPrimitive("scan")
|
||||
scan_p.multiple_results = True
|
||||
scan_p.def_custom_bind(scan_bind)
|
||||
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
|
||||
scan_p.def_abstract_eval(_scan_abstract_eval)
|
||||
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
|
||||
ad.primitive_jvps[scan_p] = _scan_jvp
|
||||
ad.reducing_transposes[scan_p] = _scan_transpose
|
||||
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
|
||||
|
@ -2723,7 +2723,7 @@ def _broadcast_in_dim_staging_rule(
|
||||
out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info)
|
||||
invars = [trace.getvar(x), *(trace.getvar(d) for d in dyn_shape)]
|
||||
eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)],
|
||||
broadcast_in_dim_p, params, source_info)
|
||||
broadcast_in_dim_p, params, core.no_effects, source_info)
|
||||
trace.frame.eqns.append(eqn)
|
||||
|
||||
return out_tracer
|
||||
|
@ -1642,7 +1642,7 @@ def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision):
|
||||
if not len(set(axis_name)) == len(axis_name): raise ValueError
|
||||
pos_aval = lax.dot_general_p.abstract_eval(
|
||||
x, y, dimension_numbers=[pos_contract, pos_batch],
|
||||
precision=precision, preferred_element_type=None)
|
||||
precision=precision, preferred_element_type=None)[0]
|
||||
common_named_shape = core.join_named_shapes(x.named_shape, y.named_shape)
|
||||
named_shape = {name: size
|
||||
for name, size in common_named_shape.items()
|
||||
|
79
jax/core.py
79
jax/core.py
@ -54,14 +54,21 @@ map, unsafe_map = safe_map, map
|
||||
|
||||
# -------------------- jaxprs --------------------
|
||||
|
||||
Effect = Any
|
||||
Effects = Set[Effect]
|
||||
|
||||
no_effects: Effects = set()
|
||||
|
||||
class Jaxpr:
|
||||
constvars: List[Var]
|
||||
invars: List[Var]
|
||||
outvars: List[Atom]
|
||||
eqns: List[JaxprEqn]
|
||||
effects: Effects
|
||||
|
||||
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
|
||||
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn]):
|
||||
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
|
||||
effects: Effects = no_effects):
|
||||
"""
|
||||
Args:
|
||||
constvars: list of variables introduced for constants. Array constants are
|
||||
@ -70,11 +77,14 @@ class Jaxpr:
|
||||
the inputs to the Jaxpr.
|
||||
outvars: list of output variables.
|
||||
eqns: list of equations.
|
||||
effects: set of effects. The effects on a jaxpr are a superset of the
|
||||
union of the effects for each equation.
|
||||
"""
|
||||
self.constvars = list(constvars)
|
||||
self.invars = list(invars)
|
||||
self.outvars = list(outvars)
|
||||
self.eqns = list(eqns)
|
||||
self.effects = effects
|
||||
|
||||
def __str__(self):
|
||||
return str(pp_jaxpr(self, JaxprPpContext(), JaxprPpSettings()))
|
||||
@ -92,6 +102,18 @@ class Jaxpr:
|
||||
def _repr_pretty_(self, p, cycle):
|
||||
return p.text(self.pretty_print(use_color=True))
|
||||
|
||||
def replace(self, *, constvars=None, invars=None, outvars=None, eqns=None,
|
||||
effects=None):
|
||||
constvars = self.constvars if constvars is None else constvars
|
||||
invars = self.invars if invars is None else invars
|
||||
outvars = self.outvars if outvars is None else outvars
|
||||
eqns = self.eqns if eqns is None else eqns
|
||||
effects = self.effects if effects is None else effects
|
||||
return Jaxpr(constvars=constvars, invars=invars, outvars=outvars, eqns=eqns,
|
||||
effects=effects)
|
||||
|
||||
def join_effects(*effects: Effect) -> Effects:
|
||||
return set.union(*effects) if effects else no_effects
|
||||
|
||||
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
|
||||
for val in params.values():
|
||||
@ -137,9 +159,18 @@ class ClosedJaxpr:
|
||||
def eqns(self):
|
||||
return self.jaxpr.eqns
|
||||
|
||||
@property
|
||||
def effects(self) -> Effects:
|
||||
return self.jaxpr.effects
|
||||
|
||||
def map_jaxpr(self, f):
|
||||
return ClosedJaxpr(f(self.jaxpr), self.consts)
|
||||
|
||||
def replace(self, *, jaxpr=None, consts=None):
|
||||
jaxpr = self.jaxpr if jaxpr is None else jaxpr
|
||||
consts = self.consts if consts is None else consts
|
||||
return ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
def __str__(self): return str(self.jaxpr)
|
||||
def __repr__(self): return repr(self.jaxpr)
|
||||
|
||||
@ -164,16 +195,20 @@ class JaxprEqn(NamedTuple):
|
||||
outvars: List[Var]
|
||||
primitive: Primitive
|
||||
params: Dict[str, Any]
|
||||
effects: Effects
|
||||
source_info: source_info_util.SourceInfo
|
||||
|
||||
def __repr__(self):
|
||||
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()
|
||||
|
||||
def new_jaxpr_eqn(invars, outvars, primitive, params, source_info=None):
|
||||
def replace(self, *args, **kwargs):
|
||||
return self._replace(*args, **kwargs)
|
||||
|
||||
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
|
||||
if primitive.call_primitive:
|
||||
assert len(outvars) == len(params["call_jaxpr"].outvars)
|
||||
source_info = source_info or source_info_util.new_source_info()
|
||||
return JaxprEqn(invars, outvars, primitive, params, source_info)
|
||||
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)
|
||||
|
||||
|
||||
@total_ordering
|
||||
@ -296,9 +331,13 @@ class Primitive:
|
||||
return impl
|
||||
|
||||
def def_abstract_eval(self, abstract_eval):
|
||||
self.abstract_eval = abstract_eval
|
||||
self.abstract_eval = _effect_free_abstract_eval(abstract_eval) # type: ignore[assignment]
|
||||
return abstract_eval
|
||||
|
||||
def def_effectful_abstract_eval(self, effectful_abstract_eval):
|
||||
self.abstract_eval = effectful_abstract_eval # type: ignore[assignment]
|
||||
return effectful_abstract_eval
|
||||
|
||||
def def_custom_bind(self, bind):
|
||||
self.bind = bind
|
||||
return bind
|
||||
@ -315,6 +354,11 @@ class Primitive:
|
||||
return [], params
|
||||
|
||||
|
||||
def _effect_free_abstract_eval(abstract_eval):
|
||||
def abstract_eval_(*args, **kwargs):
|
||||
return abstract_eval(*args, **kwargs), no_effects
|
||||
return abstract_eval_
|
||||
|
||||
# -------------------- lifting --------------------
|
||||
|
||||
# TODO(mattjj): replace this approach with a primitive-keyed table of rules
|
||||
@ -2018,7 +2062,7 @@ def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: Dict[Var, Var
|
||||
e.eqn = eqn
|
||||
raise
|
||||
params = subst_axis_names(eqn.primitive, eqn.params, subst)
|
||||
return new_jaxpr_eqn(invars, outvars, eqn.primitive, params, eqn.source_info)
|
||||
return eqn.replace(invars=invars, outvars=outvars, params=params)
|
||||
|
||||
def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst):
|
||||
consts = None
|
||||
@ -2030,7 +2074,7 @@ def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst
|
||||
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars]
|
||||
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns]
|
||||
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars]
|
||||
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns)
|
||||
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, jaxpr.effects)
|
||||
if consts is not None:
|
||||
return ClosedJaxpr(new_jaxpr, consts)
|
||||
return new_jaxpr
|
||||
@ -2172,15 +2216,20 @@ def _check_jaxpr(
|
||||
if any(isinstance(ina, ConcreteArray) for ina in in_avals):
|
||||
raise JaxprTypeError("Equation given ConcreteArray type inputs")
|
||||
if prim in custom_typechecks:
|
||||
out_avals = custom_typechecks[prim](*in_avals, **eqn.params)
|
||||
out_avals, effects = custom_typechecks[prim](*in_avals, **eqn.params)
|
||||
if out_avals is None:
|
||||
out_avals = [v.aval for v in eqn.outvars]
|
||||
elif prim.call_primitive:
|
||||
out_avals = check_call(ctx_factory, prim, in_avals, eqn.params)
|
||||
out_avals, effects = check_call(ctx_factory, prim, in_avals, eqn.params)
|
||||
elif prim.map_primitive:
|
||||
out_avals = check_map(ctx_factory, prim, in_avals, eqn.params)
|
||||
out_avals, effects = check_map(ctx_factory, prim, in_avals, eqn.params)
|
||||
else:
|
||||
out_avals = check_eqn(prim, in_avals, eqn.params)
|
||||
out_avals, effects = check_eqn(prim, in_avals, eqn.params)
|
||||
if eqn.effects != effects:
|
||||
print(eqn.effects, effects)
|
||||
raise JaxprTypeError("Inferred effects do not match equation effects.")
|
||||
if not eqn.effects.issubset(jaxpr.effects):
|
||||
raise JaxprTypeError("Equation effects are not subset of Jaxpr effects.")
|
||||
map(write, eqn.outvars, out_avals)
|
||||
except JaxprTypeError as e:
|
||||
ctx, settings = ctx_factory()
|
||||
@ -2196,10 +2245,10 @@ def check_eqn(prim, in_avals, params):
|
||||
for jaxpr in jaxprs_in_params(params):
|
||||
check_jaxpr(jaxpr)
|
||||
|
||||
out_avals = prim.abstract_eval(*in_avals, **params)
|
||||
out_avals, effects = prim.abstract_eval(*in_avals, **params)
|
||||
if not prim.multiple_results:
|
||||
out_avals = [out_avals]
|
||||
return out_avals
|
||||
return out_avals, effects
|
||||
|
||||
def check_call(ctx_factory, prim, in_avals, params):
|
||||
if "call_jaxpr" not in params:
|
||||
@ -2221,12 +2270,14 @@ def check_call(ctx_factory, prim, in_avals, params):
|
||||
_check_jaxpr(ctx_factory, call_jaxpr, in_avals)
|
||||
|
||||
out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
return out_avals
|
||||
return out_avals, call_jaxpr.effects
|
||||
|
||||
def check_map(ctx_factory, prim, in_avals, params):
|
||||
if "call_jaxpr" not in params:
|
||||
raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
|
||||
call_jaxpr = params["call_jaxpr"]
|
||||
if call_jaxpr.effects:
|
||||
raise JaxprTypeError(f"Map primitive {prim} mapping an effectful function")
|
||||
if "axis_size" not in params:
|
||||
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_size' parameter")
|
||||
axis_size = params["axis_size"]
|
||||
@ -2257,7 +2308,7 @@ def check_map(ctx_factory, prim, in_avals, params):
|
||||
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval
|
||||
for aval, out_axis in zip(mapped_out_avals, out_axes)]
|
||||
return out_avals
|
||||
return out_avals, call_jaxpr.effects
|
||||
|
||||
|
||||
# ------------------- Jaxpr printed representation -------------------
|
||||
|
@ -619,8 +619,7 @@ def ignore_errors_jaxpr(jaxpr, error):
|
||||
new_vars = core.gensym([jaxpr])
|
||||
new_invars = (new_vars(err_aval), new_vars(code_aval),
|
||||
new_vars(payload_aval), *jaxpr.invars)
|
||||
new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars,
|
||||
jaxpr.outvars, jaxpr.eqns)
|
||||
new_jaxpr = jaxpr.replace(invars=new_invars)
|
||||
return core.ClosedJaxpr(new_jaxpr, consts)
|
||||
|
||||
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
|
@ -1441,10 +1441,10 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
||||
# We need tokens but none is given in input; make one depending on all invars
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(jaxpr.invars, [last_token_var],
|
||||
lax.create_token_p, {}, source_info_util.current()))
|
||||
lax.create_token_p, {}, core.no_effects, source_info_util.current()))
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(jaxpr.invars, [last_itoken_var],
|
||||
lax.create_token_p, {}, source_info_util.current()))
|
||||
lax.create_token_p, {}, core.no_effects, source_info_util.current()))
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if not core.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
||||
@ -1458,7 +1458,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
||||
last_itoken_var = output_itoken_var
|
||||
|
||||
outvars = jaxpr.outvars + ([last_token_var, last_itoken_var] if has_output_token else [])
|
||||
new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
|
||||
new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr.effects)
|
||||
return new_jaxpr
|
||||
|
||||
|
||||
@ -1476,11 +1476,9 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
"""
|
||||
if eqn.primitive is outside_call_p:
|
||||
assert "has_token" not in eqn.params
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(eqn.invars + [input_token_var, input_itoken_var],
|
||||
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
||||
dict(eqn.params, has_token=True),
|
||||
eqn.source_info))
|
||||
eqns.append(eqn.replace(invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(eqn.params, has_token=True)))
|
||||
elif eqn.primitive is lax.while_p:
|
||||
cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
|
||||
eqn.params,
|
||||
@ -1492,28 +1490,26 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
return
|
||||
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
eqn.invars + [input_token_var, input_itoken_var],
|
||||
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
||||
dict(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
body_jaxpr=_rewrite_closed_jaxpr(body_jaxpr, True, True),
|
||||
cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True,
|
||||
False)), eqn.source_info))
|
||||
cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True, False))))
|
||||
elif eqn.primitive is lax.cond_p:
|
||||
branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
|
||||
index, *operands = eqn.invars
|
||||
new_invars = [index, *operands, input_token_var, input_itoken_var]
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
new_invars, eqn.outvars + [output_token_var, output_itoken_var],
|
||||
eqn.primitive,
|
||||
dict(
|
||||
eqn.replace(
|
||||
invars=new_invars, outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
branches=tuple(
|
||||
_rewrite_closed_jaxpr(jaxpr, True, True)
|
||||
for jaxpr in branches),
|
||||
linear=(*linear, False, False)), eqn.source_info))
|
||||
linear=(*linear, False, False))))
|
||||
elif eqn.primitive is lax.scan_p:
|
||||
num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict(
|
||||
eqn.params,
|
||||
@ -1537,56 +1533,51 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
new_jaxpr_outvars[num_carry:-2])
|
||||
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
new_invars,
|
||||
eqn.replace(
|
||||
invars=new_invars,
|
||||
# Output token is at the end of carry result
|
||||
eqn.outvars[0:num_carry] + [output_token_var, output_itoken_var] +
|
||||
eqn.outvars[num_carry:],
|
||||
eqn.primitive,
|
||||
dict(
|
||||
outvars=(eqn.outvars[0:num_carry] + [output_token_var, output_itoken_var] +
|
||||
eqn.outvars[num_carry:]),
|
||||
params=dict(
|
||||
eqn.params,
|
||||
jaxpr=new_jaxpr,
|
||||
num_carry=num_carry + 2,
|
||||
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:]),
|
||||
eqn.source_info))
|
||||
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:])))
|
||||
elif eqn.primitive is xla.xla_call_p:
|
||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
eqn.invars + [input_token_var, input_itoken_var],
|
||||
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
||||
dict(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
||||
donated_invars=eqn.params["donated_invars"] + (False, False)),
|
||||
eqn.source_info))
|
||||
donated_invars=eqn.params["donated_invars"] + (False, False))))
|
||||
elif eqn.primitive is pxla.xla_pmap_p:
|
||||
# We broadcast the input token into an array of tokens
|
||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
eqn.invars + [input_token_var, input_itoken_var],
|
||||
eqn.outvars + [output_token_var, output_itoken_var],
|
||||
eqn.primitive,
|
||||
dict(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
||||
donated_invars=eqn.params["donated_invars"] + (False, False),
|
||||
# Sharding/unsharding of tokens in pmap_translation are special
|
||||
# cased to just pass-through the token
|
||||
in_axes=eqn.params["in_axes"] + (None, None),
|
||||
out_axes=eqn.params["out_axes"] + (0, 0)),
|
||||
eqn.source_info))
|
||||
out_axes=eqn.params["out_axes"] + (0, 0))))
|
||||
elif eqn.primitive is pe.remat_call_p:
|
||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
eqn.invars + [input_token_var, input_itoken_var],
|
||||
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
||||
dict(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
||||
), eqn.source_info))
|
||||
)))
|
||||
elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p:
|
||||
fun_jaxpr = eqn.params["fun_jaxpr"]
|
||||
|
||||
@ -1594,15 +1585,14 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
assert False, "Should not be reached"
|
||||
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
eqn.invars + [input_token_var, input_itoken_var],
|
||||
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
||||
dict(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
|
||||
jvp_jaxpr_thunk=unreachable_thunk
|
||||
),
|
||||
eqn.source_info))
|
||||
)))
|
||||
elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p:
|
||||
fun_jaxpr = eqn.params["fun_jaxpr"]
|
||||
new_invars = [*eqn.invars, input_token_var, input_itoken_var]
|
||||
@ -1611,11 +1601,10 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
assert False, "Should not be reached"
|
||||
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
new_invars,
|
||||
eqn.outvars + [output_token_var, output_itoken_var],
|
||||
eqn.primitive,
|
||||
dict(
|
||||
eqn.replace(
|
||||
invars=new_invars,
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
|
||||
fwd_jaxpr_thunk=unreachable_thunk,
|
||||
@ -1623,25 +1612,24 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
# should not be needed because this rewrite is just before
|
||||
# compilation to XLA, which does not use those parameters.
|
||||
bwd="illegal param",
|
||||
out_trees="illegal param"),
|
||||
eqn.source_info))
|
||||
out_trees="illegal param")))
|
||||
elif eqn.primitive is core.named_call_p:
|
||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
eqn.invars + [input_token_var, input_itoken_var],
|
||||
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
||||
dict(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
||||
), eqn.source_info))
|
||||
)))
|
||||
elif eqn.primitive is pjit.pjit_p:
|
||||
jaxpr = cast(core.ClosedJaxpr, eqn.params["jaxpr"])
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
eqn.invars + [input_token_var, input_itoken_var],
|
||||
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
||||
dict(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True),
|
||||
donated_invars=eqn.params["donated_invars"] + (False, False),
|
||||
@ -1649,7 +1637,7 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
(pjit.REPLICATED, pjit.REPLICATED)),
|
||||
out_axis_resources=(eqn.params["out_axis_resources"] +
|
||||
(pjit.REPLICATED, pjit.REPLICATED)),
|
||||
), eqn.source_info))
|
||||
)))
|
||||
else:
|
||||
raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
|
||||
|
||||
@ -1678,6 +1666,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
name="cond_before",
|
||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
|
||||
inline=False),
|
||||
transformed_cond_jaxpr.jaxpr.effects,
|
||||
eqn.source_info))
|
||||
# Make a new cond "lambda pred, carry, token, itoken: pred"
|
||||
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
|
||||
@ -1686,7 +1675,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
[mk_new_var(input_token_var.aval),
|
||||
mk_new_var(input_itoken_var.aval)])
|
||||
new_cond_jaxpr = core.ClosedJaxpr(
|
||||
core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), [])
|
||||
core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], [], set()), [])
|
||||
# Make a new body:
|
||||
# "lambda cond_constvars, body_constvars, pred, carry, token, itoken:
|
||||
# carry2, token2, itoken2 = rewrite(BODY)(body_constvars, carry, token, itoken)
|
||||
@ -1723,6 +1712,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
name="body",
|
||||
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals),
|
||||
inline=False),
|
||||
transformed_body_jaxpr.effects,
|
||||
eqn.source_info),
|
||||
core.new_jaxpr_eqn(
|
||||
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
|
||||
@ -1732,14 +1722,16 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
name="cond_body",
|
||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
|
||||
inline=False),
|
||||
transformed_cond_jaxpr.effects,
|
||||
eqn.source_info)
|
||||
]
|
||||
effects = core.join_effects(*(eqn.effects for eqn in new_body_eqns))
|
||||
new_body_jaxpr = core.ClosedJaxpr(
|
||||
core.Jaxpr([], (new_body_invars_cond_constvars +
|
||||
new_body_invars_body_constvars + [new_body_invars_pred] +
|
||||
new_body_invars_carry + [new_body_invars_token, new_body_invars_itoken]),
|
||||
([new_body_pred2] + new_body_carry2 + [new_body_token3, new_body_itoken3]),
|
||||
new_body_eqns), [])
|
||||
new_body_eqns, effects), [])
|
||||
|
||||
pred_out = mk_new_var(cond_jaxpr.out_avals[0])
|
||||
eqns.append(
|
||||
@ -1752,7 +1744,9 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
cond_jaxpr=new_cond_jaxpr,
|
||||
cond_nconsts=0,
|
||||
body_jaxpr=new_body_jaxpr,
|
||||
body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
|
||||
body_nconsts=cond_nconsts + body_nconsts),
|
||||
new_body_jaxpr.effects,
|
||||
eqn.source_info))
|
||||
|
||||
|
||||
# We need an identity primitive to simplify rewriting
|
||||
|
@ -809,7 +809,7 @@ class TensorFlowTrace(core.Trace):
|
||||
# abstract evaluation rules can properly track polymorphic shapes.
|
||||
# Unfortunately under op-by-op execution this is a rare occasion where we
|
||||
# need abstract evaluation.
|
||||
out_aval = primitive.abstract_eval(*args_avals, **params)
|
||||
out_aval, _ = primitive.abstract_eval(*args_avals, **params)
|
||||
args_tf: Sequence[TfVal] = [t.val for t in tracers]
|
||||
def invoke_impl() -> TfVal:
|
||||
if impl_needs_avals:
|
||||
|
@ -987,7 +987,7 @@ def _typecheck_xmap(
|
||||
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
|
||||
for a, a_out_axes in zip(mapped_out_avals, out_axes)]
|
||||
return out_avals
|
||||
return out_avals, call_jaxpr.effects
|
||||
core.custom_typechecks[xmap_p] = _typecheck_xmap
|
||||
|
||||
|
||||
@ -1087,7 +1087,7 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
del new_params['out_axes_thunk']
|
||||
del new_params['spmd_out_axes_thunk']
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
|
||||
new_params, source_info)
|
||||
new_params, call_jaxpr.effects, source_info)
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
|
||||
@ -1273,7 +1273,7 @@ def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params
|
||||
|
||||
eqn = new_eqn_recipe((*const_tracers, *unknown_tracers_in),
|
||||
unknown_tracers_out,
|
||||
primitive, new_params, source_info_util.current())
|
||||
primitive, new_params, jaxpr.effects, source_info_util.current())
|
||||
for t in unknown_tracers_out: t.recipe = eqn
|
||||
return pe._zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
|
||||
pe.JaxprTrace.process_xmap = _jaxpr_trace_process_xmap
|
||||
@ -2079,13 +2079,14 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
||||
for eqn in jaxpr.eqns:
|
||||
new_jaxpr_params = core.traverse_jaxpr_params(rec, eqn.params)
|
||||
tmp_outvars = [gen_fresh_name(v.aval) for v in eqn.outvars]
|
||||
new_eqns.append(core.JaxprEqn(eqn.invars, tmp_outvars, eqn.primitive,
|
||||
dict(eqn.params, **new_jaxpr_params), eqn.source_info))
|
||||
new_eqns.append(eqn.replace(
|
||||
outvars=tmp_outvars, params=dict(eqn.params, **new_jaxpr_params)))
|
||||
for outvar, tmpvar in zip(eqn.outvars, tmp_outvars):
|
||||
new_eqns.append(core.JaxprEqn([tmpvar], [outvar], sharding_constraint_p,
|
||||
dict(resource_env=resource_env, axis_resources=ParsedPartitionSpec((), ())),
|
||||
set(),
|
||||
eqn.source_info))
|
||||
return core.Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
|
||||
return jaxpr.replace(eqns=new_eqns)
|
||||
|
||||
def _flatten_axes(what, tree, axes, tupled_args):
|
||||
try:
|
||||
|
@ -864,6 +864,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
unknown_tracers_out,
|
||||
pjit_p,
|
||||
unknown_params,
|
||||
unknown_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
for t in unknown_tracers_out: t.recipe = eqn
|
||||
return pe._zip_knowns(known_tracers_out, unknown_tracers_out, unknown_outs)
|
||||
|
@ -707,7 +707,8 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
|
||||
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
|
||||
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
|
||||
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
|
||||
new_invars, new_outvars, jaxpr.jaxpr.eqns)
|
||||
new_invars, new_outvars, jaxpr.jaxpr.eqns,
|
||||
jaxpr.jaxpr.effects)
|
||||
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
|
||||
|
||||
def _perm(primal_counts, tangent_counts, lst):
|
||||
|
@ -492,7 +492,8 @@ class MaskTrace(Trace):
|
||||
if masking_rule is None:
|
||||
raise NotImplementedError(
|
||||
f'Masking rule for {primitive} not implemented yet.')
|
||||
out_aval = primitive.abstract_eval(*(t.aval for t in tracers), **params)
|
||||
# Ignore effects for now.
|
||||
out_aval, _ = primitive.abstract_eval(*(t.aval for t in tracers), **params)
|
||||
vals, polymorphic_shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers)
|
||||
logical_shapes = map(shape_as_value, polymorphic_shapes)
|
||||
# TODO(mattjj): generalize mask rule signature
|
||||
|
@ -167,19 +167,19 @@ class JaxprTrace(Trace):
|
||||
return primitive.bind(*consts, **params)
|
||||
tracers = map(self.instantiate_const, tracers)
|
||||
avals = [t.aval for t in tracers]
|
||||
out_aval = primitive.abstract_eval(*avals, **params)
|
||||
out_aval, effects = primitive.abstract_eval(*avals, **params)
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
if primitive.multiple_results:
|
||||
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
|
||||
for aval in out_aval]
|
||||
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source)
|
||||
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
else:
|
||||
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
|
||||
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
|
||||
params, source)
|
||||
params, effects, source)
|
||||
return out_tracer
|
||||
|
||||
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
@ -222,7 +222,7 @@ class JaxprTrace(Trace):
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
|
||||
out_tracers, primitive, staged_params, source)
|
||||
out_tracers, primitive, staged_params, jaxpr.effects, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
|
||||
@ -288,6 +288,7 @@ class JaxprTrace(Trace):
|
||||
for a in out_avals]
|
||||
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
|
||||
out_tracers, primitive, staged_params,
|
||||
jaxpr.effects,
|
||||
source_info_util.current())
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
|
||||
@ -313,7 +314,8 @@ class JaxprTrace(Trace):
|
||||
new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, source)
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
|
||||
jaxpr.effects, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
|
||||
@ -352,7 +354,7 @@ class JaxprTrace(Trace):
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
|
||||
primitive, staged_params, source)
|
||||
primitive, staged_params, jaxpr.effects, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
|
||||
@ -412,6 +414,7 @@ class JaxprTrace(Trace):
|
||||
dict(fun_jaxpr=closed_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
||||
num_consts=len(consts) + len(env)),
|
||||
jaxpr.effects,
|
||||
source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
@ -436,7 +439,7 @@ class JaxprTrace(Trace):
|
||||
in_tracers = map(self.instantiate_const, tracers)
|
||||
new_params = dict(params, call=call)
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params,
|
||||
source_info_util.current())
|
||||
core.no_effects, source_info_util.current())
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
@ -473,6 +476,7 @@ class JaxprTrace(Trace):
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
||||
num_consts=len(consts) + len(env),
|
||||
bwd=bwd, out_trees=out_trees),
|
||||
jaxpr.effects,
|
||||
source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
@ -636,12 +640,14 @@ class JaxprEqnRecipe(NamedTuple):
|
||||
outvars: 'Sequence[ref[JaxprTracer]]'
|
||||
primitive: Primitive
|
||||
params: Dict[str, Any]
|
||||
effects: core.Effects
|
||||
source_info: source_info_util.SourceInfo
|
||||
|
||||
def new_eqn_recipe(invars: Sequence[JaxprTracer],
|
||||
outvars: Sequence[JaxprTracer],
|
||||
primitive: Primitive,
|
||||
params: Dict[str, Any],
|
||||
effects: core.Effects,
|
||||
source_info: source_info_util.SourceInfo
|
||||
) -> JaxprEqnRecipe:
|
||||
"""Constructs a new JaxEqnRecipe.
|
||||
@ -665,17 +671,17 @@ def new_eqn_recipe(invars: Sequence[JaxprTracer],
|
||||
assert ("donated_invars" in params and
|
||||
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
|
||||
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
|
||||
params, source_info)
|
||||
params, effects, source_info)
|
||||
|
||||
|
||||
def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
|
||||
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
|
||||
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
|
||||
_, in_tracers, out_tracer_refs, primitive, params, effects, source_info = recipe
|
||||
out_tracers = [t_ref() for t_ref in out_tracer_refs]
|
||||
invars = [getvar(t) for t in in_tracers]
|
||||
outvars = [core.DropVar(core.abstract_unit) if t is None
|
||||
else cast(Var, getvar(t)) for t in out_tracers]
|
||||
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
|
||||
return new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
|
||||
|
||||
def tracers_to_jaxpr(
|
||||
in_tracers: Sequence[JaxprTracer],
|
||||
@ -738,7 +744,9 @@ def tracers_to_jaxpr(
|
||||
env_vars, env_vals = unzip2(env.items())
|
||||
const_vars, const_vals = unzip2(consts.items())
|
||||
# The env_vars are pre-pended to the invars
|
||||
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
|
||||
effects = core.join_effects(*(eqn.effects for eqn in eqns))
|
||||
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers),
|
||||
eqns, effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
||||
return jaxpr, const_vals, env_vals
|
||||
|
||||
@ -748,7 +756,8 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
||||
lifted_jaxpr = Jaxpr(constvars=(),
|
||||
invars=jaxpr.constvars + jaxpr.invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
||||
effects=jaxpr.effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
|
||||
return lifted_jaxpr
|
||||
|
||||
@ -756,7 +765,8 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
||||
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
|
||||
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
|
||||
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
||||
effects=jaxpr.effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr)
|
||||
return converted_jaxpr
|
||||
|
||||
@ -913,7 +923,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
for x in jaxpr_unknown.outvars]
|
||||
new_params = dict(params, call_jaxpr=jaxpr_unknown, differentiated=True)
|
||||
recipe = new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_call_p,
|
||||
new_params, source_info_util.current())
|
||||
new_params, jaxpr_unknown.effects, source_info_util.current())
|
||||
for t in out_jaxpr_tracers: t.recipe = recipe
|
||||
return _zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
|
||||
else:
|
||||
@ -960,7 +970,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
|
||||
eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p,
|
||||
new_params, source_info_util.current())
|
||||
new_params, new_jaxpr.effects, source_info_util.current())
|
||||
for t in unknown_output_tracers: t.recipe = eqn
|
||||
return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
|
||||
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
||||
@ -1009,8 +1019,7 @@ def _partial_eval_jaxpr_custom(
|
||||
map(write, unks_out, inst_out, eqn.outvars)
|
||||
elif any(unks_in):
|
||||
inputs = map(ensure_instantiated, inst_in, eqn.invars)
|
||||
staged_eqns.append(new_jaxpr_eqn(inputs, eqn.outvars, eqn.primitive,
|
||||
eqn.params, eqn.source_info))
|
||||
staged_eqns.append(eqn.replace(invars=inputs))
|
||||
map(partial(write, True, True), eqn.outvars)
|
||||
else:
|
||||
known_eqns.append(eqn)
|
||||
@ -1018,8 +1027,7 @@ def _partial_eval_jaxpr_custom(
|
||||
map(partial(write, False, False), eqn.outvars)
|
||||
else:
|
||||
inputs = map(ensure_instantiated, inst_in, eqn.invars)
|
||||
staged_eqns.append(new_jaxpr_eqn(inputs, eqn.outvars, eqn.primitive,
|
||||
eqn.params, eqn.source_info))
|
||||
staged_eqns.append(eqn.replace(invars=inputs))
|
||||
map(partial(write, False, True), eqn.outvars)
|
||||
out_unknowns, out_inst = unzip2(map(read, jaxpr.outvars))
|
||||
assert all(type(v) is Var for v in residuals), residuals
|
||||
@ -1027,11 +1035,15 @@ def _partial_eval_jaxpr_custom(
|
||||
ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
|
||||
outs_known_, _ = partition_list(out_unknowns, jaxpr.outvars)
|
||||
outs_known = [x for x in outs_known_ if x.aval is not abstract_unit]
|
||||
jaxpr_known = Jaxpr((), ins_known, [*outs_known, *residuals], known_eqns)
|
||||
known_effects = core.join_effects(*(eqn.effects for eqn in known_eqns))
|
||||
jaxpr_known = Jaxpr((), ins_known, [*outs_known, *residuals], known_eqns,
|
||||
known_effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
|
||||
|
||||
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
|
||||
jaxpr_staged = Jaxpr((), [*residuals, *jaxpr.invars], outs_staged, staged_eqns)
|
||||
staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns))
|
||||
jaxpr_staged = Jaxpr((), [*residuals, *jaxpr.invars], outs_staged,
|
||||
staged_eqns, staged_effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
|
||||
|
||||
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals)
|
||||
@ -1087,9 +1099,10 @@ def call_partial_eval_custom_rule(
|
||||
params_known, params_staged = params_updater(
|
||||
unks_in, kept_outs_known, kept_outs_staged, num_res, params_known, params_staged)
|
||||
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, eqn.source_info)
|
||||
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
|
||||
eqn_staged = new_jaxpr_eqn([*residuals, *eqn.invars], out_binders_staged,
|
||||
eqn.primitive, params_staged, eqn.source_info)
|
||||
eqn.primitive, params_staged,
|
||||
jaxpr_staged.effects, eqn.source_info)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is Var and not inst]
|
||||
@ -1142,7 +1155,7 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: List[bool]
|
||||
new_jaxpr = Jaxpr((),
|
||||
[v for v, b in zip(jaxpr.invars, used_inputs) if b],
|
||||
[v for v, b in zip(jaxpr.outvars, used_outputs) if b],
|
||||
new_eqns[::-1])
|
||||
new_eqns[::-1], jaxpr.effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
|
||||
|
||||
return new_jaxpr, used_inputs
|
||||
@ -1160,7 +1173,7 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
new_params = update_params(new_params, used_inputs, 0)
|
||||
new_eqn = new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, eqn.source_info)
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
||||
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
|
||||
@ -1189,14 +1202,15 @@ def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False)
|
||||
new_eqns.append(eqn)
|
||||
needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
|
||||
new_eqns = new_eqns[::-1]
|
||||
return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars, new_eqns)
|
||||
return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars, new_eqns,
|
||||
jaxpr.effects)
|
||||
|
||||
@weakref_lru_cache
|
||||
def _drop_vars(jaxpr: Jaxpr, drop_ins: Tuple[bool, ...], drop_outs: Tuple[bool, ...]):
|
||||
return Jaxpr(jaxpr.constvars,
|
||||
[v for v, d in zip(jaxpr.invars, drop_ins) if not d],
|
||||
[v for v, d in zip(jaxpr.outvars, drop_outs) if not d],
|
||||
jaxpr.eqns)
|
||||
jaxpr.eqns, jaxpr.effects)
|
||||
|
||||
|
||||
def _reconstruct_pval(pval1: PartialVal, const2: core.Value):
|
||||
@ -1215,7 +1229,8 @@ def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) ->
|
||||
assert len(closed_jaxpr.in_avals) == len(to_move)
|
||||
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
|
||||
new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars,
|
||||
closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns)
|
||||
closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns,
|
||||
closed_jaxpr.jaxpr.effects)
|
||||
new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
|
||||
return new_closed_jaxpr
|
||||
|
||||
@ -1286,6 +1301,7 @@ class JaxprStackFrame:
|
||||
tracers: List[DynamicJaxprTracer] # hold onto strong refs for all tracers
|
||||
eqns: List[JaxprEqn]
|
||||
invars: List[Var]
|
||||
effects: core.Effects
|
||||
|
||||
def __init__(self):
|
||||
self.gensym = core.gensym()
|
||||
@ -1295,13 +1311,18 @@ class JaxprStackFrame:
|
||||
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
|
||||
self.eqns = [] # cleared when we pop frame from main
|
||||
self.invars = []
|
||||
self.effects = set()
|
||||
|
||||
def add_eqn(self, eqn: core.JaxprEqn):
|
||||
self.eqns.append(eqn)
|
||||
self.effects |= eqn.effects
|
||||
|
||||
def to_jaxpr(self, out_tracers):
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
||||
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns)
|
||||
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, self.effects)
|
||||
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
return jaxpr, constvals
|
||||
@ -1335,8 +1356,7 @@ def _const_folding_and_forwarding(jaxpr, constvals):
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
# always apply invar substitutions
|
||||
eqn = JaxprEqn([var_subs.get(v, v) for v in eqn.invars], eqn.outvars,
|
||||
eqn.primitive, eqn.params, eqn.source_info)
|
||||
eqn = eqn.replace(invars=[var_subs.get(v, v) for v in eqn.invars])
|
||||
# if any inputs are constants and we have a constant-folding rule, apply it
|
||||
if eqn.primitive in const_fold_rules and any(v in consts for v in eqn.invars):
|
||||
consts_in = [consts.get(v) for v in eqn.invars]
|
||||
@ -1357,7 +1377,7 @@ def _const_folding_and_forwarding(jaxpr, constvals):
|
||||
new_eqns.append(eqn)
|
||||
new_constvars, new_constvals = unzip2(consts.items())
|
||||
new_outvars = [var_subs.get(v, v) for v in jaxpr.outvars]
|
||||
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns)
|
||||
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns, jaxpr.effects)
|
||||
return new_jaxpr, new_constvals
|
||||
|
||||
ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
|
||||
@ -1402,10 +1422,10 @@ def _inline_literals(jaxpr, constvals):
|
||||
invars = [lit(v) or var(v) for v in eqn.invars]
|
||||
outvars = [var(v) if v in used else core.DropVar(v.aval)
|
||||
for v in eqn.outvars]
|
||||
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
|
||||
eqn.source_info))
|
||||
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
|
||||
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
|
||||
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
|
||||
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
|
||||
jaxpr.effects)
|
||||
return new_jaxpr, new_constvals
|
||||
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
@ -1485,14 +1505,14 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
def default_process_primitive(self, primitive, tracers, params):
|
||||
avals = [t.aval for t in tracers]
|
||||
out_avals = primitive.abstract_eval(*avals, **params)
|
||||
out_avals, effects = primitive.abstract_eval(*avals, **params)
|
||||
out_avals = [out_avals] if not primitive.multiple_results else out_avals
|
||||
source_info = source_info_util.current()
|
||||
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
|
||||
self.frame.eqns.append(eqn)
|
||||
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers if primitive.multiple_results else out_tracers.pop()
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
@ -1520,7 +1540,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
new_params = update_params(new_params, [True] * len(tracers),
|
||||
len(consts) + len(dim_tracers))
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
|
||||
call_primitive, new_params, source_info)
|
||||
call_primitive, new_params,
|
||||
new_params['call_jaxpr'].effects, source_info)
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
|
||||
@ -1554,7 +1575,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
if update_params:
|
||||
new_params = update_params(new_params, [True] * len(tracers), len(consts))
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
|
||||
new_params, source_info)
|
||||
new_params, new_params['call_jaxpr'].effects, source_info)
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
|
||||
@ -1577,6 +1598,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
dict(fun_jaxpr=closed_fun_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
||||
num_consts=len(consts)),
|
||||
fun_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
@ -1601,6 +1623,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
||||
num_consts=len(consts),
|
||||
bwd=bwd, out_trees=out_trees),
|
||||
fun_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
@ -1640,8 +1663,9 @@ class DynamicJaxprTrace(core.Trace):
|
||||
transpose_jaxpr_thunk=transpose_jaxpr_thunk,
|
||||
out_types=out_types, res_tree=res_tree,
|
||||
lin_tree=lin_tree, out_tree=out_tree),
|
||||
closed_call_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
self.frame.eqns.append(eqn)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
|
||||
|
133
tests/jaxpr_effects_test.py
Normal file
133
tests/jaxpr_effects_test.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
effect_p = core.Primitive('effect')
|
||||
effect_p.multiple_results = True
|
||||
|
||||
@effect_p.def_effectful_abstract_eval
|
||||
def _(*, effect):
|
||||
return [], {effect}
|
||||
|
||||
|
||||
class JaxprEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_trivial_jaxpr_has_no_effects(self):
|
||||
def f(x):
|
||||
return x + 1.
|
||||
jaxpr = jax.make_jaxpr(f)(2.)
|
||||
self.assertEqual(core.no_effects, jaxpr.effects)
|
||||
|
||||
def test_effectful_primitive_in_jaxpr_creates_effects(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
jaxpr = jax.make_jaxpr(f)(2.)
|
||||
self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects)
|
||||
self.assertEqual({'foo'}, jaxpr.effects)
|
||||
|
||||
def test_different_effects_in_jaxpr(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x + 1.
|
||||
jaxpr = jax.make_jaxpr(f)(2.)
|
||||
self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects)
|
||||
self.assertEqual({'bar'}, jaxpr.jaxpr.eqns[1].effects)
|
||||
self.assertEqual({'foo', 'bar'}, jaxpr.effects)
|
||||
|
||||
def test_jaxpr_typecheck_should_verify_eqn_effects_are_subset(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x + 1.
|
||||
jaxpr = jax.make_jaxpr(f)(2.).jaxpr
|
||||
|
||||
# Edit jaxpr to make its type wrong
|
||||
jaxpr = jaxpr.replace(effects={'foo'})
|
||||
|
||||
with self.assertRaisesRegex(core.JaxprTypeError,
|
||||
'Equation effects are not subset of Jaxpr effects.'):
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_effects_disallowed_in_cond(self):
|
||||
def f1(x):
|
||||
def true_fun(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x
|
||||
def false_fun(x):
|
||||
return x
|
||||
return lax.cond(True, true_fun, false_fun, x)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f1)(2.)
|
||||
|
||||
def f2(x):
|
||||
def true_fun(x):
|
||||
return x
|
||||
def false_fun(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x
|
||||
return lax.cond(True, true_fun, false_fun, x)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f2)(2.)
|
||||
|
||||
def test_effects_disallowed_in_while(self):
|
||||
def f1(x):
|
||||
def cond_fun(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return False
|
||||
def body_fun(x):
|
||||
return x
|
||||
return lax.while_loop(cond_fun, body_fun, x)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f1)(2.)
|
||||
|
||||
def f2(x):
|
||||
def cond_fun(x):
|
||||
return False
|
||||
def body_fun(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x
|
||||
return lax.while_loop(cond_fun, body_fun, x)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f2)(2.)
|
||||
|
||||
def test_effects_disallowed_in_scan(self):
|
||||
|
||||
def f(x):
|
||||
def body(carry, x):
|
||||
effect_p.bind(effect='foo')
|
||||
return carry, x
|
||||
return lax.scan(body, x, jnp.arange(4))
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
jax.make_jaxpr(f)(2.)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -2888,20 +2888,20 @@ class LaxNamedShapeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_abstract_eval(self):
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
|
||||
out = lax.sin_p.abstract_eval(aval1)
|
||||
out, _ = lax.sin_p.abstract_eval(aval1)
|
||||
self.assertEqual(out, aval1)
|
||||
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
|
||||
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
|
||||
expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
|
||||
out = lax.add_p.abstract_eval(aval1, aval2)
|
||||
out, _ = lax.add_p.abstract_eval(aval1, aval2)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
def test_abstract_eval_collective(self):
|
||||
with core.extend_axis_env('i', 10, None):
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
|
||||
expected = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
|
||||
out, = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
|
||||
(out,), _ = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user