add cond dce rule and custom-policy partial eval rule

This commit is contained in:
Matthew Johnson 2022-07-19 08:53:23 -07:00
parent 9e6254e058
commit ec9f9c3c07
4 changed files with 454 additions and 53 deletions

View File

@ -17,8 +17,9 @@ import functools
from functools import partial
import inspect
import itertools
import operator
from typing import Callable, Sequence
from typing import Callable, Sequence, List, Tuple
from jax import core
from jax import linear_util as lu
@ -36,7 +37,8 @@ from jax._src import source_info_util
from jax._src import util
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import safe_map, extend_name_stack, split_list
from jax._src.util import (safe_map, extend_name_stack, split_list,
partition_list)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
import numpy as np
@ -52,7 +54,7 @@ from jax._src.lax.control_flow.common import (
allowed_effects,
)
_map, unsafe_map = safe_map, map
map, unsafe_map = safe_map, map
# For backward compatibility with a previous switch/cond calling convention,
@ -124,7 +126,7 @@ def switch(index, branches: Sequence[Callable], *operands,
return branches[int(index)](*operands)
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(_map(_abstractify, ops))
ops_avals = tuple(map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
@ -216,7 +218,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
linear_ops, ops_tree2 = tree_flatten(linear)
if ops_tree != ops_tree2:
raise TypeError('linear tree and operand tree mismatch')
ops_avals = tuple(_map(_abstractify, ops))
ops_avals = tuple(map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
@ -285,7 +287,7 @@ def _cond_abstract_eval(*args, branches, **kwargs):
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
joined_effects = core.join_effects(*(b.effects for b in branches))
return _map(raise_to_shaped, branches[0].out_avals), joined_effects
return map(raise_to_shaped, branches[0].out_avals), joined_effects
def _bcast_select(pred, on_true, on_false):
if np.ndim(pred) != np.ndim(on_true):
@ -407,7 +409,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
branches_known, all_res_avals, res_avals_per_branch, num_known_outs)
branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
branches_unknown, all_res_avals, res_avals_per_branch)
assert all(all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
for j in branches_known[1:])
in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
@ -419,7 +421,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
index_tracer = trace.instantiate_const(tracers[0])
ops_tracers = [trace.instantiate_const(t)
for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk]
res_tracers = _map(trace.new_instantiated_const, res)
res_tracers = map(trace.new_instantiated_const, res)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in branches_unknown[0].out_avals]
linear_unknown = ([False] * num_res +
@ -429,10 +431,84 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe(
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
core.no_effects, source)
core.join_effects(*(j.effects for j in branches_unknown)), source)
for t in out_tracers: t.recipe = eqn
return util.merge_lists(out_uks, out_consts, out_tracers)
# TODO(mattjj): de-duplicate with _cond_partial_eval
def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
index_uk, *ops_uk = unks_in
assert not index_uk # only possible with old-style remat
branches = eqn.params['branches']
# First, compute output unknowns (unks_out), where an output of the cond is
# unknown if it would be unknown on any of the branches.
unks_out: List[bool] = [False] * len(eqn.outvars)
for jaxpr in branches:
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
ensure_out_unknowns=False, ensure_out_inst=True, saveable=saveable)
unks_out = map(operator.or_, unks_out, unks_out_)
# Next, use the computed output unknowns to build a known jaxpr and a staged
# jaxpr for each branch.
branches_known_ : List[core.ClosedJaxpr] = []
branches_staged_: List[core.ClosedJaxpr] = []
branch_res_avals: List[core.AbstractValue] = []
for jaxpr in branches:
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
ensure_out_unknowns=unks_out, ensure_out_inst=True,
saveable=saveable)
branches_known_.append( core.ClosedJaxpr(jaxpr_known, jaxpr.consts))
branches_staged_.append(core.ClosedJaxpr(jaxpr_staged, jaxpr.consts))
branch_res_avals.append(branches_staged_[-1].in_avals[:num_res])
# Residuals may differ across branches, so we merge them, then use the merged
# residuals to join the outputs of all branches to the same type.
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
num_res = len(all_res_avals)
num_known_outs = len(unks_out) - sum(unks_out)
branches_known = _join_cond_outputs(
branches_known_, all_res_avals, res_avals_per_branch, num_known_outs)
branches_staged = _join_cond_pe_staged_jaxpr_inputs(
branches_staged_, all_res_avals, res_avals_per_branch)
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
for j in branches_known[1:])
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
# passing in_inst argument to partial_eval_jaxpr_custom above).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_in = [True] * len(inst_in)
# Create residual variables.
newvar = core.gensym()
res_binders = map(newvar, all_res_avals)
# Build the known eqn.
ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
linear_known = [l for l, uk in zip(eqn.params['linear'], ops_uk) if not uk]
params_known = dict(branches=branches_known, linear=tuple(linear_known))
effects_known = core.join_effects(*(b.effects for b in branches_known))
eqn_known = pe.new_jaxpr_eqn(
ins_known, [*out_binders_known, *res_binders], cond_p, params_known,
effects_known, eqn.source_info)
# Build the staged eqn.
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
linear_staged = [False] * len(res_binders) + list(eqn.params['linear'])
params_staged = dict(branches=branches_staged, linear=tuple(linear_staged))
effects_staged = core.join_effects(*(b.effects for b in branches_staged))
eqn_staged = pe.new_jaxpr_eqn(
[eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged,
cond_p, params_staged, effects_staged, eqn.source_info)
new_vars = [*new_inst, *res_binders]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
# When partially evaluating conditionals, each branch produces residuals
# depending on the computation carried out by the branch, and a corresponding
# staged jaxpr that accepts those residuals as its first few inputs. The
@ -462,7 +538,7 @@ def _merge_branch_residuals(branch_res_avals):
def enumerate_equal(xs):
counts = {v: itertools.count() for v in set(xs)}
return [(x, next(counts[x])) for x in xs]
branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals)
branch_res_tagged_avals = map(enumerate_equal, branch_res_avals)
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
indices = {v: i for i, v in enumerate(all_tagged_avals)}
branch_indices = [
@ -480,13 +556,13 @@ def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
def f_aug(*args):
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
aug_residuals = map(ad_util.zeros_like_aval, all_res_avals)
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
return outs + list(aug_residuals)
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
# This function augments branch inputs to agree with the merged residual format:
# each branch is made to accept all residuals, even though it will ignore those
@ -494,7 +570,7 @@ def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
res_aval_indices_per_jaxpr):
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
all_res_vars = _map(newvar, all_res_avals)
all_res_vars = map(newvar, all_res_avals)
def augment_jaxpr(jaxpr, res_indices):
num_res = len(res_indices)
@ -509,15 +585,51 @@ def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
return jaxpr_aug
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
def _ordered_unique(xs):
d = collections.OrderedDict((x, None) for x in xs)
return list(d.keys())
def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn,
) -> Tuple[List[bool], core.JaxprEqn]:
if not config.after_neurips:
return [True] * len(eqn.params['jaxpr'].in_avals), eqn
closed_branches = eqn.params['branches']
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]
# First, compute which inputs are used in any branch (not including `pred`).
used_inputs: List[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
for jaxpr in branches:
_, used_inputs_ = pe.dce_jaxpr(jaxpr, used_outputs, instantiate=False)
used_inputs = map(operator.or_, used_inputs, used_inputs_)
# Next, compute DCEd branches, instantiating according to used_inputs.
dce_branches_ = [pe.dce_jaxpr(jaxpr, used_outputs, instantiate=used_inputs)[0]
for jaxpr in branches]
dce_branches = [core.ClosedJaxpr(jaxpr, closed_jaxpr.consts)
for closed_jaxpr, jaxpr in zip(closed_branches, dce_branches_)]
# Finally, update parameters and form the new eqn.
dce_linear = [l for l, used in zip(eqn.params['linear'], used_inputs) if used]
new_params = dict(eqn.params, branches=tuple(dce_branches),
linear=tuple(dce_linear))
new_effects = core.join_effects(*(b.effects for b in dce_branches))
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_effects, eqn.source_info)
assert all(len(new_eqn.invars ) == 1 + len(jaxpr.in_avals )
for jaxpr in new_params['branches'])
assert all(len(new_eqn.outvars) == len(jaxpr.out_avals)
for jaxpr in new_params['branches'])
return [True, *used_inputs], new_eqn
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
primal_avals = _map(raise_to_shaped, primal_avals)
primal_avals = map(raise_to_shaped, primal_avals)
@lu.wrap_init
def transposed(*args):
@ -526,13 +638,13 @@ def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
cts_in = ad.backward_pass(
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
return map(ad.instantiate_zeros_aval, primal_avals, cts_in)
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
index, *ops = args
in_avals = _map(raise_to_shaped, branches[0].in_avals)
in_avals = map(raise_to_shaped, branches[0].in_avals)
num_res = len(ops) - sum(linear)
branches_trans = tuple(
@ -544,12 +656,12 @@ def _cond_transpose(reduce_axes, cts, *args, branches, linear):
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
res = ops[:num_res]
cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
cts = map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
linear_trans = (False,) * num_res + (True,) * len(cts)
out = cond_p.bind(
index, *res, *cts, branches=branches_trans, linear=linear_trans)
assert all(_map(core.typecheck, lin_in_avals, out))
assert all(map(core.typecheck, lin_in_avals, out))
out_iter = iter(out)
out = [next(out_iter) if l else None for l in linear]
@ -589,11 +701,11 @@ def _cond_typecheck(*in_atoms, branches, linear):
raise core.JaxprTypeError(
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
if not all(map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching input types: '
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
if not all(map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching output types: '
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
@ -607,7 +719,7 @@ def _cond_typecheck(*in_atoms, branches, linear):
if index_aval.dtype != np.int32:
raise core.JaxprTypeError(
f'cond called with index of type {index_aval.dtype} instead of int32')
if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)):
if not all(map(core.typecompat, jaxpr0.in_avals, op_avals)):
raise core.JaxprTypeError(
f'cond branches take input types {jaxpr0_in_avals_str}, '
f'called with operands of type {_avals_short(op_avals)}')
@ -620,7 +732,7 @@ def _cond_typecheck(*in_atoms, branches, linear):
def cond_bind(*args, branches, linear):
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
avals = map(core.get_aval, args)
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
_cond_typecheck(*in_atoms, branches=branches, linear=linear)
for jaxpr in branches:
@ -638,8 +750,8 @@ pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
xla.register_initial_style_primitive(cond_p)
core.custom_typechecks[cond_p] = _cond_typecheck
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
pe.dce_rules[cond_p] = _cond_dce_rule
def _cond_lowering(ctx, index, *args, branches, linear):
del linear # Unused.
@ -650,7 +762,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
tokens_in = ctx.tokens_in.subset(ordered_effects)
output_token_types = [mlir.token_type() for _ in ordered_effects]
output_types = [
*output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)]
*output_token_types, *map(mlir.aval_to_ir_types, ctx.avals_out)]
flat_output_types = util.flatten(output_types)
# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
@ -666,13 +778,13 @@ def _cond_lowering(ctx, index, *args, branches, linear):
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
out_vals, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr.jaxpr, tokens_in,
_map(mlir.ir_constants, jaxpr.consts),
*_map(mlir.wrap_singleton_ir_values, args))
map(mlir.ir_constants, jaxpr.consts),
*map(mlir.wrap_singleton_ir_values, args))
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
mhlo.ReturnOp(util.flatten(out_vals))
tokens_and_outputs = util.unflatten(case_op.results, _map(len, output_types))
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
return outputs

