autodidax: add cond and start while_loop

This commit is contained in:
Matthew Johnson 2021-04-14 17:51:16 -07:00
parent 6ce4ef46b9
commit 83cd42271b
3 changed files with 1909 additions and 186 deletions

File diff suppressed because it is too large Load Diff

View File

@ -95,6 +95,7 @@ sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
@ -105,6 +106,7 @@ def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
@ -138,7 +140,7 @@ more descriptive.
```{code-cell}
from contextlib import contextmanager
from typing import Type, List, Optional, Any
from typing import Type, List, Tuple, Sequence, Optional, Any
class MainTrace(NamedTuple):
level: int
@ -209,7 +211,6 @@ like arrays.)
```{code-cell}
import numpy as np
from typing import Tuple
class Tracer:
_trace: Trace
@ -229,6 +230,7 @@ class Tracer:
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __bool__(self): return self.aval._bool(self)
def __nonzero__(self): return self.aval._nonzero(self)
@ -261,6 +263,7 @@ class ShapedArray:
_mul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
_lt = staticmethod(less)
@staticmethod
def _bool(tracer):
@ -421,6 +424,7 @@ impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
@ -452,7 +456,8 @@ First, a few helper functions:
```{code-cell}
def zeros_like(val):
return np.zeros_like(val)
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
def unzip2(pairs):
lst1, lst2 = [], []
@ -464,6 +469,14 @@ def unzip2(pairs):
map_ = map
def map(f, *xs):
return list(map_(f, *xs))
zip_ = zip
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(zip_(*args))
```
The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The
@ -533,6 +546,12 @@ def greater_jvp(primals, tangents):
out_primal = greater(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = less(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp
```
Finally, we add a transformation API to kick off the trace:
@ -833,16 +852,17 @@ Next we can define batching interpreter rules for each primitive:
```{code-cell}
from functools import partial
def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):
def binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)
vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
(x,), (x_bdim,) = vals_in, dims_in
@ -917,7 +937,7 @@ That's it for `jvp` and `vmap`!
## Part 2: Jaxprs
The next transformations are the horizon are `jit` for just-in-time
The next transformations on the horizon are `jit` for just-in-time
compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small
wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to
carry a little bit of extra context, for both `jit` and `vjp` we need much
@ -984,8 +1004,8 @@ class Lit:
aval: ShapedArray
def __init__(self, val):
self.val = val
self.aval = raise_to_shaped(get_aval(self.val))
self.aval = aval = raise_to_shaped(get_aval(val))
self.val = np.array(val, aval.dtype)
Atom = Union[Var, Lit]
@ -1088,6 +1108,19 @@ a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one
used by `jit`, which is also used by control flow primitives like `lax.cond`,
`lax.while_loop`, and `lax.scan`.
```{code-cell}
def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
assert 0 <= n <= len(lst)
return lst[:n], lst[n:]
def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
assert len(bs) == len(l)
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
lists[b].append(x)
return lst1, lst2
```
```{code-cell}
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
@ -1181,9 +1214,25 @@ class JaxprBuilder:
out_vars = [t2v(t) for t in out_tracers]
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
typecheck_jaxpr(jaxpr)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
```
```{code-cell}
def _inline_literals(jaxpr: Jaxpr, consts: List[Any]) -> Tuple[Jaxpr, List[Any]]:
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
new_const_binders, lit_binders = partition_list(scalars, const_binders)
new_consts, lit_vals = partition_list(scalars, consts)
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
new_outs = [literals.get(x, x) for x in jaxpr.outs]
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
typecheck_jaxpr(new_jaxpr)
return new_jaxpr, new_consts
```
The rules we need for `JaxprTrace.process_primitive` are essentially typing
rules for primitive applications: given the primitive, its parameters, and
types for the inputs, the rule must produce a type for the output, which is
@ -1196,36 +1245,38 @@ rules for the other jaxpr-producing trace machinery, where the potential extra
generality is useful.
```{code-cell}
def broadcast_shapes(*shapes):
assert len(shapes) > 1
for sizes in zip(*shapes):
sizes = [d for d in sizes if d != 1]
if sizes[:-1] != sizes[1:]:
raise Exception
return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes))
def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
return [ShapedArray(x.shape, x.dtype)]
def broadcasting_binop_abstract_eval_rule(*avals_in):
out_dtype = np.result_type(*map(np.result_type, avals_in))
out_shape = broadcast_shapes(*map(np.shape, avals_in))
return [ShapedArray(out_shape, out_dtype)]
abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval
abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule
abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if x.shape != y.shape: raise TypeError
return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval
def vectorized_unop_abstract_eval_rule(aval_in):
return [ShapedArray(np.shape(aval_in), np.result_type(aval_in))]
def vectorized_unop_abstract_eval(x: ShapedArray) -> List[ShapedArray]:
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval_rule
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
def reduce_sum_abstract_eval_rule(aval_in, *, axis):
new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis]
return [ShapedArray(tuple(new_shape), aval_in.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: int) -> List[ShapedArray]:
new_shape = [d for i, d in enumerate(x.shape) if i != axis]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
def broadcast_abstract_eval(x, *, shape, axes):
return [ShapedArray(tuple(shape), np.result_type(x))]
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
axes: Sequence[int]) -> List[ShapedArray]:
return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
```
@ -1501,6 +1552,7 @@ impl_rules[xla_call_p] = xla_call_impl
@lru_cache()
def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable]):
jaxpr: Jaxpr = hashable_jaxpr.val
typecheck_jaxpr(jaxpr)
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
c = xb.make_computation_builder('xla_call')
@ -1534,7 +1586,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
env: Dict[Var, xe.XlaOp] = {}
def read(x: Atom) -> xe.XlaOp:
return env[x] if type(x) is Var else xb.constant(c, x.val)
return env[x] if type(x) is Var else xb.constant(c, x.val, False)
def write(v: Var, val: xe.XlaOp) -> None:
env[v] = val
@ -1555,7 +1607,7 @@ def execute_compiled(compiled, out_avals, *args):
default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
[int, float, np.ndarray, np.float64, np.float32]}
[bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now.
@ -1580,6 +1632,7 @@ xla_translations[neg_p] = partial(direct_translation, xops.Neg)
xla_translations[sin_p] = partial(direct_translation, xops.Sin)
xla_translations[cos_p] = partial(direct_translation, xops.Cos)
xla_translations[greater_p] = partial(direct_translation, xops.Gt)
xla_translations[less_p] = partial(direct_translation, xops.Lt)
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
(x_aval,), (x,) = in_avals, in_vals
@ -1784,6 +1837,7 @@ class DeviceArray:
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[DeviceArray] = lambda x: x.buf
jax_types.add(DeviceArray)
@ -1829,6 +1883,20 @@ y, y_dot = jvp(f, (x,), (x_dot,))
where the application of `f_lin` does not redo any of the linearization work.
We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.
Tangentially, now that we have linear arrows `-o`, we can provide a slightly
more informative type for `jvp`:
```
jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
```
Here we're writing `UnrestrictedUse` just to indicate that we have a special
pair where the first element can be used in an unrestricted (nonlinear) way.
In conjunction with the linear arrow, this notation is just meant to express
that the function `jvp f` uses its first input in a nonlinear way but its
second input in a linear way, producing a corresponding nonlinear output
(which can be used in a nonlinear way) paired with a linear output. This more
refined type signature encodes the data dependencies in `jvp f`, which are
useful for partial evaluation.
To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:
we evaluate all the primal values as we trace, but stage the tangent
computations into a jaxpr. This is our second way to build jaxprs. But where
@ -1839,18 +1907,15 @@ primitive binds with a data dependence on tangent inputs.
First, some utilities:
```{code-cell}
def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
return lst[:n], lst[n:]
def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:
assert not len(lst) % 2
return split_list(lst, len(lst) // 2)
def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
lists[b].append(x)
return lst1, lst2
def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]:
l1, l2 = iter(l1), iter(l2)
out = [next(l2) if b else next(l1) for b in which]
assert next(l1, None) is next(l2, None) is None
return out
```
Next, we'll write `linearize` by combining `jvp` together with a general
@ -1895,24 +1960,22 @@ inputs, together with (2) a jaxpr representing the part of the Python
callable's computation which can only be performed after the remaining inputs
are known.
This transformation can't be summarized purely in a type signature because its
behavior relies on the data dependencies inside the given Python callable and
not just its type. Nevertheless a heuristic type signature is useful. If we
This transformation is tricky to summarize in a type signature. If we
assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where
`a1` and `a2` represent the known and unknown inputs, respectively, and where
`b1` only has a data dependency on `a1` while `b2` has some data dependency on
`a2`, then we might write
```
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> (b1, res, (res, a2) -> b2)
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
```
In words, given values for the inputs of type `a1`, `partial_eval` produces
the outputs of type `b1` along with "residual" values of type `res`
representing the intermediates required to complete the computation in the
second stage. It also produces a function of type `(res, a2) -> b2` which
accepts the residual values as well as the remaining inputs and produces the
remaining outputs.
the outputs of type `b1` along with "residual" values of
existentially-quantified type `r` representing the intermediates required to
complete the computation in the second stage. It also produces a function of
type `(r, a2) -> b2` which accepts the residual values as well as the
remaining inputs and produces the remaining outputs.
We like to think of partial evaluation as "unzipping" one computation into
two. For example, consider this jaxpr:
@ -1924,7 +1987,7 @@ two. For example, consider this jaxpr:
```
A jaxpr for the JVP would look like:
```
{ lambda a:float64[] b:float64 .
{ lambda a:float64[] b:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
e:float64[] = mul d b
@ -2158,6 +2221,8 @@ def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:
```
```{code-cell}
:tags: [hide-input]
def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
if not out_nodes: return []
out_nodes = remove_duplicates(out_nodes)
@ -2212,6 +2277,10 @@ To handle `linearize`-of-`jit`, we still need to write a partial evaluation
rule for `xla_call_p`. Other than tracer bookkeeping, the main task is to
perform partial evaluation of a jaxpr, 'unzipping' it into two jaxprs.
There are actually two rules to write: one for trace-time partial evaluation,
which we'll call `xla_call_partial_eval`, and one for partial evaluation of
jaxprs, whicch we'll call `xla_call_peval_eqn`.
```{code-cell}
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
del num_consts # Unused.
@ -2228,18 +2297,17 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
dict(jaxpr=jaxpr2, num_consts=0),
[v.aval for v in jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
outs1, outs2 = iter(outs1), iter(outs2)
return [next(outs2) if uk else next(outs1) for uk in out_unknowns]
return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool]
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
instantiate: Optional[List[bool]] = None,
) -> Tuple[Jaxpr, Jaxpr, List[bool], int]:
env: Dict[Var, bool] = {}
residuals = set()
def read(v: Atom) -> bool:
if type(v) is Lit: raise NotImplementedError
return env[v]
return type(v) is Var and env[v]
def write(unk: bool, v: Var) -> None:
env[v] = unk
@ -2264,6 +2332,11 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool]
eqns1.append(eqn)
map(partial(write, False), eqn.out_binders)
out_unknowns = map(read, jaxpr.outs)
if instantiate is not None:
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
if inst and not uk: new_res(v)
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
@ -2295,7 +2368,7 @@ def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
@ -2481,10 +2554,11 @@ transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache()
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: Tuple[bool, ...]
) -> Tuple[Jaxpr, List[Any]]:
traceable = partial(eval_jaxpr_transposed, jaxpr)
avals_in, avals_out = typecheck_jaxpr(jaxpr)
traceable = partial(eval_jaxpr_transposed, jaxpr)
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
typecheck_jaxpr(trans_jaxpr)
return trans_jaxpr, consts
```
@ -2571,3 +2645,441 @@ _, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
```
## Part 5: the control flow primitives `cond` and `while_loop`
Next we'll add higher-order primitives for staged-out control flow. These
resemble `jit` from Part 3, another higher-order primitive, but differ in that
they are parameterized by multiple callables rather than just one.
+++
### Adding `cond`
We introduce a `cond` primitive to represent conditional application of one
function or another inside a jaxpr. We write the type of `cond` as
`Bool -> (a -> b) -> (a -> b) -> a -> b`. In words, `cond` takes a boolean
representing the predicate and two functions of equal types. Depending on the
value of the predicate, it applies one function or the other to its final
argument.
In Python, we represent it as a function which itself takes two functions as
arguments. As with `jit`, the first step is to call `make_jaxpr` on its
callable arguments to turn them into jaxprs:
```{code-cell}
def cond(pred, true_fn, false_fn, *operands):
avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
if out_tree != out_tree_: raise TypeError
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
raise TypeError
outs = bind_cond(pred, *true_consts, *false_consts, *operands,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')
def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> Tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
consts1, rest1 = split_list(jaxpr1.in_binders, n1)
consts2, rest2 = split_list(jaxpr2.in_binders, n2)
new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
return new_jaxpr1, new_jaxpr2
def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
```
We require `true_jaxpr` and `false_jaxpr` to have the same type, but because
they might close over different constants (and because jaxprs can only
represent closed terms, i.e. can't have free variables and are instead
closure-converted) we need to use the helper `_join_jaxpr_consts` to make
consistent the input binder lists of the two jaxprs. (To be more economical we
could try to identify pairs of constants with the same shapes, but instead we
just concatenate the lists of constants.)
Next we can turn to adding interpreter rules for `cond`. Its evaluation rule
is simple:
```{code-cell}
def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
if pred:
return eval_jaxpr(true_jaxpr, operands)
else:
return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
```
```{code-cell}
out = cond(True, lambda: 3, lambda: 4)
print(out)
```
For its JVP and vmap rules, we only need to call the same `jvp_jaxpr` and
`vmap_jaxpr` utilities we created for `jit`, followed by another pass of
`_join_jaxpr_consts`:
```{code-cell}
def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
pred, *primals = primals
_ , *tangents = tangents
true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)
false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
```
```{code-cell}
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
```
```{code-cell}
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
pred , *vals_in = vals_in
pred_dim, *dims_in = dims_in
if pred_dim is not not_mapped: raise NotImplementedError # TODO
true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
```
```{code-cell}
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
```
Notice that we're not currently supporting the case where the predicate value
itself is batched. In mainline JAX, we handle this case by transforming the
conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).
That transformation is semantically correct so long as `true_fun` and
`false_fun` do not involve any side-effecting primitives.
Another thing not represented here, but present in the mainline JAX, is that
applying transformations to two jaxprs of equal type might result in jaxprs of
different types. For example, applying the mainline JAX version of
`vmap_jaxpr` to the identity-function jaxpr
```
{ lambda a:float32[] .
let
in ( a ) }
```
would result in a jaxpr with a batched output, of type
`[float32[10]] -> [float32[10]]` if the batch size were 10, while applying it
to the zero-function jaxpr
```
{ lambda a:float32[] .
let
in ( 0. ) }
```
would result in a jaxpr with an unbatched output, of type
`[float32[10]] -> [float32[]]`. This is an optimization, aimed at not batching
values unnecessarily. But it means that in `cond` we'd need an extra step of
joining the two transformed jaxprs to have consistent output types. We don't
need this step here because we chose `vmap_jaxpr` always to batch all outputs
over the leading axis.
+++
Next we can turn to abstract evaluation and XLA lowering rules:
```{code-cell}
def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
jaxpr_type = typecheck_jaxpr(true_jaxpr)
if jaxpr_type != typecheck_jaxpr(false_jaxpr):
raise TypeError
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
del in_avals # Unused.
pred, *in_vals = in_vals
flat_vals, in_tree = tree_flatten(in_vals)
operand = xops.Tuple(c, flat_vals)
operand_shape = c.get_shape(operand)
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
c = xb.make_computation_builder(name)
operand = xb.parameter(c, 0, operand_shape)
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
outs = jaxpr_subcomp(c, jaxpr, operands)
return c.build(xops.Tuple(c, outs))
true_comp = make_comp('true_fn', true_jaxpr)
false_comp = make_comp('false_fn', false_jaxpr)
int_etype = xc.dtype_to_etype(np.dtype('int32'))
out = xops.Conditional(xops.ConvertElementType(pred, int_etype),
[false_comp, true_comp], [operand] * 2)
return destructure_tuple(c, out)
xla_translations[cond_p] = cond_translation
```
```{code-cell}
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
```
Finally, to support reverse-mode automatic differentiation, we need partial
evaluation and transposition rules. For partial evaluation, we need to
introduce another jaxpr-munging utility, `_join_jaxpr_res`, to handle the fact
that applying partial evaluation to `true_fun` and `false_fun` will in general
result in distinct residuals. We use `_join_jaxpr_res` to make the output
types of the transformed jaxprs consistent (while `_join_jaxpr_consts` dealt
with input types).
```{code-cell}
def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
pred_tracer, *tracers = tracers
assert pred_tracer.pval.is_known
pred = pred_tracer.pval.const
in_uks = [not t.pval.is_known for t in tracers]
*jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
known_tracers, unknown_tracers = partition_list(in_uks, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind_cond(pred, *known_vals,
true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in t_jaxpr2.outs]
eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
[v.aval for v in t_jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval
def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: List[bool]
) -> Tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, List[bool], int]:
_, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
_, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
out_uks = map(op.or_, t_out_uks, f_out_uks)
t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)
t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
num_res = t_nres + f_nres
return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res
def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> Tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
assert out_types1 == out_types2
outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
return new_jaxpr1, new_jaxpr2
```
```{code-cell}
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
```
```{code-cell}
def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
pred_unk, *unks_in = unks_in
assert not pred_unk
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
*jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
outs1 + residuals)
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
outs2)
return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
```
```{code-cell}
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
```
Transposition is a fairly straightforward application of `transpose_jaxpr`:
```{code-cell}
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
undef_primals = tuple([type(x) is UndefPrimal for x in invals])
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
res = [x for x in invals if type(x) is not UndefPrimal]
outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
outs = iter(outs)
return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
```
```{code-cell}
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
```
### Adding `while_loop`
Next we'll add a primitive for looping behavior in a jaxpr. We'll use
`while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first
function-valued argument represents the loop condition, the second represents
the loop body, and the final argument is the initial value of the carry.
After `cond`, adding `while_loop` is not so different:
```{code-cell}
def while_loop(cond_fn, body_fn, init_val):
init_val, in_tree = tree_flatten(init_val)
avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]
cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)
body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)
cond_jaxpr, body_jaxpr = _join_jaxpr_consts(
cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))
if cond_tree != tree_flatten(True)[1]: raise TypeError
if in_tree != in_tree_: raise TypeError
outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
return tree_unflatten(in_tree, outs)
while_loop_p = Primitive('while_loop')
```
```{code-cell}
def while_loop_impl(*args, cond_jaxpr, body_jaxpr):
consts, carry = split_list(args, _loop_num_consts(body_jaxpr))
while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:
carry = eval_jaxpr(body_jaxpr, [*consts, *carry])
return carry
impl_rules[while_loop_p] = while_loop_impl
```
```{code-cell}
def _loop_num_consts(body_jaxpr: Jaxpr) -> int:
return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)
```
```{code-cell}
out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)
print(out)
```
Notice the convention that `args = [*consts, *carry]`.
The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the
convention that all the binders for tangent values are appended after all the
binders for primal values, like `args = [*primals, *tangents]`. But that's in
tension with our `while_loop` convention that the carry binders come after all
the constant binders, i.e. that `args = [*consts, *carry]`, because both the
constants and the carries can have their own tangents. For this reason, we
introduce the `_loop_jvp_binders` helper to rearrange binders as needed.
```{code-cell}
def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):
num_consts = _loop_num_consts(body_jaxpr)
body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)
cond_jaxpr, body_jaxpr = _loop_jvp_binders(
cond_jaxpr, body_jaxpr, len(body_consts), num_consts)
outs = bind(while_loop_p, *body_consts, *primals, *tangents,
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[while_loop_p] = while_loop_jvp_rule
def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int
) -> Jaxpr:
# body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]
jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)
primal_binders, tangent_binders = split_half(binders)
consts , carry = split_list(primal_binders , n2)
consts_dot, carry_dot = split_list(tangent_binders, n2)
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)
typecheck_jaxpr(new_body_jaxpr)
# cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]
assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)
consts, carry = split_list(cond_jaxpr.in_binders, n2)
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)
return new_cond_jaxpr, new_body_jaxpr
```
```{code-cell}
out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),
(1.,), (1.,))
print(out_tan)
```
```{code-cell}
def f(x):
def cond_fn(i, _):
return i < 3
def body_fn(i, x):
return i + 1, cos(x)
_, y = while_loop(cond_fn, body_fn, (0, x))
return y
def g(x):
return cos(cos(cos(x)))
print(jvp(f, (1.,), (1.,)))
print(jvp(g, (1.,), (1.,)))
```
The vmap rule for `while_loop` presents two cases:
1. if the output of `cond_fun` is not batched, then the loop has the same
basic structure, just with a batched body;
2. but if the output of `cond_fun` is batched, we must represent a batch of
loops which might run for different numbers of iterations.
...Stay tuned for the thrilling conclusion!

