mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
add cond dce rule and custom-policy partial eval rule
This commit is contained in:
parent
9e6254e058
commit
ec9f9c3c07
@ -17,8 +17,9 @@ import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools
|
||||
import operator
|
||||
|
||||
from typing import Callable, Sequence
|
||||
from typing import Callable, Sequence, List, Tuple
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
@ -36,7 +37,8 @@ from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import safe_map, extend_name_stack, split_list
|
||||
from jax._src.util import (safe_map, extend_name_stack, split_list,
|
||||
partition_list)
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
import numpy as np
|
||||
@ -52,7 +54,7 @@ from jax._src.lax.control_flow.common import (
|
||||
allowed_effects,
|
||||
)
|
||||
|
||||
_map, unsafe_map = safe_map, map
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
|
||||
# For backward compatibility with a previous switch/cond calling convention,
|
||||
@ -124,7 +126,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
return branches[int(index)](*operands)
|
||||
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(_map(_abstractify, ops))
|
||||
ops_avals = tuple(map(_abstractify, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
branches, ops_tree, ops_avals, primitive_name='switch')
|
||||
@ -216,7 +218,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
linear_ops, ops_tree2 = tree_flatten(linear)
|
||||
if ops_tree != ops_tree2:
|
||||
raise TypeError('linear tree and operand tree mismatch')
|
||||
ops_avals = tuple(_map(_abstractify, ops))
|
||||
ops_avals = tuple(map(_abstractify, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
@ -285,7 +287,7 @@ def _cond_abstract_eval(*args, branches, **kwargs):
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `cond`: {disallowed_effects}')
|
||||
joined_effects = core.join_effects(*(b.effects for b in branches))
|
||||
return _map(raise_to_shaped, branches[0].out_avals), joined_effects
|
||||
return map(raise_to_shaped, branches[0].out_avals), joined_effects
|
||||
|
||||
def _bcast_select(pred, on_true, on_false):
|
||||
if np.ndim(pred) != np.ndim(on_true):
|
||||
@ -407,7 +409,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
branches_known, all_res_avals, res_avals_per_branch, num_known_outs)
|
||||
branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
|
||||
branches_unknown, all_res_avals, res_avals_per_branch)
|
||||
assert all(all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
||||
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
||||
for j in branches_known[1:])
|
||||
|
||||
in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
|
||||
@ -419,7 +421,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
index_tracer = trace.instantiate_const(tracers[0])
|
||||
ops_tracers = [trace.instantiate_const(t)
|
||||
for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk]
|
||||
res_tracers = _map(trace.new_instantiated_const, res)
|
||||
res_tracers = map(trace.new_instantiated_const, res)
|
||||
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
||||
for aval in branches_unknown[0].out_avals]
|
||||
linear_unknown = ([False] * num_res +
|
||||
@ -429,10 +431,84 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = pe.new_eqn_recipe(
|
||||
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
|
||||
core.no_effects, source)
|
||||
core.join_effects(*(j.effects for j in branches_unknown)), source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return util.merge_lists(out_uks, out_consts, out_tracers)
|
||||
|
||||
# TODO(mattjj): de-duplicate with _cond_partial_eval
|
||||
def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
index_uk, *ops_uk = unks_in
|
||||
assert not index_uk # only possible with old-style remat
|
||||
branches = eqn.params['branches']
|
||||
|
||||
# First, compute output unknowns (unks_out), where an output of the cond is
|
||||
# unknown if it would be unknown on any of the branches.
|
||||
unks_out: List[bool] = [False] * len(eqn.outvars)
|
||||
for jaxpr in branches:
|
||||
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
|
||||
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
|
||||
ensure_out_unknowns=False, ensure_out_inst=True, saveable=saveable)
|
||||
unks_out = map(operator.or_, unks_out, unks_out_)
|
||||
|
||||
# Next, use the computed output unknowns to build a known jaxpr and a staged
|
||||
# jaxpr for each branch.
|
||||
branches_known_ : List[core.ClosedJaxpr] = []
|
||||
branches_staged_: List[core.ClosedJaxpr] = []
|
||||
branch_res_avals: List[core.AbstractValue] = []
|
||||
for jaxpr in branches:
|
||||
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
|
||||
ensure_out_unknowns=unks_out, ensure_out_inst=True,
|
||||
saveable=saveable)
|
||||
branches_known_.append( core.ClosedJaxpr(jaxpr_known, jaxpr.consts))
|
||||
branches_staged_.append(core.ClosedJaxpr(jaxpr_staged, jaxpr.consts))
|
||||
branch_res_avals.append(branches_staged_[-1].in_avals[:num_res])
|
||||
|
||||
# Residuals may differ across branches, so we merge them, then use the merged
|
||||
# residuals to join the outputs of all branches to the same type.
|
||||
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
|
||||
num_res = len(all_res_avals)
|
||||
num_known_outs = len(unks_out) - sum(unks_out)
|
||||
branches_known = _join_cond_outputs(
|
||||
branches_known_, all_res_avals, res_avals_per_branch, num_known_outs)
|
||||
branches_staged = _join_cond_pe_staged_jaxpr_inputs(
|
||||
branches_staged_, all_res_avals, res_avals_per_branch)
|
||||
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
||||
for j in branches_known[1:])
|
||||
|
||||
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
|
||||
# passing in_inst argument to partial_eval_jaxpr_custom above).
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
inst_in = [True] * len(inst_in)
|
||||
|
||||
# Create residual variables.
|
||||
newvar = core.gensym()
|
||||
res_binders = map(newvar, all_res_avals)
|
||||
|
||||
# Build the known eqn.
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
linear_known = [l for l, uk in zip(eqn.params['linear'], ops_uk) if not uk]
|
||||
params_known = dict(branches=branches_known, linear=tuple(linear_known))
|
||||
effects_known = core.join_effects(*(b.effects for b in branches_known))
|
||||
eqn_known = pe.new_jaxpr_eqn(
|
||||
ins_known, [*out_binders_known, *res_binders], cond_p, params_known,
|
||||
effects_known, eqn.source_info)
|
||||
|
||||
# Build the staged eqn.
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
linear_staged = [False] * len(res_binders) + list(eqn.params['linear'])
|
||||
params_staged = dict(branches=branches_staged, linear=tuple(linear_staged))
|
||||
effects_staged = core.join_effects(*(b.effects for b in branches_staged))
|
||||
eqn_staged = pe.new_jaxpr_eqn(
|
||||
[eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged,
|
||||
cond_p, params_staged, effects_staged, eqn.source_info)
|
||||
|
||||
new_vars = [*new_inst, *res_binders]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
||||
|
||||
# When partially evaluating conditionals, each branch produces residuals
|
||||
# depending on the computation carried out by the branch, and a corresponding
|
||||
# staged jaxpr that accepts those residuals as its first few inputs. The
|
||||
@ -462,7 +538,7 @@ def _merge_branch_residuals(branch_res_avals):
|
||||
def enumerate_equal(xs):
|
||||
counts = {v: itertools.count() for v in set(xs)}
|
||||
return [(x, next(counts[x])) for x in xs]
|
||||
branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals)
|
||||
branch_res_tagged_avals = map(enumerate_equal, branch_res_avals)
|
||||
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
|
||||
indices = {v: i for i, v in enumerate(all_tagged_avals)}
|
||||
branch_indices = [
|
||||
@ -480,13 +556,13 @@ def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
|
||||
def f_aug(*args):
|
||||
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
|
||||
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
|
||||
aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
|
||||
aug_residuals = map(ad_util.zeros_like_aval, all_res_avals)
|
||||
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
|
||||
return outs + list(aug_residuals)
|
||||
|
||||
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
|
||||
|
||||
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
||||
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
||||
|
||||
# This function augments branch inputs to agree with the merged residual format:
|
||||
# each branch is made to accept all residuals, even though it will ignore those
|
||||
@ -494,7 +570,7 @@ def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
|
||||
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
||||
res_aval_indices_per_jaxpr):
|
||||
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
|
||||
all_res_vars = _map(newvar, all_res_avals)
|
||||
all_res_vars = map(newvar, all_res_avals)
|
||||
|
||||
def augment_jaxpr(jaxpr, res_indices):
|
||||
num_res = len(res_indices)
|
||||
@ -509,15 +585,51 @@ def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
||||
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
|
||||
return jaxpr_aug
|
||||
|
||||
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
||||
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
||||
|
||||
def _ordered_unique(xs):
|
||||
d = collections.OrderedDict((x, None) for x in xs)
|
||||
return list(d.keys())
|
||||
|
||||
def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn,
|
||||
) -> Tuple[List[bool], core.JaxprEqn]:
|
||||
if not config.after_neurips:
|
||||
return [True] * len(eqn.params['jaxpr'].in_avals), eqn
|
||||
closed_branches = eqn.params['branches']
|
||||
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]
|
||||
|
||||
# First, compute which inputs are used in any branch (not including `pred`).
|
||||
used_inputs: List[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
|
||||
for jaxpr in branches:
|
||||
_, used_inputs_ = pe.dce_jaxpr(jaxpr, used_outputs, instantiate=False)
|
||||
used_inputs = map(operator.or_, used_inputs, used_inputs_)
|
||||
|
||||
# Next, compute DCEd branches, instantiating according to used_inputs.
|
||||
dce_branches_ = [pe.dce_jaxpr(jaxpr, used_outputs, instantiate=used_inputs)[0]
|
||||
for jaxpr in branches]
|
||||
dce_branches = [core.ClosedJaxpr(jaxpr, closed_jaxpr.consts)
|
||||
for closed_jaxpr, jaxpr in zip(closed_branches, dce_branches_)]
|
||||
|
||||
# Finally, update parameters and form the new eqn.
|
||||
dce_linear = [l for l, used in zip(eqn.params['linear'], used_inputs) if used]
|
||||
new_params = dict(eqn.params, branches=tuple(dce_branches),
|
||||
linear=tuple(dce_linear))
|
||||
new_effects = core.join_effects(*(b.effects for b in dce_branches))
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_effects, eqn.source_info)
|
||||
|
||||
assert all(len(new_eqn.invars ) == 1 + len(jaxpr.in_avals )
|
||||
for jaxpr in new_params['branches'])
|
||||
assert all(len(new_eqn.outvars) == len(jaxpr.out_avals)
|
||||
for jaxpr in new_params['branches'])
|
||||
return [True, *used_inputs], new_eqn
|
||||
|
||||
|
||||
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
|
||||
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
||||
primal_avals = _map(raise_to_shaped, primal_avals)
|
||||
primal_avals = map(raise_to_shaped, primal_avals)
|
||||
|
||||
@lu.wrap_init
|
||||
def transposed(*args):
|
||||
@ -526,13 +638,13 @@ def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
|
||||
cts_in = ad.backward_pass(
|
||||
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
|
||||
_, cts_in = split_list(cts_in, [num_res])
|
||||
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
|
||||
return map(ad.instantiate_zeros_aval, primal_avals, cts_in)
|
||||
|
||||
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
||||
|
||||
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
|
||||
index, *ops = args
|
||||
in_avals = _map(raise_to_shaped, branches[0].in_avals)
|
||||
in_avals = map(raise_to_shaped, branches[0].in_avals)
|
||||
num_res = len(ops) - sum(linear)
|
||||
|
||||
branches_trans = tuple(
|
||||
@ -544,12 +656,12 @@ def _cond_transpose(reduce_axes, cts, *args, branches, linear):
|
||||
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
|
||||
|
||||
res = ops[:num_res]
|
||||
cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
|
||||
cts = map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
|
||||
linear_trans = (False,) * num_res + (True,) * len(cts)
|
||||
|
||||
out = cond_p.bind(
|
||||
index, *res, *cts, branches=branches_trans, linear=linear_trans)
|
||||
assert all(_map(core.typecheck, lin_in_avals, out))
|
||||
assert all(map(core.typecheck, lin_in_avals, out))
|
||||
|
||||
out_iter = iter(out)
|
||||
out = [next(out_iter) if l else None for l in linear]
|
||||
@ -589,11 +701,11 @@ def _cond_typecheck(*in_atoms, branches, linear):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
|
||||
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
|
||||
if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
|
||||
if not all(map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches 0 and {i+1} have mismatching input types: '
|
||||
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
|
||||
if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
|
||||
if not all(map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches 0 and {i+1} have mismatching output types: '
|
||||
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
|
||||
@ -607,7 +719,7 @@ def _cond_typecheck(*in_atoms, branches, linear):
|
||||
if index_aval.dtype != np.int32:
|
||||
raise core.JaxprTypeError(
|
||||
f'cond called with index of type {index_aval.dtype} instead of int32')
|
||||
if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)):
|
||||
if not all(map(core.typecompat, jaxpr0.in_avals, op_avals)):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches take input types {jaxpr0_in_avals_str}, '
|
||||
f'called with operands of type {_avals_short(op_avals)}')
|
||||
@ -620,7 +732,7 @@ def _cond_typecheck(*in_atoms, branches, linear):
|
||||
|
||||
def cond_bind(*args, branches, linear):
|
||||
if config.jax_enable_checks:
|
||||
avals = _map(core.get_aval, args)
|
||||
avals = map(core.get_aval, args)
|
||||
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
|
||||
_cond_typecheck(*in_atoms, branches=branches, linear=linear)
|
||||
for jaxpr in branches:
|
||||
@ -638,8 +750,8 @@ pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
||||
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
|
||||
xla.register_initial_style_primitive(cond_p)
|
||||
core.custom_typechecks[cond_p] = _cond_typecheck
|
||||
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
|
||||
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')
|
||||
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
|
||||
pe.dce_rules[cond_p] = _cond_dce_rule
|
||||
|
||||
def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
del linear # Unused.
|
||||
@ -650,7 +762,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
tokens_in = ctx.tokens_in.subset(ordered_effects)
|
||||
output_token_types = [mlir.token_type() for _ in ordered_effects]
|
||||
output_types = [
|
||||
*output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)]
|
||||
*output_token_types, *map(mlir.aval_to_ir_types, ctx.avals_out)]
|
||||
flat_output_types = util.flatten(output_types)
|
||||
|
||||
# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
|
||||
@ -666,13 +778,13 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
|
||||
out_vals, tokens_out = mlir.jaxpr_subcomp(
|
||||
sub_ctx, jaxpr.jaxpr, tokens_in,
|
||||
_map(mlir.ir_constants, jaxpr.consts),
|
||||
*_map(mlir.wrap_singleton_ir_values, args))
|
||||
map(mlir.ir_constants, jaxpr.consts),
|
||||
*map(mlir.wrap_singleton_ir_values, args))
|
||||
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
|
||||
out_vals = [*out_tokens, *out_vals]
|
||||
mhlo.ReturnOp(util.flatten(out_vals))
|
||||
|
||||
tokens_and_outputs = util.unflatten(case_op.results, _map(len, output_types))
|
||||
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
|
||||
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
|
||||
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
|
||||
return outputs
|
||||
|
@ -800,18 +800,17 @@ def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
|
||||
else:
|
||||
assert False, "Fixpoint not reached"
|
||||
core.check_jaxpr(jaxpr.jaxpr)
|
||||
if config.jax_enable_checks: core.check_jaxpr(jaxpr.jaxpr)
|
||||
|
||||
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
|
||||
new_params = dict(eqn.params, num_consts=sum(used_consts),
|
||||
num_carry=sum(used_carry_in), linear=tuple(new_linear),
|
||||
jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts))
|
||||
new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs)
|
||||
if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs)
|
||||
if used],
|
||||
eqn.primitive, new_params, eqn.effects,
|
||||
eqn.source_info)
|
||||
# TODO(mattjj,sharadmv): don't assume effects are never DCE'd?
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, eqn.effects, eqn.source_info)
|
||||
assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals )
|
||||
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
|
||||
return used_inputs, new_eqn
|
||||
@ -822,7 +821,8 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
|
||||
# Fixpoint (currently trivial on 'inst_in')
|
||||
# Fixpoint (trivial on 'inst_in', since we might as well make all inputs
|
||||
# available as DCE can subsequently prune any unused ones)
|
||||
const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry])
|
||||
for _ in range(1 + len(carry_uk)):
|
||||
unks_in = const_uk + carry_uk + xs_uk
|
||||
@ -831,7 +831,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=[True] * len(unks_in),
|
||||
ensure_out_unknowns=carry_uk + [False] * num_ys,
|
||||
ensure_out_inst=True, saveable=saveable)
|
||||
carry_uk_out , ys_uk = split_list(unks_out, [num_carry])
|
||||
carry_uk_out, ys_uk = split_list(unks_out, [num_carry])
|
||||
if carry_uk_out == carry_uk:
|
||||
break
|
||||
else:
|
||||
@ -841,13 +841,14 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
|
||||
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
|
||||
|
||||
# Ensure residuals are all moved to the back.
|
||||
# Move all residual binders to the back of jaxpr_staged so they're extensive.
|
||||
# TODO(mattjj): make jaxpr_staged only take instantiated inputs
|
||||
res_avals = jaxpr_staged.in_avals[:num_res]
|
||||
jaxpr_staged = pe.move_binders_to_back(
|
||||
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))
|
||||
|
||||
# Instantiate all inputs (b/c jaxpr_staged takes all inputs).
|
||||
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
|
||||
# passing in_inst argument to partial_eval_jaxpr_custom above).
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
inst_in = [True] * len(inst_in)
|
||||
@ -907,6 +908,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects,
|
||||
eqn.source_info)
|
||||
|
||||
# Create the staged eqn.
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) +
|
||||
[False] * len(extensive_res))
|
||||
|
@ -4546,7 +4546,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
# Two sine calls in the backward pass because while we don't save sines
|
||||
# within the (rematted) body function, we can save the scan carry, which
|
||||
# effectively saves one sine. Three cosines for the Jacoian coefficients.
|
||||
# effectively saves one sine. Three cosines for the Jacobian coefficients.
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 3)
|
||||
# Six calls to dot_general in the backward pass because we save the primal
|
||||
@ -4805,7 +4805,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_const_in_jvp(self, remat):
|
||||
def test_const_in_jvp_scan(self, remat):
|
||||
@api.custom_jvp
|
||||
def f(x):
|
||||
return x * np.arange(3.)
|
||||
@ -4980,6 +4980,233 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_remat_of_cond(self, remat):
|
||||
true_fn = lambda c: (jnp.sin(c), jnp.sin(c))
|
||||
false_fn = lambda c: (jnp.sin(c), jnp.sin(c))
|
||||
f = lambda x: lax.cond(x > 0., true_fn, false_fn, x)
|
||||
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
|
||||
self.assertNotIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
|
||||
true_fn = lambda c: jnp.sin(jnp.sin(c))
|
||||
false_fn = lambda c: c
|
||||
f = lambda x: lax.cond(x > 0., true_fn, false_fn, x)
|
||||
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_new', new_checkpoint),
|
||||
])
|
||||
def test_const_in_jvp_cond(self, remat):
|
||||
@api.custom_jvp
|
||||
def f(x):
|
||||
return x * np.arange(3.)
|
||||
@f.defjvp
|
||||
def f_jvp(primals, tangents):
|
||||
(x,), (xdot,) = primals, tangents
|
||||
return f(x), xdot * np.arange(3.)
|
||||
|
||||
@remat
|
||||
def g(x):
|
||||
y = jax.lax.cond(x.sum() > 0, f, lambda x: x, x)
|
||||
return y.sum()
|
||||
|
||||
jax.grad(g)(jnp.arange(3.)) # doesn't crash
|
||||
|
||||
def test_remat_checkpoint_dots_inside_cond(self):
|
||||
x = jnp.ones((5,))
|
||||
|
||||
def f(W):
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(x):
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
return x
|
||||
|
||||
return lax.cond(x.sum() > 0, f, lambda x: x, x)
|
||||
|
||||
_, f_vjp = api.vjp(f, jnp.ones((5, 5)))
|
||||
jaxpr_text = str(f_vjp.args[0].func.args[1])
|
||||
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 3)
|
||||
# Five calls to dot_general in the backward pass because we have two for
|
||||
# each forward-pass dot, except for the first which only has one (as we are
|
||||
# differentiating with respect to only W and not x).
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 5)
|
||||
|
||||
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
|
||||
modes=['fwd', 'rev'])
|
||||
|
||||
def test_remat_checkpoint_dots_outside_cond(self):
|
||||
# see also above test test_remat_checkpoint_dots_inside_cond
|
||||
# The behavior between the two tests is essentially identical, whereas for
|
||||
# scan different things are saved based on this difference in remat
|
||||
# placement (because of the carry).
|
||||
x = jnp.ones((5,))
|
||||
|
||||
@partial(new_checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(W):
|
||||
def f(x):
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
return x
|
||||
|
||||
return lax.cond(x.sum() > 0, f, lambda x: x, x)
|
||||
|
||||
_, f_vjp = api.vjp(f, jnp.ones((5, 5)))
|
||||
jaxpr = f_vjp.args[0].func.args[1]
|
||||
jaxpr_text = str(jaxpr)
|
||||
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 3)
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 5)
|
||||
|
||||
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
|
||||
modes=['fwd', 'rev'])
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_cond_policy(self):
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
f = new_checkpoint(lambda x: lax.cond(x > 0, jnp.sin, lambda x: x, x),
|
||||
policy=save_cos)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_cond_funky_custom_jvp(self):
|
||||
def cond_apply(f, x):
|
||||
return lax.cond(x.sum() > -jnp.inf, f, lambda x: x, x)
|
||||
|
||||
@api.custom_jvp
|
||||
def sin(x):
|
||||
return jnp.sin(x)
|
||||
def sin_jvp(primals, tangents):
|
||||
x, = primals
|
||||
xdot, = tangents
|
||||
y, c = jax.jit(lambda: (jnp.sin(x), jnp.cos(x)))()
|
||||
ydot = c * xdot
|
||||
return y, ydot
|
||||
sin.defjvp(sin_jvp)
|
||||
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
f = new_checkpoint(partial(cond_apply, sin), policy=save_cos)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
f = new_checkpoint(partial(cond_apply, sin), policy=save_sin)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
|
||||
f = new_checkpoint(partial(cond_apply, sin),
|
||||
policy=jax.checkpoint_policies.everything_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
f = new_checkpoint(partial(cond_apply, sin),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
|
||||
f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_cond_funky_custom_jvp2(self):
|
||||
# Like the above test but instead of using jit inside custom_jvp, use cond.
|
||||
|
||||
def cond_apply(f, x):
|
||||
return lax.cond(True, f, lambda x: x, x)
|
||||
|
||||
@api.custom_jvp
|
||||
def sin(x):
|
||||
return jnp.sin(x)
|
||||
def sin_jvp(primals, tangents):
|
||||
x, = primals
|
||||
xdot, = tangents
|
||||
y, c = cond_apply(lambda xs: (jnp.sin(xs[0]), jnp.cos(xs[1])), (x, x))
|
||||
ydot = c * xdot
|
||||
return y, ydot
|
||||
sin.defjvp(sin_jvp)
|
||||
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
f = new_checkpoint(partial(cond_apply, sin), policy=save_cos)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
f = new_checkpoint(partial(cond_apply, sin), policy=save_sin)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
|
||||
f = new_checkpoint(partial(cond_apply, sin),
|
||||
policy=jax.checkpoint_policies.everything_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
f = new_checkpoint(partial(cond_apply, sin),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
|
||||
f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
@ -5367,6 +5594,56 @@ class DCETest(jtu.JaxTestCase):
|
||||
result2 = core.eval_jaxpr(jaxpr_pruned, consts, *pruned_args)
|
||||
self.assertAllClose(result1, result2)
|
||||
|
||||
def test_dce_jaxpr_cond_trivial(self):
|
||||
x = jnp.array(1., dtype='float32')
|
||||
|
||||
# start with 7 eqns, use both outputs so nothing can be pruned
|
||||
def f(x1, x2):
|
||||
return lax.cond(x1 > 0,
|
||||
lambda x1, x2: (jnp.sin(x1), jnp.sin(x2)),
|
||||
lambda x1, x2: (jnp.sin(x1), jnp.sin(x2)),
|
||||
x1, x2)
|
||||
jaxpr = jax.make_jaxpr(f)(x, x).jaxpr
|
||||
self.assert_dce_result(jaxpr, [True, True], [True, True], 7)
|
||||
|
||||
# use neither output so everything can be pruned
|
||||
self.assert_dce_result(jaxpr, [False, False], [False, False], 0)
|
||||
|
||||
def test_dce_jaxpr_cond_nontrivial(self):
|
||||
x = jnp.array(1., dtype='float32')
|
||||
|
||||
# start with 7 eqns, dont use an output so an eqn can be trimmed on each
|
||||
# side and x2 _can_ be pruned
|
||||
def f(x1, x2):
|
||||
return lax.cond(x1 > 0,
|
||||
lambda x1, x2: (jnp.sin(x1), jnp.sin(x2)),
|
||||
lambda x1, x2: (jnp.sin(x1), jnp.sin(x1)),
|
||||
x1, x2)
|
||||
jaxpr = jax.make_jaxpr(f)(x, x).jaxpr
|
||||
self.assert_dce_result(jaxpr, [True, False], [True, False], 5)
|
||||
|
||||
# start with 7 eqns, dont use an output so an eqn can be trimmed on each
|
||||
# side, but x2 _can't_ be pruned b/c of a swap
|
||||
def f(x1, x2):
|
||||
return lax.cond(x1 > 0,
|
||||
lambda x1, x2: (jnp.sin(x1), jnp.sin(x2)),
|
||||
lambda x1, x2: (jnp.sin(x2), jnp.sin(x1)),
|
||||
x1, x2)
|
||||
jaxpr = jax.make_jaxpr(f)(x, x).jaxpr
|
||||
self.assert_dce_result(jaxpr, [True, False], [True, True], 5)
|
||||
|
||||
# start with 7 eqns, only use x1 on one side and x2 on the other, so we
|
||||
# can't prune any inputs or eqns
|
||||
def f(x1, x2):
|
||||
return lax.cond(x1 > 0,
|
||||
lambda x1, x2: (jnp.sin(x1), jnp.sin(x1)),
|
||||
lambda x1, x2: (jnp.sin(x2), jnp.sin(x2)),
|
||||
x1, x2)
|
||||
jaxpr = jax.make_jaxpr(f)(x, x).jaxpr
|
||||
self.assert_dce_result(jaxpr, [True, True], [True, True], 7)
|
||||
# use only one output, so we can prune eqns but not inputs
|
||||
self.assert_dce_result(jaxpr, [True, False], [True, True], 5)
|
||||
|
||||
|
||||
class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -59,6 +59,22 @@ def cond_via_switch(pred, true_fun, false_fun, op, *args):
|
||||
index = lax.convert_element_type(pred, np.int32)
|
||||
return lax.switch(index, [false_fun, true_fun], op)
|
||||
|
||||
def cond_with_new_checkpoint(pred, true_fun, false_fun, op, *args):
|
||||
if args:
|
||||
true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0]
|
||||
op = (false_op, true_op)
|
||||
false_fun = lambda op: _false_fun(op[0])
|
||||
true_fun = lambda op: _true_fun(op[1])
|
||||
index = lax.convert_element_type(pred, np.int32)
|
||||
fn = lambda index, op: lax.switch(index, [false_fun, true_fun], op)
|
||||
return new_checkpoint(fn)(index, op)
|
||||
|
||||
COND_IMPLS = [
|
||||
(lax.cond, 'cond'),
|
||||
(cond_via_switch, 'switch'),
|
||||
(cond_with_new_checkpoint, 'new_checkpoint'),
|
||||
]
|
||||
|
||||
|
||||
# We wanted to try all scan tests with the scan partial evaluation rule that
|
||||
# happens under ad_checkpoint.checkpoint, so we make a scan wrapper which
|
||||
@ -73,12 +89,6 @@ def scan_with_new_checkpoint2(f, *args, **kwargs):
|
||||
def scan_with_for(f, *args, **kwargs):
|
||||
return for_loop.scan(f, *args, **kwargs)
|
||||
|
||||
COND_IMPLS = [
|
||||
(lax.cond, 'cond'),
|
||||
(cond_via_switch, 'switch'),
|
||||
]
|
||||
|
||||
|
||||
SCAN_IMPLS = [
|
||||
(lax.scan, 'unroll1'),
|
||||
(partial(lax.scan, unroll=2), 'unroll2'),
|
||||
@ -786,7 +796,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertEqual(cf(1, x), branch(x))
|
||||
|
||||
def testIssue1379(self):
|
||||
|
||||
def fun(pred):
|
||||
return lax.cond(pred, lambda x: (True, x), lambda x: (False, x), pred)
|
||||
|
||||
@ -1092,7 +1101,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return 2. * x
|
||||
|
||||
def fun(x):
|
||||
return cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
|
||||
return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x)
|
||||
|
||||
x = 3.14
|
||||
ans = jax.jvp(fun, (x,), (x,))
|
||||
@ -1222,7 +1231,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return 2. * x
|
||||
|
||||
def fun(x):
|
||||
return cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
|
||||
return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x)
|
||||
|
||||
x = 3.14
|
||||
ans = jax.grad(fun)(x)
|
||||
@ -1240,6 +1249,8 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondGrad4(self, cond):
|
||||
if cond is cond_with_new_checkpoint and 'tpu' in jtu.device_under_test():
|
||||
raise unittest.SkipTest("tpu bug") # TODO(parkers): tpu bug ehibited here
|
||||
def fun_ref(x, y):
|
||||
if x < 3:
|
||||
return 2. * jnp.sin(y)
|
||||
@ -1250,7 +1261,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def fun(x, y):
|
||||
return cond(
|
||||
x < 3,
|
||||
(), lambda _: 2. * jnp.sin(y),
|
||||
None, lambda _: 2. * jnp.sin(y),
|
||||
x, lambda x: 2. * x)
|
||||
|
||||
y = 5.8
|
||||
@ -1761,7 +1772,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
f"got {tree_util.tree_structure(a)} and {tree_util.tree_structure((1, 2))}.")):
|
||||
lax.scan(lambda c, x: (0, x), (1, 2), a)
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{scan_name}",
|
||||
"scan": scan_impl}
|
||||
|
Loading…
x
Reference in New Issue
Block a user