View File

@ -800,18 +800,17 @@ def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
else:
assert False, "Fixpoint not reached"
core.check_jaxpr(jaxpr.jaxpr)
if config.jax_enable_checks: core.check_jaxpr(jaxpr.jaxpr)
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
new_params = dict(eqn.params, num_consts=sum(used_consts),
num_carry=sum(used_carry_in), linear=tuple(new_linear),
jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts))
new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs)
if used],
[v for v, used in zip(eqn.outvars, used_outputs)
if used],
eqn.primitive, new_params, eqn.effects,
eqn.source_info)
# TODO(mattjj,sharadmv): don't assume effects are never DCE'd?
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, eqn.effects, eqn.source_info)
assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals )
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
return used_inputs, new_eqn
@ -822,7 +821,8 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_ys = len(jaxpr.out_avals) - num_carry
# Fixpoint (currently trivial on 'inst_in')
# Fixpoint (trivial on 'inst_in', since we might as well make all inputs
# available as DCE can subsequently prune any unused ones)
const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry])
for _ in range(1 + len(carry_uk)):
unks_in = const_uk + carry_uk + xs_uk
@ -831,7 +831,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=[True] * len(unks_in),
ensure_out_unknowns=carry_uk + [False] * num_ys,
ensure_out_inst=True, saveable=saveable)
carry_uk_out , ys_uk = split_list(unks_out, [num_carry])
carry_uk_out, ys_uk = split_list(unks_out, [num_carry])
if carry_uk_out == carry_uk:
break
else:
@ -841,13 +841,14 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
# Ensure residuals are all moved to the back.
# Move all residual binders to the back of jaxpr_staged so they're extensive.
# TODO(mattjj): make jaxpr_staged only take instantiated inputs
res_avals = jaxpr_staged.in_avals[:num_res]
jaxpr_staged = pe.move_binders_to_back(
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))
# Instantiate all inputs (b/c jaxpr_staged takes all inputs).
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
# passing in_inst argument to partial_eval_jaxpr_custom above).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_in = [True] * len(inst_in)
@ -907,6 +908,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects,
eqn.source_info)
# Create the staged eqn.
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) +
[False] * len(extensive_res))