View File

@ -85,6 +85,7 @@ sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
@ -95,6 +96,7 @@ def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
@ -128,7 +130,7 @@ def bind1(prim, *args, **params):
# +
from contextlib import contextmanager
from typing import Type, List, Optional, Any
from typing import Type, List, Tuple, Sequence, Optional, Any
class MainTrace(NamedTuple):
level: int
@ -197,7 +199,6 @@ class Trace:
# +
import numpy as np
from typing import Tuple
class Tracer:
_trace: Trace
@ -217,6 +218,7 @@ class Tracer:
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __bool__(self): return self.aval._bool(self)
def __nonzero__(self): return self.aval._nonzero(self)
@ -248,6 +250,7 @@ class ShapedArray:
_mul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
_lt = staticmethod(less)
@staticmethod
def _bool(tracer):
@ -404,6 +407,7 @@ impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
@ -433,7 +437,8 @@ print(f(3.0))
# +
def zeros_like(val):
return np.zeros_like(val)
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
def unzip2(pairs):
lst1, lst2 = [], []
@ -445,6 +450,14 @@ def unzip2(pairs):
map_ = map
def map(f, *xs):
return list(map_(f, *xs))
zip_ = zip
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(zip_(*args))
# -
# The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The
@ -514,6 +527,12 @@ def greater_jvp(primals, tangents):
out_primal = greater(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = less(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp
# -
# Finally, we add a transformation API to kick off the trace:
@ -797,16 +816,17 @@ vmap_rules = {}
# +
from functools import partial
def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):
def binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)
vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
(x,), (x_bdim,) = vals_in, dims_in
@ -878,7 +898,7 @@ jacfwd(f, np.arange(3.))
# ## Part 2: Jaxprs
#
# The next transformations are the horizon are `jit` for just-in-time
# The next transformations on the horizon are `jit` for just-in-time
# compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small
# wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to
# carry a little bit of extra context, for both `jit` and `vjp` we need much
@ -943,8 +963,8 @@ class Lit:
aval: ShapedArray
def __init__(self, val):
self.val = val
self.aval = raise_to_shaped(get_aval(self.val))
self.aval = aval = raise_to_shaped(get_aval(val))
self.val = np.array(val, aval.dtype)
Atom = Union[Var, Lit]
@ -1046,6 +1066,18 @@ def jaxpr_as_fun(jaxpr: Jaxpr):
# used by `jit`, which is also used by control flow primitives like `lax.cond`,
# `lax.while_loop`, and `lax.scan`.
# +
def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
assert 0 <= n <= len(lst)
return lst[:n], lst[n:]
def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
assert len(bs) == len(l)
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
lists[b].append(x)
return lst1, lst2
# +
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
@ -1138,8 +1170,22 @@ class JaxprBuilder:
out_vars = [t2v(t) for t in out_tracers]
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
typecheck_jaxpr(jaxpr)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
def _inline_literals(jaxpr: Jaxpr, consts: List[Any]) -> Tuple[Jaxpr, List[Any]]:
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
new_const_binders, lit_binders = partition_list(scalars, const_binders)
new_consts, lit_vals = partition_list(scalars, consts)
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
new_outs = [literals.get(x, x) for x in jaxpr.outs]
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
typecheck_jaxpr(new_jaxpr)
return new_jaxpr, new_consts
# The rules we need for `JaxprTrace.process_primitive` are essentially typing
# rules for primitive applications: given the primitive, its parameters, and
# types for the inputs, the rule must produce a type for the output, which is
@ -1152,36 +1198,38 @@ class JaxprBuilder:
# generality is useful.
# +
def broadcast_shapes(*shapes):
assert len(shapes) > 1
for sizes in zip(*shapes):
sizes = [d for d in sizes if d != 1]
if sizes[:-1] != sizes[1:]:
raise Exception
return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes))
def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
return [ShapedArray(x.shape, x.dtype)]
def broadcasting_binop_abstract_eval_rule(*avals_in):
out_dtype = np.result_type(*map(np.result_type, avals_in))
out_shape = broadcast_shapes(*map(np.shape, avals_in))
return [ShapedArray(out_shape, out_dtype)]
abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval
abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule
abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if x.shape != y.shape: raise TypeError
return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval
def vectorized_unop_abstract_eval_rule(aval_in):
return [ShapedArray(np.shape(aval_in), np.result_type(aval_in))]
def vectorized_unop_abstract_eval(x: ShapedArray) -> List[ShapedArray]:
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval_rule
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
def reduce_sum_abstract_eval_rule(aval_in, *, axis):
new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis]
return [ShapedArray(tuple(new_shape), aval_in.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: int) -> List[ShapedArray]:
new_shape = [d for i, d in enumerate(x.shape) if i != axis]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
def broadcast_abstract_eval(x, *, shape, axes):
return [ShapedArray(tuple(shape), np.result_type(x))]
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
axes: Sequence[int]) -> List[ShapedArray]:
return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
# -
@ -1441,6 +1489,7 @@ impl_rules[xla_call_p] = xla_call_impl
@lru_cache()
def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable]):
jaxpr: Jaxpr = hashable_jaxpr.val
typecheck_jaxpr(jaxpr)
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
c = xb.make_computation_builder('xla_call')
@ -1474,7 +1523,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
env: Dict[Var, xe.XlaOp] = {}
def read(x: Atom) -> xe.XlaOp:
return env[x] if type(x) is Var else xb.constant(c, x.val)
return env[x] if type(x) is Var else xb.constant(c, x.val, False)
def write(v: Var, val: xe.XlaOp) -> None:
env[v] = val
@ -1495,7 +1544,7 @@ def execute_compiled(compiled, out_avals, *args):
default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
[int, float, np.ndarray, np.float64, np.float32]}
[bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now.
@ -1520,6 +1569,7 @@ xla_translations[neg_p] = partial(direct_translation, xops.Neg)
xla_translations[sin_p] = partial(direct_translation, xops.Sin)
xla_translations[cos_p] = partial(direct_translation, xops.Cos)
xla_translations[greater_p] = partial(direct_translation, xops.Gt)
xla_translations[less_p] = partial(direct_translation, xops.Lt)
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
(x_aval,), (x,) = in_avals, in_vals
@ -1708,6 +1758,7 @@ class DeviceArray:
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[DeviceArray] = lambda x: x.buf
jax_types.add(DeviceArray)
@ -1750,6 +1801,20 @@ print(ydot)
# where the application of `f_lin` does not redo any of the linearization work.
# We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.
#
# Tangentially, now that we have linear arrows `-o`, we can provide a slightly
# more informative type for `jvp`:
# ```
# jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
# ```
# Here we're writing `UnrestrictedUse` just to indicate that we have a special
# pair where the first element can be used in an unrestricted (nonlinear) way.
# In conjunction with the linear arrow, this notation is just meant to express
# that the function `jvp f` uses its first input in a nonlinear way but its
# second input in a linear way, producing a corresponding nonlinear output
# (which can be used in a nonlinear way) paired with a linear output. This more
# refined type signature encodes the data dependencies in `jvp f`, which are
# useful for partial evaluation.
#
# To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:
# we evaluate all the primal values as we trace, but stage the tangent
# computations into a jaxpr. This is our second way to build jaxprs. But where
@ -1760,18 +1825,15 @@ print(ydot)
# First, some utilities:
# +
def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
return lst[:n], lst[n:]
def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:
assert not len(lst) % 2
return split_list(lst, len(lst) // 2)
def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
lists[b].append(x)
return lst1, lst2
def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]:
l1, l2 = iter(l1), iter(l2)
out = [next(l2) if b else next(l1) for b in which]
assert next(l1, None) is next(l2, None) is None
return out
# -
# Next, we'll write `linearize` by combining `jvp` together with a general
@ -1816,24 +1878,22 @@ def vspace(aval: ShapedArray) -> ShapedArray:
# callable's computation which can only be performed after the remaining inputs
# are known.
#
# This transformation can't be summarized purely in a type signature because its
# behavior relies on the data dependencies inside the given Python callable and
# not just its type. Nevertheless a heuristic type signature is useful. If we
# This transformation is tricky to summarize in a type signature. If we
# assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where
# `a1` and `a2` represent the known and unknown inputs, respectively, and where
# `b1` only has a data dependency on `a1` while `b2` has some data dependency on
# `a2`, then we might write
#
# ```
# partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> (b1, res, (res, a2) -> b2)
# partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
# ```
#
# In words, given values for the inputs of type `a1`, `partial_eval` produces
# the outputs of type `b1` along with "residual" values of type `res`
# representing the intermediates required to complete the computation in the
# second stage. It also produces a function of type `(res, a2) -> b2` which
# accepts the residual values as well as the remaining inputs and produces the
# remaining outputs.
# the outputs of type `b1` along with "residual" values of
# existentially-quantified type `r` representing the intermediates required to
# complete the computation in the second stage. It also produces a function of
# type `(r, a2) -> b2` which accepts the residual values as well as the
# remaining inputs and produces the remaining outputs.
#
# We like to think of partial evaluation as "unzipping" one computation into
# two. For example, consider this jaxpr:
@ -1845,7 +1905,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
# ```
# A jaxpr for the JVP would look like:
# ```
# { lambda a:float64[] b:float64 .
# { lambda a:float64[] b:float64[] .
# let c:float64[] = sin a
# d:float64[] = cos a
# e:float64[] = mul d b
@ -2073,7 +2133,7 @@ def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe
def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
# +
# + tags=["hide-input"]
def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
if not out_nodes: return []
out_nodes = remove_duplicates(out_nodes)
@ -2125,6 +2185,10 @@ print(sin_lin(1.), cos(3.))
# To handle `linearize`-of-`jit`, we still need to write a partial evaluation
# rule for `xla_call_p`. Other than tracer bookkeeping, the main task is to
# perform partial evaluation of a jaxpr, 'unzipping' it into two jaxprs.
#
# There are actually two rules to write: one for trace-time partial evaluation,
# which we'll call `xla_call_partial_eval`, and one for partial evaluation of
# jaxprs, whicch we'll call `xla_call_peval_eqn`.
# +
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
@ -2142,18 +2206,17 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
dict(jaxpr=jaxpr2, num_consts=0),
[v.aval for v in jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
outs1, outs2 = iter(outs1), iter(outs2)
return [next(outs2) if uk else next(outs1) for uk in out_unknowns]
return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool]
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
instantiate: Optional[List[bool]] = None,
) -> Tuple[Jaxpr, Jaxpr, List[bool], int]:
env: Dict[Var, bool] = {}
residuals = set()
def read(v: Atom) -> bool:
if type(v) is Lit: raise NotImplementedError
return env[v]
return type(v) is Var and env[v]
def write(unk: bool, v: Var) -> None:
env[v] = unk
@ -2178,6 +2241,11 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool]
eqns1.append(eqn)
map(partial(write, False), eqn.out_binders)
out_unknowns = map(read, jaxpr.outs)
if instantiate is not None:
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
if inst and not uk: new_res(v)
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
@ -2209,7 +2277,7 @@ def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
@ -2393,10 +2461,11 @@ transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache()
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: Tuple[bool, ...]
) -> Tuple[Jaxpr, List[Any]]:
traceable = partial(eval_jaxpr_transposed, jaxpr)
avals_in, avals_out = typecheck_jaxpr(jaxpr)
traceable = partial(eval_jaxpr_transposed, jaxpr)
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
typecheck_jaxpr(trans_jaxpr)
return trans_jaxpr, consts
# -
@ -2477,3 +2546,411 @@ _, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
# -
# ## Part 5: the control flow primitives `cond` and `while_loop`
#
# Next we'll add higher-order primitives for staged-out control flow. These
# resemble `jit` from Part 3, another higher-order primitive, but differ in that
# they are parameterized by multiple callables rather than just one.
# ### Adding `cond`
#
# We introduce a `cond` primitive to represent conditional application of one
# function or another inside a jaxpr. We write the type of `cond` as
# `Bool -> (a -> b) -> (a -> b) -> a -> b`. In words, `cond` takes a boolean
# representing the predicate and two functions of equal types. Depending on the
# value of the predicate, it applies one function or the other to its final
# argument.
#
# In Python, we represent it as a function which itself takes two functions as
# arguments. As with `jit`, the first step is to call `make_jaxpr` on its
# callable arguments to turn them into jaxprs:
# +
def cond(pred, true_fn, false_fn, *operands):
avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
if out_tree != out_tree_: raise TypeError
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
raise TypeError
outs = bind_cond(pred, *true_consts, *false_consts, *operands,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')
def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> Tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
consts1, rest1 = split_list(jaxpr1.in_binders, n1)
consts2, rest2 = split_list(jaxpr2.in_binders, n2)
new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
return new_jaxpr1, new_jaxpr2
def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
# -
# We require `true_jaxpr` and `false_jaxpr` to have the same type, but because
# they might close over different constants (and because jaxprs can only
# represent closed terms, i.e. can't have free variables and are instead
# closure-converted) we need to use the helper `_join_jaxpr_consts` to make
# consistent the input binder lists of the two jaxprs. (To be more economical we
# could try to identify pairs of constants with the same shapes, but instead we
# just concatenate the lists of constants.)
#
# Next we can turn to adding interpreter rules for `cond`. Its evaluation rule
# is simple:
def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
if pred:
return eval_jaxpr(true_jaxpr, operands)
else:
return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
out = cond(True, lambda: 3, lambda: 4)
print(out)
# For its JVP and vmap rules, we only need to call the same `jvp_jaxpr` and
# `vmap_jaxpr` utilities we created for `jit`, followed by another pass of
# `_join_jaxpr_consts`:
def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
pred, *primals = primals
_ , *tangents = tangents
true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)
false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
pred , *vals_in = vals_in
pred_dim, *dims_in = dims_in
if pred_dim is not not_mapped: raise NotImplementedError # TODO
true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
# Notice that we're not currently supporting the case where the predicate value
# itself is batched. In mainline JAX, we handle this case by transforming the
# conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).
# That transformation is semantically correct so long as `true_fun` and
# `false_fun` do not involve any side-effecting primitives.
#
# Another thing not represented here, but present in the mainline JAX, is that
# applying transformations to two jaxprs of equal type might result in jaxprs of
# different types. For example, applying the mainline JAX version of
# `vmap_jaxpr` to the identity-function jaxpr
#
# ```
# { lambda a:float32[] .
# let
# in ( a ) }
# ```
#
# would result in a jaxpr with a batched output, of type
# `[float32[10]] -> [float32[10]]` if the batch size were 10, while applying it
# to the zero-function jaxpr
#
# ```
# { lambda a:float32[] .
# let
# in ( 0. ) }
# ```
#
# would result in a jaxpr with an unbatched output, of type
# `[float32[10]] -> [float32[]]`. This is an optimization, aimed at not batching
# values unnecessarily. But it means that in `cond` we'd need an extra step of
# joining the two transformed jaxprs to have consistent output types. We don't
# need this step here because we chose `vmap_jaxpr` always to batch all outputs
# over the leading axis.
# Next we can turn to abstract evaluation and XLA lowering rules:
# +
def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
jaxpr_type = typecheck_jaxpr(true_jaxpr)
if jaxpr_type != typecheck_jaxpr(false_jaxpr):
raise TypeError
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
del in_avals # Unused.
pred, *in_vals = in_vals
flat_vals, in_tree = tree_flatten(in_vals)
operand = xops.Tuple(c, flat_vals)
operand_shape = c.get_shape(operand)
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
c = xb.make_computation_builder(name)
operand = xb.parameter(c, 0, operand_shape)
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
outs = jaxpr_subcomp(c, jaxpr, operands)
return c.build(xops.Tuple(c, outs))
true_comp = make_comp('true_fn', true_jaxpr)
false_comp = make_comp('false_fn', false_jaxpr)
int_etype = xc.dtype_to_etype(np.dtype('int32'))
out = xops.Conditional(xops.ConvertElementType(pred, int_etype),
[false_comp, true_comp], [operand] * 2)
return destructure_tuple(c, out)
xla_translations[cond_p] = cond_translation
# -
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
# Finally, to support reverse-mode automatic differentiation, we need partial
# evaluation and transposition rules. For partial evaluation, we need to
# introduce another jaxpr-munging utility, `_join_jaxpr_res`, to handle the fact
# that applying partial evaluation to `true_fun` and `false_fun` will in general
# result in distinct residuals. We use `_join_jaxpr_res` to make the output
# types of the transformed jaxprs consistent (while `_join_jaxpr_consts` dealt
# with input types).
# +
def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
pred_tracer, *tracers = tracers
assert pred_tracer.pval.is_known
pred = pred_tracer.pval.const
in_uks = [not t.pval.is_known for t in tracers]
*jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
known_tracers, unknown_tracers = partition_list(in_uks, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind_cond(pred, *known_vals,
true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in t_jaxpr2.outs]
eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
[v.aval for v in t_jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval
def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: List[bool]
) -> Tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, List[bool], int]:
_, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
_, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
out_uks = map(op.or_, t_out_uks, f_out_uks)
t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)
t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
num_res = t_nres + f_nres
return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res
def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> Tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
assert out_types1 == out_types2
outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
return new_jaxpr1, new_jaxpr2
# -
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
pred_unk, *unks_in = unks_in
assert not pred_unk
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
*jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
outs1 + residuals)
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
outs2)
return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
# Transposition is a fairly straightforward application of `transpose_jaxpr`:
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
undef_primals = tuple([type(x) is UndefPrimal for x in invals])
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
res = [x for x in invals if type(x) is not UndefPrimal]
outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
outs = iter(outs)
return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
# ### Adding `while_loop`
#
# Next we'll add a primitive for looping behavior in a jaxpr. We'll use
# `while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first
# function-valued argument represents the loop condition, the second represents
# the loop body, and the final argument is the initial value of the carry.
#
# After `cond`, adding `while_loop` is not so different:
def while_loop(cond_fn, body_fn, init_val):
init_val, in_tree = tree_flatten(init_val)
avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]
cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)
body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)
cond_jaxpr, body_jaxpr = _join_jaxpr_consts(
cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))
if cond_tree != tree_flatten(True)[1]: raise TypeError
if in_tree != in_tree_: raise TypeError
outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
return tree_unflatten(in_tree, outs)
while_loop_p = Primitive('while_loop')
def while_loop_impl(*args, cond_jaxpr, body_jaxpr):
consts, carry = split_list(args, _loop_num_consts(body_jaxpr))
while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:
carry = eval_jaxpr(body_jaxpr, [*consts, *carry])
return carry
impl_rules[while_loop_p] = while_loop_impl
def _loop_num_consts(body_jaxpr: Jaxpr) -> int:
return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)
out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)
print(out)
# Notice the convention that `args = [*consts, *carry]`.
#
# The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the
# convention that all the binders for tangent values are appended after all the
# binders for primal values, like `args = [*primals, *tangents]`. But that's in
# tension with our `while_loop` convention that the carry binders come after all
# the constant binders, i.e. that `args = [*consts, *carry]`, because both the
# constants and the carries can have their own tangents. For this reason, we
# introduce the `_loop_jvp_binders` helper to rearrange binders as needed.
# +
def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):
num_consts = _loop_num_consts(body_jaxpr)
body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)
cond_jaxpr, body_jaxpr = _loop_jvp_binders(
cond_jaxpr, body_jaxpr, len(body_consts), num_consts)
outs = bind(while_loop_p, *body_consts, *primals, *tangents,
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[while_loop_p] = while_loop_jvp_rule
def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int
) -> Jaxpr:
# body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]
jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)
primal_binders, tangent_binders = split_half(binders)
consts , carry = split_list(primal_binders , n2)
consts_dot, carry_dot = split_list(tangent_binders, n2)
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)
typecheck_jaxpr(new_body_jaxpr)
# cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]
assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)
consts, carry = split_list(cond_jaxpr.in_binders, n2)
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)
return new_cond_jaxpr, new_body_jaxpr
# -
out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),
(1.,), (1.,))
print(out_tan)
# +
def f(x):
def cond_fn(i, _):
return i < 3
def body_fn(i, x):
return i + 1, cos(x)
_, y = while_loop(cond_fn, body_fn, (0, x))
return y
def g(x):
return cos(cos(cos(x)))
print(jvp(f, (1.,), (1.,)))
print(jvp(g, (1.,), (1.,)))
# -
# The vmap rule for `while_loop` presents two cases:
# 1. if the output of `cond_fun` is not batched, then the loop has the same
# basic structure, just with a batched body;
# 2. but if the output of `cond_fun` is batched, we must represent a batch of
# loops which might run for different numbers of iterations.
#
# ...Stay tuned for the thrilling conclusion!