View File

@ -4546,7 +4546,7 @@ class RematTest(jtu.JaxTestCase):
# Two sine calls in the backward pass because while we don't save sines
# within the (rematted) body function, we can save the scan carry, which
# effectively saves one sine. Three cosines for the Jacoian coefficients.
# effectively saves one sine. Three cosines for the Jacobian coefficients.
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' cos '), 3)
# Six calls to dot_general in the backward pass because we save the primal
@ -4805,7 +4805,7 @@ class RematTest(jtu.JaxTestCase):
('', api.remat),
('_new', new_checkpoint),
])
def test_const_in_jvp(self, remat):
def test_const_in_jvp_scan(self, remat):
@api.custom_jvp
def f(x):
return x * np.arange(3.)
@ -4980,6 +4980,233 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 2)
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_remat_of_cond(self, remat):
true_fn = lambda c: (jnp.sin(c), jnp.sin(c))
false_fn = lambda c: (jnp.sin(c), jnp.sin(c))
f = lambda x: lax.cond(x > 0., true_fn, false_fn, x)
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
self.assertNotIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
true_fn = lambda c: jnp.sin(jnp.sin(c))
false_fn = lambda c: c
f = lambda x: lax.cond(x > 0., true_fn, false_fn, x)
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_const_in_jvp_cond(self, remat):
@api.custom_jvp
def f(x):
return x * np.arange(3.)
@f.defjvp
def f_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
return f(x), xdot * np.arange(3.)
@remat
def g(x):
y = jax.lax.cond(x.sum() > 0, f, lambda x: x, x)
return y.sum()
jax.grad(g)(jnp.arange(3.)) # doesn't crash
def test_remat_checkpoint_dots_inside_cond(self):
x = jnp.ones((5,))
def f(W):
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
def f(x):
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
return x
return lax.cond(x.sum() > 0, f, lambda x: x, x)
_, f_vjp = api.vjp(f, jnp.ones((5, 5)))
jaxpr_text = str(f_vjp.args[0].func.args[1])
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' cos '), 3)
# Five calls to dot_general in the backward pass because we have two for
# each forward-pass dot, except for the first which only has one (as we are
# differentiating with respect to only W and not x).
self.assertEqual(jaxpr_text.count(' dot_'), 5)
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
modes=['fwd', 'rev'])
def test_remat_checkpoint_dots_outside_cond(self):
# see also above test test_remat_checkpoint_dots_inside_cond
# The behavior between the two tests is essentially identical, whereas for
# scan different things are saved based on this difference in remat
# placement (because of the carry).
x = jnp.ones((5,))
@partial(new_checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def f(W):
def f(x):
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
return x
return lax.cond(x.sum() > 0, f, lambda x: x, x)
_, f_vjp = api.vjp(f, jnp.ones((5, 5)))
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' cos '), 3)
self.assertEqual(jaxpr_text.count(' dot_'), 5)
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
modes=['fwd', 'rev'])
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_cond_policy(self):
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
f = new_checkpoint(lambda x: lax.cond(x > 0, jnp.sin, lambda x: x, x),
policy=save_cos)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_cond_funky_custom_jvp(self):
def cond_apply(f, x):
return lax.cond(x.sum() > -jnp.inf, f, lambda x: x, x)
@api.custom_jvp
def sin(x):
return jnp.sin(x)
def sin_jvp(primals, tangents):
x, = primals
xdot, = tangents
y, c = jax.jit(lambda: (jnp.sin(x), jnp.cos(x)))()
ydot = c * xdot
return y, ydot
sin.defjvp(sin_jvp)
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
f = new_checkpoint(partial(cond_apply, sin), policy=save_cos)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = new_checkpoint(partial(cond_apply, sin), policy=save_sin)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
f = new_checkpoint(partial(cond_apply, sin),
policy=jax.checkpoint_policies.everything_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
f = new_checkpoint(partial(cond_apply, sin),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1)
self.assertEqual(jaxpr_text.count(' cos '), 2)
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_cond_funky_custom_jvp2(self):
# Like the above test but instead of using jit inside custom_jvp, use cond.
def cond_apply(f, x):
return lax.cond(True, f, lambda x: x, x)
@api.custom_jvp
def sin(x):
return jnp.sin(x)
def sin_jvp(primals, tangents):
x, = primals
xdot, = tangents
y, c = cond_apply(lambda xs: (jnp.sin(xs[0]), jnp.cos(xs[1])), (x, x))
ydot = c * xdot
return y, ydot
sin.defjvp(sin_jvp)
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
f = new_checkpoint(partial(cond_apply, sin), policy=save_cos)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = new_checkpoint(partial(cond_apply, sin), policy=save_sin)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
f = new_checkpoint(partial(cond_apply, sin),
policy=jax.checkpoint_policies.everything_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
f = new_checkpoint(partial(cond_apply, sin),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1)
self.assertEqual(jaxpr_text.count(' cos '), 2)
class JaxprTest(jtu.JaxTestCase):
@ -5367,6 +5594,56 @@ class DCETest(jtu.JaxTestCase):
result2 = core.eval_jaxpr(jaxpr_pruned, consts, *pruned_args)
self.assertAllClose(result1, result2)
def test_dce_jaxpr_cond_trivial(self):
x = jnp.array(1., dtype='float32')
# start with 7 eqns, use both outputs so nothing can be pruned
def f(x1, x2):
return lax.cond(x1 > 0,
lambda x1, x2: (jnp.sin(x1), jnp.sin(x2)),
lambda x1, x2: (jnp.sin(x1), jnp.sin(x2)),
x1, x2)
jaxpr = jax.make_jaxpr(f)(x, x).jaxpr
self.assert_dce_result(jaxpr, [True, True], [True, True], 7)
# use neither output so everything can be pruned
self.assert_dce_result(jaxpr, [False, False], [False, False], 0)
def test_dce_jaxpr_cond_nontrivial(self):
x = jnp.array(1., dtype='float32')
# start with 7 eqns, dont use an output so an eqn can be trimmed on each
# side and x2 _can_ be pruned
def f(x1, x2):
return lax.cond(x1 > 0,
lambda x1, x2: (jnp.sin(x1), jnp.sin(x2)),
lambda x1, x2: (jnp.sin(x1), jnp.sin(x1)),
x1, x2)
jaxpr = jax.make_jaxpr(f)(x, x).jaxpr
self.assert_dce_result(jaxpr, [True, False], [True, False], 5)
# start with 7 eqns, dont use an output so an eqn can be trimmed on each
# side, but x2 _can't_ be pruned b/c of a swap
def f(x1, x2):
return lax.cond(x1 > 0,
lambda x1, x2: (jnp.sin(x1), jnp.sin(x2)),
lambda x1, x2: (jnp.sin(x2), jnp.sin(x1)),
x1, x2)
jaxpr = jax.make_jaxpr(f)(x, x).jaxpr
self.assert_dce_result(jaxpr, [True, False], [True, True], 5)
# start with 7 eqns, only use x1 on one side and x2 on the other, so we
# can't prune any inputs or eqns
def f(x1, x2):
return lax.cond(x1 > 0,
lambda x1, x2: (jnp.sin(x1), jnp.sin(x1)),
lambda x1, x2: (jnp.sin(x2), jnp.sin(x2)),
x1, x2)
jaxpr = jax.make_jaxpr(f)(x, x).jaxpr
self.assert_dce_result(jaxpr, [True, True], [True, True], 7)
# use only one output, so we can prune eqns but not inputs
self.assert_dce_result(jaxpr, [True, False], [True, True], 5)
class CustomJVPTest(jtu.JaxTestCase):

View File

@ -59,6 +59,22 @@ def cond_via_switch(pred, true_fun, false_fun, op, *args):
index = lax.convert_element_type(pred, np.int32)
return lax.switch(index, [false_fun, true_fun], op)
def cond_with_new_checkpoint(pred, true_fun, false_fun, op, *args):
if args:
true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0]
op = (false_op, true_op)
false_fun = lambda op: _false_fun(op[0])
true_fun = lambda op: _true_fun(op[1])
index = lax.convert_element_type(pred, np.int32)
fn = lambda index, op: lax.switch(index, [false_fun, true_fun], op)
return new_checkpoint(fn)(index, op)
COND_IMPLS = [
(lax.cond, 'cond'),
(cond_via_switch, 'switch'),
(cond_with_new_checkpoint, 'new_checkpoint'),
]
# We wanted to try all scan tests with the scan partial evaluation rule that
# happens under ad_checkpoint.checkpoint, so we make a scan wrapper which
@ -73,12 +89,6 @@ def scan_with_new_checkpoint2(f, *args, **kwargs):
def scan_with_for(f, *args, **kwargs):
return for_loop.scan(f, *args, **kwargs)
COND_IMPLS = [
(lax.cond, 'cond'),
(cond_via_switch, 'switch'),
]
SCAN_IMPLS = [
(lax.scan, 'unroll1'),
(partial(lax.scan, unroll=2), 'unroll2'),
@ -786,7 +796,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(cf(1, x), branch(x))
def testIssue1379(self):
def fun(pred):
return lax.cond(pred, lambda x: (True, x), lambda x: (False, x), pred)
@ -1092,7 +1101,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return 2. * x
def fun(x):
return cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x)
x = 3.14
ans = jax.jvp(fun, (x,), (x,))
@ -1222,7 +1231,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return 2. * x
def fun(x):
return cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x)
x = 3.14
ans = jax.grad(fun)(x)
@ -1240,6 +1249,8 @@ class LaxControlFlowTest(jtu.JaxTestCase):
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondGrad4(self, cond):
if cond is cond_with_new_checkpoint and 'tpu' in jtu.device_under_test():
raise unittest.SkipTest("tpu bug") # TODO(parkers): tpu bug ehibited here
def fun_ref(x, y):
if x < 3:
return 2. * jnp.sin(y)
@ -1250,7 +1261,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def fun(x, y):
return cond(
x < 3,
(), lambda _: 2. * jnp.sin(y),
None, lambda _: 2. * jnp.sin(y),
x, lambda x: 2. * x)
y = 5.8
@ -1761,7 +1772,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
f"got {tree_util.tree_structure(a)} and {tree_util.tree_structure((1, 2))}.")):
lax.scan(lambda c, x: (0, x), (1, 2), a)
@parameterized.named_parameters(
{"testcase_name": f"_{scan_name}",
"scan": scan_impl}