mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
autodidax: add cond and start while_loop
This commit is contained in:
parent
6ce4ef46b9
commit
83cd42271b
File diff suppressed because it is too large
Load Diff
@ -95,6 +95,7 @@ sin_p = Primitive("sin")
|
||||
cos_p = Primitive("cos")
|
||||
reduce_sum_p = Primitive("reduce_sum")
|
||||
greater_p = Primitive("greater")
|
||||
less_p = Primitive("less")
|
||||
transpose_p = Primitive("transpose")
|
||||
broadcast_p = Primitive("broadcast")
|
||||
|
||||
@ -105,6 +106,7 @@ def sin(x): return bind1(sin_p, x)
|
||||
def cos(x): return bind1(cos_p, x)
|
||||
def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)
|
||||
def greater(x, y): return bind1(greater_p, x, y)
|
||||
def less(x, y): return bind1(less_p, x, y)
|
||||
def transpose(x, perm): return bind1(transpose_p, perm=perm)
|
||||
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
|
||||
|
||||
@ -138,7 +140,7 @@ more descriptive.
|
||||
|
||||
```{code-cell}
|
||||
from contextlib import contextmanager
|
||||
from typing import Type, List, Optional, Any
|
||||
from typing import Type, List, Tuple, Sequence, Optional, Any
|
||||
|
||||
class MainTrace(NamedTuple):
|
||||
level: int
|
||||
@ -209,7 +211,6 @@ like arrays.)
|
||||
|
||||
```{code-cell}
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
class Tracer:
|
||||
_trace: Trace
|
||||
@ -229,6 +230,7 @@ class Tracer:
|
||||
def __mul__(self, other): return self.aval._mul(self, other)
|
||||
def __rmul__(self, other): return self.aval._rmul(self, other)
|
||||
def __gt__(self, other): return self.aval._gt(self, other)
|
||||
def __lt__(self, other): return self.aval._lt(self, other)
|
||||
def __bool__(self): return self.aval._bool(self)
|
||||
def __nonzero__(self): return self.aval._nonzero(self)
|
||||
|
||||
@ -261,6 +263,7 @@ class ShapedArray:
|
||||
_mul = staticmethod(mul)
|
||||
_rmul = staticmethod(swap(mul))
|
||||
_gt = staticmethod(greater)
|
||||
_lt = staticmethod(less)
|
||||
|
||||
@staticmethod
|
||||
def _bool(tracer):
|
||||
@ -421,6 +424,7 @@ impl_rules[sin_p] = lambda x: [np.sin(x)]
|
||||
impl_rules[cos_p] = lambda x: [np.cos(x)]
|
||||
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
|
||||
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
|
||||
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
|
||||
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
|
||||
|
||||
def broadcast_impl(x, *, shape, axes):
|
||||
@ -452,7 +456,8 @@ First, a few helper functions:
|
||||
|
||||
```{code-cell}
|
||||
def zeros_like(val):
|
||||
return np.zeros_like(val)
|
||||
aval = get_aval(val)
|
||||
return np.zeros(aval.shape, aval.dtype)
|
||||
|
||||
def unzip2(pairs):
|
||||
lst1, lst2 = [], []
|
||||
@ -464,6 +469,14 @@ def unzip2(pairs):
|
||||
map_ = map
|
||||
def map(f, *xs):
|
||||
return list(map_(f, *xs))
|
||||
|
||||
zip_ = zip
|
||||
def zip(*args):
|
||||
fst, *rest = args = map(list, args)
|
||||
n = len(fst)
|
||||
for arg in rest:
|
||||
assert len(arg) == n
|
||||
return list(zip_(*args))
|
||||
```
|
||||
|
||||
The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The
|
||||
@ -533,6 +546,12 @@ def greater_jvp(primals, tangents):
|
||||
out_primal = greater(x, y)
|
||||
return [out_primal], [zeros_like(out_primal)]
|
||||
jvp_rules[greater_p] = greater_jvp
|
||||
|
||||
def less_jvp(primals, tangents):
|
||||
(x, y), _ = primals, tangents
|
||||
out_primal = less(x, y)
|
||||
return [out_primal], [zeros_like(out_primal)]
|
||||
jvp_rules[less_p] = less_jvp
|
||||
```
|
||||
|
||||
Finally, we add a transformation API to kick off the trace:
|
||||
@ -833,16 +852,17 @@ Next we can define batching interpreter rules for each primitive:
|
||||
```{code-cell}
|
||||
from functools import partial
|
||||
|
||||
def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):
|
||||
def binop_batching_rule(op, axis_size, vals_in, dims_in):
|
||||
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
|
||||
if x_bdim != y_bdim:
|
||||
if x_bdim is not_mapped:
|
||||
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
|
||||
x_bdim = y_bdim
|
||||
else:
|
||||
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
|
||||
return [op(x, y)], [x_bdim]
|
||||
vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)
|
||||
vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)
|
||||
vmap_rules[add_p] = partial(binop_batching_rule, add)
|
||||
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
|
||||
|
||||
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
|
||||
(x,), (x_bdim,) = vals_in, dims_in
|
||||
@ -917,7 +937,7 @@ That's it for `jvp` and `vmap`!
|
||||
|
||||
## Part 2: Jaxprs
|
||||
|
||||
The next transformations are the horizon are `jit` for just-in-time
|
||||
The next transformations on the horizon are `jit` for just-in-time
|
||||
compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small
|
||||
wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to
|
||||
carry a little bit of extra context, for both `jit` and `vjp` we need much
|
||||
@ -984,8 +1004,8 @@ class Lit:
|
||||
aval: ShapedArray
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
self.aval = raise_to_shaped(get_aval(self.val))
|
||||
self.aval = aval = raise_to_shaped(get_aval(val))
|
||||
self.val = np.array(val, aval.dtype)
|
||||
|
||||
Atom = Union[Var, Lit]
|
||||
|
||||
@ -1088,6 +1108,19 @@ a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one
|
||||
used by `jit`, which is also used by control flow primitives like `lax.cond`,
|
||||
`lax.while_loop`, and `lax.scan`.
|
||||
|
||||
```{code-cell}
|
||||
def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
|
||||
assert 0 <= n <= len(lst)
|
||||
return lst[:n], lst[n:]
|
||||
|
||||
def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
assert len(bs) == len(l)
|
||||
lists = lst1, lst2 = [], []
|
||||
for b, x in zip(bs, l):
|
||||
lists[b].append(x)
|
||||
return lst1, lst2
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
|
||||
class JaxprTracer(Tracer):
|
||||
@ -1181,9 +1214,25 @@ class JaxprBuilder:
|
||||
out_vars = [t2v(t) for t in out_tracers]
|
||||
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
|
||||
typecheck_jaxpr(jaxpr)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
return jaxpr, constvals
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def _inline_literals(jaxpr: Jaxpr, consts: List[Any]) -> Tuple[Jaxpr, List[Any]]:
|
||||
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
|
||||
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
|
||||
new_const_binders, lit_binders = partition_list(scalars, const_binders)
|
||||
new_consts, lit_vals = partition_list(scalars, consts)
|
||||
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
|
||||
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
|
||||
eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
|
||||
new_outs = [literals.get(x, x) for x in jaxpr.outs]
|
||||
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
|
||||
typecheck_jaxpr(new_jaxpr)
|
||||
return new_jaxpr, new_consts
|
||||
```
|
||||
|
||||
The rules we need for `JaxprTrace.process_primitive` are essentially typing
|
||||
rules for primitive applications: given the primitive, its parameters, and
|
||||
types for the inputs, the rule must produce a type for the output, which is
|
||||
@ -1196,36 +1245,38 @@ rules for the other jaxpr-producing trace machinery, where the potential extra
|
||||
generality is useful.
|
||||
|
||||
```{code-cell}
|
||||
def broadcast_shapes(*shapes):
|
||||
assert len(shapes) > 1
|
||||
for sizes in zip(*shapes):
|
||||
sizes = [d for d in sizes if d != 1]
|
||||
if sizes[:-1] != sizes[1:]:
|
||||
raise Exception
|
||||
return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes))
|
||||
def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:
|
||||
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
|
||||
raise TypeError
|
||||
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
|
||||
return [ShapedArray(x.shape, x.dtype)]
|
||||
|
||||
def broadcasting_binop_abstract_eval_rule(*avals_in):
|
||||
out_dtype = np.result_type(*map(np.result_type, avals_in))
|
||||
out_shape = broadcast_shapes(*map(np.shape, avals_in))
|
||||
return [ShapedArray(out_shape, out_dtype)]
|
||||
abstract_eval_rules[add_p] = binop_abstract_eval
|
||||
abstract_eval_rules[mul_p] = binop_abstract_eval
|
||||
|
||||
abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule
|
||||
abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule
|
||||
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:
|
||||
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
|
||||
raise TypeError
|
||||
if x.shape != y.shape: raise TypeError
|
||||
return [ShapedArray(x.shape, np.dtype('bool'))]
|
||||
abstract_eval_rules[greater_p] = compare_abstract_eval
|
||||
abstract_eval_rules[less_p] = compare_abstract_eval
|
||||
|
||||
def vectorized_unop_abstract_eval_rule(aval_in):
|
||||
return [ShapedArray(np.shape(aval_in), np.result_type(aval_in))]
|
||||
def vectorized_unop_abstract_eval(x: ShapedArray) -> List[ShapedArray]:
|
||||
return [ShapedArray(x.shape, x.dtype)]
|
||||
|
||||
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule
|
||||
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule
|
||||
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval_rule
|
||||
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
|
||||
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
|
||||
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
|
||||
|
||||
def reduce_sum_abstract_eval_rule(aval_in, *, axis):
|
||||
new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis]
|
||||
return [ShapedArray(tuple(new_shape), aval_in.dtype)]
|
||||
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule
|
||||
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: int) -> List[ShapedArray]:
|
||||
new_shape = [d for i, d in enumerate(x.shape) if i != axis]
|
||||
return [ShapedArray(tuple(new_shape), x.dtype)]
|
||||
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
|
||||
|
||||
def broadcast_abstract_eval(x, *, shape, axes):
|
||||
return [ShapedArray(tuple(shape), np.result_type(x))]
|
||||
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
|
||||
axes: Sequence[int]) -> List[ShapedArray]:
|
||||
return [ShapedArray(tuple(shape), x.dtype)]
|
||||
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
|
||||
```
|
||||
|
||||
@ -1501,6 +1552,7 @@ impl_rules[xla_call_p] = xla_call_impl
|
||||
@lru_cache()
|
||||
def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable]):
|
||||
jaxpr: Jaxpr = hashable_jaxpr.val
|
||||
typecheck_jaxpr(jaxpr)
|
||||
consts = [x.val for x in hashable_consts]
|
||||
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
|
||||
c = xb.make_computation_builder('xla_call')
|
||||
@ -1534,7 +1586,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
|
||||
env: Dict[Var, xe.XlaOp] = {}
|
||||
|
||||
def read(x: Atom) -> xe.XlaOp:
|
||||
return env[x] if type(x) is Var else xb.constant(c, x.val)
|
||||
return env[x] if type(x) is Var else xb.constant(c, x.val, False)
|
||||
|
||||
def write(v: Var, val: xe.XlaOp) -> None:
|
||||
env[v] = val
|
||||
@ -1555,7 +1607,7 @@ def execute_compiled(compiled, out_avals, *args):
|
||||
|
||||
default_input_handler = xb.get_backend(None).buffer_from_pyval
|
||||
input_handlers = {ty: default_input_handler for ty in
|
||||
[int, float, np.ndarray, np.float64, np.float32]}
|
||||
[bool, int, float, np.ndarray, np.float64, np.float32]}
|
||||
|
||||
def handle_result(aval: ShapedArray, buf):
|
||||
del aval # Unused for now.
|
||||
@ -1580,6 +1632,7 @@ xla_translations[neg_p] = partial(direct_translation, xops.Neg)
|
||||
xla_translations[sin_p] = partial(direct_translation, xops.Sin)
|
||||
xla_translations[cos_p] = partial(direct_translation, xops.Cos)
|
||||
xla_translations[greater_p] = partial(direct_translation, xops.Gt)
|
||||
xla_translations[less_p] = partial(direct_translation, xops.Lt)
|
||||
|
||||
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
|
||||
(x_aval,), (x,) = in_avals, in_vals
|
||||
@ -1784,6 +1837,7 @@ class DeviceArray:
|
||||
_mul = staticmethod(mul)
|
||||
_rmul = staticmethod(mul)
|
||||
_gt = staticmethod(greater)
|
||||
_lt = staticmethod(less)
|
||||
input_handlers[DeviceArray] = lambda x: x.buf
|
||||
|
||||
jax_types.add(DeviceArray)
|
||||
@ -1829,6 +1883,20 @@ y, y_dot = jvp(f, (x,), (x_dot,))
|
||||
where the application of `f_lin` does not redo any of the linearization work.
|
||||
We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.
|
||||
|
||||
Tangentially, now that we have linear arrows `-o`, we can provide a slightly
|
||||
more informative type for `jvp`:
|
||||
```
|
||||
jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
|
||||
```
|
||||
Here we're writing `UnrestrictedUse` just to indicate that we have a special
|
||||
pair where the first element can be used in an unrestricted (nonlinear) way.
|
||||
In conjunction with the linear arrow, this notation is just meant to express
|
||||
that the function `jvp f` uses its first input in a nonlinear way but its
|
||||
second input in a linear way, producing a corresponding nonlinear output
|
||||
(which can be used in a nonlinear way) paired with a linear output. This more
|
||||
refined type signature encodes the data dependencies in `jvp f`, which are
|
||||
useful for partial evaluation.
|
||||
|
||||
To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:
|
||||
we evaluate all the primal values as we trace, but stage the tangent
|
||||
computations into a jaxpr. This is our second way to build jaxprs. But where
|
||||
@ -1839,18 +1907,15 @@ primitive binds with a data dependence on tangent inputs.
|
||||
First, some utilities:
|
||||
|
||||
```{code-cell}
|
||||
def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
|
||||
return lst[:n], lst[n:]
|
||||
|
||||
def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
assert not len(lst) % 2
|
||||
return split_list(lst, len(lst) // 2)
|
||||
|
||||
def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
lists = lst1, lst2 = [], []
|
||||
for b, x in zip(bs, l):
|
||||
lists[b].append(x)
|
||||
return lst1, lst2
|
||||
def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]:
|
||||
l1, l2 = iter(l1), iter(l2)
|
||||
out = [next(l2) if b else next(l1) for b in which]
|
||||
assert next(l1, None) is next(l2, None) is None
|
||||
return out
|
||||
```
|
||||
|
||||
Next, we'll write `linearize` by combining `jvp` together with a general
|
||||
@ -1895,24 +1960,22 @@ inputs, together with (2) a jaxpr representing the part of the Python
|
||||
callable's computation which can only be performed after the remaining inputs
|
||||
are known.
|
||||
|
||||
This transformation can't be summarized purely in a type signature because its
|
||||
behavior relies on the data dependencies inside the given Python callable and
|
||||
not just its type. Nevertheless a heuristic type signature is useful. If we
|
||||
This transformation is tricky to summarize in a type signature. If we
|
||||
assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where
|
||||
`a1` and `a2` represent the known and unknown inputs, respectively, and where
|
||||
`b1` only has a data dependency on `a1` while `b2` has some data dependency on
|
||||
`a2`, then we might write
|
||||
|
||||
```
|
||||
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> (b1, res, (res, a2) -> b2)
|
||||
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
|
||||
```
|
||||
|
||||
In words, given values for the inputs of type `a1`, `partial_eval` produces
|
||||
the outputs of type `b1` along with "residual" values of type `res`
|
||||
representing the intermediates required to complete the computation in the
|
||||
second stage. It also produces a function of type `(res, a2) -> b2` which
|
||||
accepts the residual values as well as the remaining inputs and produces the
|
||||
remaining outputs.
|
||||
the outputs of type `b1` along with "residual" values of
|
||||
existentially-quantified type `r` representing the intermediates required to
|
||||
complete the computation in the second stage. It also produces a function of
|
||||
type `(r, a2) -> b2` which accepts the residual values as well as the
|
||||
remaining inputs and produces the remaining outputs.
|
||||
|
||||
We like to think of partial evaluation as "unzipping" one computation into
|
||||
two. For example, consider this jaxpr:
|
||||
@ -1924,7 +1987,7 @@ two. For example, consider this jaxpr:
|
||||
```
|
||||
A jaxpr for the JVP would look like:
|
||||
```
|
||||
{ lambda a:float64[] b:float64 .
|
||||
{ lambda a:float64[] b:float64[] .
|
||||
let c:float64[] = sin a
|
||||
d:float64[] = cos a
|
||||
e:float64[] = mul d b
|
||||
@ -2158,6 +2221,8 @@ def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:tags: [hide-input]
|
||||
|
||||
def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
|
||||
if not out_nodes: return []
|
||||
out_nodes = remove_duplicates(out_nodes)
|
||||
@ -2212,6 +2277,10 @@ To handle `linearize`-of-`jit`, we still need to write a partial evaluation
|
||||
rule for `xla_call_p`. Other than tracer bookkeeping, the main task is to
|
||||
perform partial evaluation of a jaxpr, 'unzipping' it into two jaxprs.
|
||||
|
||||
There are actually two rules to write: one for trace-time partial evaluation,
|
||||
which we'll call `xla_call_partial_eval`, and one for partial evaluation of
|
||||
jaxprs, whicch we'll call `xla_call_peval_eqn`.
|
||||
|
||||
```{code-cell}
|
||||
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
@ -2228,18 +2297,17 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
dict(jaxpr=jaxpr2, num_consts=0),
|
||||
[v.aval for v in jaxpr2.outs], map(ref, outs2))
|
||||
for t in outs2: t.recipe = eqn
|
||||
outs1, outs2 = iter(outs1), iter(outs2)
|
||||
return [next(outs2) if uk else next(outs1) for uk in out_unknowns]
|
||||
return merge_lists(out_unknowns, outs1, outs2)
|
||||
partial_eval_rules[xla_call_p] = xla_call_partial_eval
|
||||
|
||||
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool]
|
||||
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
|
||||
instantiate: Optional[List[bool]] = None,
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], int]:
|
||||
env: Dict[Var, bool] = {}
|
||||
residuals = set()
|
||||
|
||||
def read(v: Atom) -> bool:
|
||||
if type(v) is Lit: raise NotImplementedError
|
||||
return env[v]
|
||||
return type(v) is Var and env[v]
|
||||
|
||||
def write(unk: bool, v: Var) -> None:
|
||||
env[v] = unk
|
||||
@ -2264,6 +2332,11 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool]
|
||||
eqns1.append(eqn)
|
||||
map(partial(write, False), eqn.out_binders)
|
||||
out_unknowns = map(read, jaxpr.outs)
|
||||
if instantiate is not None:
|
||||
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
|
||||
if inst and not uk: new_res(v)
|
||||
out_unknowns = map(op.or_, out_unknowns, instantiate)
|
||||
|
||||
residuals, num_res = list(residuals), len(residuals)
|
||||
|
||||
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
|
||||
@ -2295,7 +2368,7 @@ def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
|
||||
|
||||
partial_eval_jaxpr_rules = {}
|
||||
|
||||
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn
|
||||
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
|
||||
@ -2481,10 +2554,11 @@ transpose_rules[xla_call_p] = xla_call_transpose_rule
|
||||
@lru_cache()
|
||||
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: Tuple[bool, ...]
|
||||
) -> Tuple[Jaxpr, List[Any]]:
|
||||
traceable = partial(eval_jaxpr_transposed, jaxpr)
|
||||
avals_in, avals_out = typecheck_jaxpr(jaxpr)
|
||||
traceable = partial(eval_jaxpr_transposed, jaxpr)
|
||||
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
|
||||
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
|
||||
typecheck_jaxpr(trans_jaxpr)
|
||||
return trans_jaxpr, consts
|
||||
```
|
||||
|
||||
@ -2571,3 +2645,441 @@ _, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
|
||||
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
|
||||
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
|
||||
```
|
||||
|
||||
## Part 5: the control flow primitives `cond` and `while_loop`
|
||||
|
||||
Next we'll add higher-order primitives for staged-out control flow. These
|
||||
resemble `jit` from Part 3, another higher-order primitive, but differ in that
|
||||
they are parameterized by multiple callables rather than just one.
|
||||
|
||||
+++
|
||||
|
||||
### Adding `cond`
|
||||
|
||||
We introduce a `cond` primitive to represent conditional application of one
|
||||
function or another inside a jaxpr. We write the type of `cond` as
|
||||
`Bool -> (a -> b) -> (a -> b) -> a -> b`. In words, `cond` takes a boolean
|
||||
representing the predicate and two functions of equal types. Depending on the
|
||||
value of the predicate, it applies one function or the other to its final
|
||||
argument.
|
||||
|
||||
In Python, we represent it as a function which itself takes two functions as
|
||||
arguments. As with `jit`, the first step is to call `make_jaxpr` on its
|
||||
callable arguments to turn them into jaxprs:
|
||||
|
||||
```{code-cell}
|
||||
def cond(pred, true_fn, false_fn, *operands):
|
||||
avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
|
||||
true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
|
||||
false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
|
||||
if out_tree != out_tree_: raise TypeError
|
||||
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
|
||||
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
|
||||
if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
|
||||
raise TypeError
|
||||
outs = bind_cond(pred, *true_consts, *false_consts, *operands,
|
||||
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
return tree_unflatten(out_tree, outs)
|
||||
cond_p = Primitive('cond')
|
||||
|
||||
def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
|
||||
) -> Tuple[Jaxpr, Jaxpr]:
|
||||
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
|
||||
assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
|
||||
consts1, rest1 = split_list(jaxpr1.in_binders, n1)
|
||||
consts2, rest2 = split_list(jaxpr2.in_binders, n2)
|
||||
new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
|
||||
new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
|
||||
return new_jaxpr1, new_jaxpr2
|
||||
|
||||
def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
|
||||
assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
|
||||
return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
```
|
||||
|
||||
We require `true_jaxpr` and `false_jaxpr` to have the same type, but because
|
||||
they might close over different constants (and because jaxprs can only
|
||||
represent closed terms, i.e. can't have free variables and are instead
|
||||
closure-converted) we need to use the helper `_join_jaxpr_consts` to make
|
||||
consistent the input binder lists of the two jaxprs. (To be more economical we
|
||||
could try to identify pairs of constants with the same shapes, but instead we
|
||||
just concatenate the lists of constants.)
|
||||
|
||||
Next we can turn to adding interpreter rules for `cond`. Its evaluation rule
|
||||
is simple:
|
||||
|
||||
```{code-cell}
|
||||
def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
|
||||
if pred:
|
||||
return eval_jaxpr(true_jaxpr, operands)
|
||||
else:
|
||||
return eval_jaxpr(false_jaxpr, operands)
|
||||
impl_rules[cond_p] = cond_impl
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
out = cond(True, lambda: 3, lambda: 4)
|
||||
print(out)
|
||||
```
|
||||
|
||||
For its JVP and vmap rules, we only need to call the same `jvp_jaxpr` and
|
||||
`vmap_jaxpr` utilities we created for `jit`, followed by another pass of
|
||||
`_join_jaxpr_consts`:
|
||||
|
||||
```{code-cell}
|
||||
def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
|
||||
pred, *primals = primals
|
||||
_ , *tangents = tangents
|
||||
true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)
|
||||
false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
|
||||
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
|
||||
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
|
||||
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
|
||||
outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
|
||||
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
primals_out, tangents_out = split_half(outs)
|
||||
return primals_out, tangents_out
|
||||
jvp_rules[cond_p] = cond_jvp_rule
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
|
||||
print(out_tan)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
|
||||
pred , *vals_in = vals_in
|
||||
pred_dim, *dims_in = dims_in
|
||||
if pred_dim is not not_mapped: raise NotImplementedError # TODO
|
||||
true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
|
||||
false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
|
||||
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
|
||||
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
|
||||
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
|
||||
outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
|
||||
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
return outs, [0] * len(outs)
|
||||
vmap_rules[cond_p] = cond_vmap_rule
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
xs = np.array([1., 2., 3])
|
||||
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
|
||||
print(out)
|
||||
```
|
||||
|
||||
Notice that we're not currently supporting the case where the predicate value
|
||||
itself is batched. In mainline JAX, we handle this case by transforming the
|
||||
conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).
|
||||
That transformation is semantically correct so long as `true_fun` and
|
||||
`false_fun` do not involve any side-effecting primitives.
|
||||
|
||||
Another thing not represented here, but present in the mainline JAX, is that
|
||||
applying transformations to two jaxprs of equal type might result in jaxprs of
|
||||
different types. For example, applying the mainline JAX version of
|
||||
`vmap_jaxpr` to the identity-function jaxpr
|
||||
|
||||
```
|
||||
{ lambda a:float32[] .
|
||||
let
|
||||
in ( a ) }
|
||||
```
|
||||
|
||||
would result in a jaxpr with a batched output, of type
|
||||
`[float32[10]] -> [float32[10]]` if the batch size were 10, while applying it
|
||||
to the zero-function jaxpr
|
||||
|
||||
```
|
||||
{ lambda a:float32[] .
|
||||
let
|
||||
in ( 0. ) }
|
||||
```
|
||||
|
||||
would result in a jaxpr with an unbatched output, of type
|
||||
`[float32[10]] -> [float32[]]`. This is an optimization, aimed at not batching
|
||||
values unnecessarily. But it means that in `cond` we'd need an extra step of
|
||||
joining the two transformed jaxprs to have consistent output types. We don't
|
||||
need this step here because we chose `vmap_jaxpr` always to batch all outputs
|
||||
over the leading axis.
|
||||
|
||||
+++
|
||||
|
||||
Next we can turn to abstract evaluation and XLA lowering rules:
|
||||
|
||||
```{code-cell}
|
||||
def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
|
||||
if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
|
||||
jaxpr_type = typecheck_jaxpr(true_jaxpr)
|
||||
if jaxpr_type != typecheck_jaxpr(false_jaxpr):
|
||||
raise TypeError
|
||||
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
|
||||
raise TypeError
|
||||
return jaxpr_type.out_types
|
||||
abstract_eval_rules[cond_p] = cond_abstract_eval
|
||||
|
||||
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
|
||||
del in_avals # Unused.
|
||||
pred, *in_vals = in_vals
|
||||
flat_vals, in_tree = tree_flatten(in_vals)
|
||||
operand = xops.Tuple(c, flat_vals)
|
||||
operand_shape = c.get_shape(operand)
|
||||
|
||||
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
|
||||
c = xb.make_computation_builder(name)
|
||||
operand = xb.parameter(c, 0, operand_shape)
|
||||
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
|
||||
outs = jaxpr_subcomp(c, jaxpr, operands)
|
||||
return c.build(xops.Tuple(c, outs))
|
||||
|
||||
true_comp = make_comp('true_fn', true_jaxpr)
|
||||
false_comp = make_comp('false_fn', false_jaxpr)
|
||||
|
||||
int_etype = xc.dtype_to_etype(np.dtype('int32'))
|
||||
out = xops.Conditional(xops.ConvertElementType(pred, int_etype),
|
||||
[false_comp, true_comp], [operand] * 2)
|
||||
return destructure_tuple(c, out)
|
||||
xla_translations[cond_p] = cond_translation
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
|
||||
print(out)
|
||||
```
|
||||
|
||||
Finally, to support reverse-mode automatic differentiation, we need partial
|
||||
evaluation and transposition rules. For partial evaluation, we need to
|
||||
introduce another jaxpr-munging utility, `_join_jaxpr_res`, to handle the fact
|
||||
that applying partial evaluation to `true_fun` and `false_fun` will in general
|
||||
result in distinct residuals. We use `_join_jaxpr_res` to make the output
|
||||
types of the transformed jaxprs consistent (while `_join_jaxpr_consts` dealt
|
||||
with input types).
|
||||
|
||||
```{code-cell}
|
||||
def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
|
||||
pred_tracer, *tracers = tracers
|
||||
assert pred_tracer.pval.is_known
|
||||
pred = pred_tracer.pval.const
|
||||
in_uks = [not t.pval.is_known for t in tracers]
|
||||
|
||||
*jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
|
||||
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
|
||||
|
||||
known_tracers, unknown_tracers = partition_list(in_uks, tracers)
|
||||
known_vals = [t.pval.const for t in known_tracers]
|
||||
outs1_res = bind_cond(pred, *known_vals,
|
||||
true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
|
||||
outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
|
||||
pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
|
||||
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
|
||||
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
|
||||
for v in t_jaxpr2.outs]
|
||||
eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
|
||||
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
|
||||
[v.aval for v in t_jaxpr2.outs], map(ref, outs2))
|
||||
for t in outs2: t.recipe = eqn
|
||||
return merge_lists(out_uks, outs1, outs2)
|
||||
partial_eval_rules[cond_p] = cond_partial_eval
|
||||
|
||||
def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: List[bool]
|
||||
) -> Tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, List[bool], int]:
|
||||
_, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
|
||||
_, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
|
||||
out_uks = map(op.or_, t_out_uks, f_out_uks)
|
||||
|
||||
t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
|
||||
f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)
|
||||
|
||||
t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
|
||||
t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
|
||||
assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
|
||||
assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
|
||||
num_res = t_nres + f_nres
|
||||
|
||||
return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res
|
||||
|
||||
def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
|
||||
) -> Tuple[Jaxpr, Jaxpr]:
|
||||
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
|
||||
out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
|
||||
out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
|
||||
assert out_types1 == out_types2
|
||||
outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
|
||||
outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
|
||||
zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
|
||||
zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
|
||||
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
|
||||
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
|
||||
return new_jaxpr1, new_jaxpr2
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
|
||||
out = f_lin(3.14)
|
||||
print(out)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
|
||||
pred_unk, *unks_in = unks_in
|
||||
assert not pred_unk
|
||||
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
|
||||
*jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
|
||||
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
|
||||
ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
|
||||
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
|
||||
residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
|
||||
eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
|
||||
dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
|
||||
outs1 + residuals)
|
||||
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
|
||||
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
|
||||
outs2)
|
||||
return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]
|
||||
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
|
||||
out = f_lin(3.14)
|
||||
print(out)
|
||||
```
|
||||
|
||||
Transposition is a fairly straightforward application of `transpose_jaxpr`:
|
||||
|
||||
```{code-cell}
|
||||
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
|
||||
undef_primals = tuple([type(x) is UndefPrimal for x in invals])
|
||||
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
|
||||
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
|
||||
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
|
||||
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
|
||||
res = [x for x in invals if type(x) is not UndefPrimal]
|
||||
outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
|
||||
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
outs = iter(outs)
|
||||
return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
|
||||
transpose_rules[cond_p] = cond_transpose_rule
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
|
||||
print(out)
|
||||
```
|
||||
|
||||
### Adding `while_loop`
|
||||
|
||||
Next we'll add a primitive for looping behavior in a jaxpr. We'll use
|
||||
`while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first
|
||||
function-valued argument represents the loop condition, the second represents
|
||||
the loop body, and the final argument is the initial value of the carry.
|
||||
|
||||
After `cond`, adding `while_loop` is not so different:
|
||||
|
||||
```{code-cell}
|
||||
def while_loop(cond_fn, body_fn, init_val):
|
||||
init_val, in_tree = tree_flatten(init_val)
|
||||
avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]
|
||||
cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)
|
||||
body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)
|
||||
cond_jaxpr, body_jaxpr = _join_jaxpr_consts(
|
||||
cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))
|
||||
if cond_tree != tree_flatten(True)[1]: raise TypeError
|
||||
if in_tree != in_tree_: raise TypeError
|
||||
outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,
|
||||
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
return tree_unflatten(in_tree, outs)
|
||||
while_loop_p = Primitive('while_loop')
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def while_loop_impl(*args, cond_jaxpr, body_jaxpr):
|
||||
consts, carry = split_list(args, _loop_num_consts(body_jaxpr))
|
||||
while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:
|
||||
carry = eval_jaxpr(body_jaxpr, [*consts, *carry])
|
||||
return carry
|
||||
impl_rules[while_loop_p] = while_loop_impl
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def _loop_num_consts(body_jaxpr: Jaxpr) -> int:
|
||||
return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)
|
||||
print(out)
|
||||
```
|
||||
|
||||
Notice the convention that `args = [*consts, *carry]`.
|
||||
|
||||
The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the
|
||||
convention that all the binders for tangent values are appended after all the
|
||||
binders for primal values, like `args = [*primals, *tangents]`. But that's in
|
||||
tension with our `while_loop` convention that the carry binders come after all
|
||||
the constant binders, i.e. that `args = [*consts, *carry]`, because both the
|
||||
constants and the carries can have their own tangents. For this reason, we
|
||||
introduce the `_loop_jvp_binders` helper to rearrange binders as needed.
|
||||
|
||||
```{code-cell}
|
||||
def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):
|
||||
num_consts = _loop_num_consts(body_jaxpr)
|
||||
body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)
|
||||
cond_jaxpr, body_jaxpr = _loop_jvp_binders(
|
||||
cond_jaxpr, body_jaxpr, len(body_consts), num_consts)
|
||||
outs = bind(while_loop_p, *body_consts, *primals, *tangents,
|
||||
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
primals_out, tangents_out = split_half(outs)
|
||||
return primals_out, tangents_out
|
||||
jvp_rules[while_loop_p] = while_loop_jvp_rule
|
||||
|
||||
def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int
|
||||
) -> Jaxpr:
|
||||
# body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]
|
||||
jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)
|
||||
primal_binders, tangent_binders = split_half(binders)
|
||||
consts , carry = split_list(primal_binders , n2)
|
||||
consts_dot, carry_dot = split_list(tangent_binders, n2)
|
||||
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
|
||||
new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)
|
||||
typecheck_jaxpr(new_body_jaxpr)
|
||||
|
||||
# cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]
|
||||
assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)
|
||||
consts, carry = split_list(cond_jaxpr.in_binders, n2)
|
||||
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
|
||||
new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)
|
||||
|
||||
return new_cond_jaxpr, new_body_jaxpr
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),
|
||||
(1.,), (1.,))
|
||||
print(out_tan)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def f(x):
|
||||
def cond_fn(i, _):
|
||||
return i < 3
|
||||
def body_fn(i, x):
|
||||
return i + 1, cos(x)
|
||||
_, y = while_loop(cond_fn, body_fn, (0, x))
|
||||
return y
|
||||
|
||||
def g(x):
|
||||
return cos(cos(cos(x)))
|
||||
|
||||
print(jvp(f, (1.,), (1.,)))
|
||||
print(jvp(g, (1.,), (1.,)))
|
||||
```
|
||||
|
||||
The vmap rule for `while_loop` presents two cases:
|
||||
1. if the output of `cond_fun` is not batched, then the loop has the same
|
||||
basic structure, just with a batched body;
|
||||
2. but if the output of `cond_fun` is batched, we must represent a batch of
|
||||
loops which might run for different numbers of iterations.
|
||||
|
||||
...Stay tuned for the thrilling conclusion!
|
||||
|
@ -85,6 +85,7 @@ sin_p = Primitive("sin")
|
||||
cos_p = Primitive("cos")
|
||||
reduce_sum_p = Primitive("reduce_sum")
|
||||
greater_p = Primitive("greater")
|
||||
less_p = Primitive("less")
|
||||
transpose_p = Primitive("transpose")
|
||||
broadcast_p = Primitive("broadcast")
|
||||
|
||||
@ -95,6 +96,7 @@ def sin(x): return bind1(sin_p, x)
|
||||
def cos(x): return bind1(cos_p, x)
|
||||
def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)
|
||||
def greater(x, y): return bind1(greater_p, x, y)
|
||||
def less(x, y): return bind1(less_p, x, y)
|
||||
def transpose(x, perm): return bind1(transpose_p, perm=perm)
|
||||
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
|
||||
|
||||
@ -128,7 +130,7 @@ def bind1(prim, *args, **params):
|
||||
|
||||
# +
|
||||
from contextlib import contextmanager
|
||||
from typing import Type, List, Optional, Any
|
||||
from typing import Type, List, Tuple, Sequence, Optional, Any
|
||||
|
||||
class MainTrace(NamedTuple):
|
||||
level: int
|
||||
@ -197,7 +199,6 @@ class Trace:
|
||||
|
||||
# +
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
class Tracer:
|
||||
_trace: Trace
|
||||
@ -217,6 +218,7 @@ class Tracer:
|
||||
def __mul__(self, other): return self.aval._mul(self, other)
|
||||
def __rmul__(self, other): return self.aval._rmul(self, other)
|
||||
def __gt__(self, other): return self.aval._gt(self, other)
|
||||
def __lt__(self, other): return self.aval._lt(self, other)
|
||||
def __bool__(self): return self.aval._bool(self)
|
||||
def __nonzero__(self): return self.aval._nonzero(self)
|
||||
|
||||
@ -248,6 +250,7 @@ class ShapedArray:
|
||||
_mul = staticmethod(mul)
|
||||
_rmul = staticmethod(swap(mul))
|
||||
_gt = staticmethod(greater)
|
||||
_lt = staticmethod(less)
|
||||
|
||||
@staticmethod
|
||||
def _bool(tracer):
|
||||
@ -404,6 +407,7 @@ impl_rules[sin_p] = lambda x: [np.sin(x)]
|
||||
impl_rules[cos_p] = lambda x: [np.cos(x)]
|
||||
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
|
||||
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
|
||||
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
|
||||
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
|
||||
|
||||
def broadcast_impl(x, *, shape, axes):
|
||||
@ -433,7 +437,8 @@ print(f(3.0))
|
||||
|
||||
# +
|
||||
def zeros_like(val):
|
||||
return np.zeros_like(val)
|
||||
aval = get_aval(val)
|
||||
return np.zeros(aval.shape, aval.dtype)
|
||||
|
||||
def unzip2(pairs):
|
||||
lst1, lst2 = [], []
|
||||
@ -445,6 +450,14 @@ def unzip2(pairs):
|
||||
map_ = map
|
||||
def map(f, *xs):
|
||||
return list(map_(f, *xs))
|
||||
|
||||
zip_ = zip
|
||||
def zip(*args):
|
||||
fst, *rest = args = map(list, args)
|
||||
n = len(fst)
|
||||
for arg in rest:
|
||||
assert len(arg) == n
|
||||
return list(zip_(*args))
|
||||
# -
|
||||
|
||||
# The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The
|
||||
@ -514,6 +527,12 @@ def greater_jvp(primals, tangents):
|
||||
out_primal = greater(x, y)
|
||||
return [out_primal], [zeros_like(out_primal)]
|
||||
jvp_rules[greater_p] = greater_jvp
|
||||
|
||||
def less_jvp(primals, tangents):
|
||||
(x, y), _ = primals, tangents
|
||||
out_primal = less(x, y)
|
||||
return [out_primal], [zeros_like(out_primal)]
|
||||
jvp_rules[less_p] = less_jvp
|
||||
# -
|
||||
|
||||
# Finally, we add a transformation API to kick off the trace:
|
||||
@ -797,16 +816,17 @@ vmap_rules = {}
|
||||
# +
|
||||
from functools import partial
|
||||
|
||||
def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):
|
||||
def binop_batching_rule(op, axis_size, vals_in, dims_in):
|
||||
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
|
||||
if x_bdim != y_bdim:
|
||||
if x_bdim is not_mapped:
|
||||
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
|
||||
x_bdim = y_bdim
|
||||
else:
|
||||
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
|
||||
return [op(x, y)], [x_bdim]
|
||||
vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)
|
||||
vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)
|
||||
vmap_rules[add_p] = partial(binop_batching_rule, add)
|
||||
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
|
||||
|
||||
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
|
||||
(x,), (x_bdim,) = vals_in, dims_in
|
||||
@ -878,7 +898,7 @@ jacfwd(f, np.arange(3.))
|
||||
|
||||
# ## Part 2: Jaxprs
|
||||
#
|
||||
# The next transformations are the horizon are `jit` for just-in-time
|
||||
# The next transformations on the horizon are `jit` for just-in-time
|
||||
# compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small
|
||||
# wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to
|
||||
# carry a little bit of extra context, for both `jit` and `vjp` we need much
|
||||
@ -943,8 +963,8 @@ class Lit:
|
||||
aval: ShapedArray
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
self.aval = raise_to_shaped(get_aval(self.val))
|
||||
self.aval = aval = raise_to_shaped(get_aval(val))
|
||||
self.val = np.array(val, aval.dtype)
|
||||
|
||||
Atom = Union[Var, Lit]
|
||||
|
||||
@ -1046,6 +1066,18 @@ def jaxpr_as_fun(jaxpr: Jaxpr):
|
||||
# used by `jit`, which is also used by control flow primitives like `lax.cond`,
|
||||
# `lax.while_loop`, and `lax.scan`.
|
||||
|
||||
# +
|
||||
def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
|
||||
assert 0 <= n <= len(lst)
|
||||
return lst[:n], lst[n:]
|
||||
|
||||
def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
assert len(bs) == len(l)
|
||||
lists = lst1, lst2 = [], []
|
||||
for b, x in zip(bs, l):
|
||||
lists[b].append(x)
|
||||
return lst1, lst2
|
||||
|
||||
# +
|
||||
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
|
||||
class JaxprTracer(Tracer):
|
||||
@ -1138,8 +1170,22 @@ class JaxprBuilder:
|
||||
out_vars = [t2v(t) for t in out_tracers]
|
||||
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
|
||||
typecheck_jaxpr(jaxpr)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
return jaxpr, constvals
|
||||
|
||||
def _inline_literals(jaxpr: Jaxpr, consts: List[Any]) -> Tuple[Jaxpr, List[Any]]:
|
||||
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
|
||||
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
|
||||
new_const_binders, lit_binders = partition_list(scalars, const_binders)
|
||||
new_consts, lit_vals = partition_list(scalars, consts)
|
||||
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
|
||||
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
|
||||
eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
|
||||
new_outs = [literals.get(x, x) for x in jaxpr.outs]
|
||||
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
|
||||
typecheck_jaxpr(new_jaxpr)
|
||||
return new_jaxpr, new_consts
|
||||
|
||||
# The rules we need for `JaxprTrace.process_primitive` are essentially typing
|
||||
# rules for primitive applications: given the primitive, its parameters, and
|
||||
# types for the inputs, the rule must produce a type for the output, which is
|
||||
@ -1152,36 +1198,38 @@ class JaxprBuilder:
|
||||
# generality is useful.
|
||||
|
||||
# +
|
||||
def broadcast_shapes(*shapes):
|
||||
assert len(shapes) > 1
|
||||
for sizes in zip(*shapes):
|
||||
sizes = [d for d in sizes if d != 1]
|
||||
if sizes[:-1] != sizes[1:]:
|
||||
raise Exception
|
||||
return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes))
|
||||
def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:
|
||||
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
|
||||
raise TypeError
|
||||
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
|
||||
return [ShapedArray(x.shape, x.dtype)]
|
||||
|
||||
def broadcasting_binop_abstract_eval_rule(*avals_in):
|
||||
out_dtype = np.result_type(*map(np.result_type, avals_in))
|
||||
out_shape = broadcast_shapes(*map(np.shape, avals_in))
|
||||
return [ShapedArray(out_shape, out_dtype)]
|
||||
abstract_eval_rules[add_p] = binop_abstract_eval
|
||||
abstract_eval_rules[mul_p] = binop_abstract_eval
|
||||
|
||||
abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule
|
||||
abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule
|
||||
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:
|
||||
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
|
||||
raise TypeError
|
||||
if x.shape != y.shape: raise TypeError
|
||||
return [ShapedArray(x.shape, np.dtype('bool'))]
|
||||
abstract_eval_rules[greater_p] = compare_abstract_eval
|
||||
abstract_eval_rules[less_p] = compare_abstract_eval
|
||||
|
||||
def vectorized_unop_abstract_eval_rule(aval_in):
|
||||
return [ShapedArray(np.shape(aval_in), np.result_type(aval_in))]
|
||||
def vectorized_unop_abstract_eval(x: ShapedArray) -> List[ShapedArray]:
|
||||
return [ShapedArray(x.shape, x.dtype)]
|
||||
|
||||
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule
|
||||
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule
|
||||
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval_rule
|
||||
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
|
||||
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
|
||||
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
|
||||
|
||||
def reduce_sum_abstract_eval_rule(aval_in, *, axis):
|
||||
new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis]
|
||||
return [ShapedArray(tuple(new_shape), aval_in.dtype)]
|
||||
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule
|
||||
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: int) -> List[ShapedArray]:
|
||||
new_shape = [d for i, d in enumerate(x.shape) if i != axis]
|
||||
return [ShapedArray(tuple(new_shape), x.dtype)]
|
||||
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
|
||||
|
||||
def broadcast_abstract_eval(x, *, shape, axes):
|
||||
return [ShapedArray(tuple(shape), np.result_type(x))]
|
||||
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
|
||||
axes: Sequence[int]) -> List[ShapedArray]:
|
||||
return [ShapedArray(tuple(shape), x.dtype)]
|
||||
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
|
||||
# -
|
||||
|
||||
@ -1441,6 +1489,7 @@ impl_rules[xla_call_p] = xla_call_impl
|
||||
@lru_cache()
|
||||
def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable]):
|
||||
jaxpr: Jaxpr = hashable_jaxpr.val
|
||||
typecheck_jaxpr(jaxpr)
|
||||
consts = [x.val for x in hashable_consts]
|
||||
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
|
||||
c = xb.make_computation_builder('xla_call')
|
||||
@ -1474,7 +1523,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
|
||||
env: Dict[Var, xe.XlaOp] = {}
|
||||
|
||||
def read(x: Atom) -> xe.XlaOp:
|
||||
return env[x] if type(x) is Var else xb.constant(c, x.val)
|
||||
return env[x] if type(x) is Var else xb.constant(c, x.val, False)
|
||||
|
||||
def write(v: Var, val: xe.XlaOp) -> None:
|
||||
env[v] = val
|
||||
@ -1495,7 +1544,7 @@ def execute_compiled(compiled, out_avals, *args):
|
||||
|
||||
default_input_handler = xb.get_backend(None).buffer_from_pyval
|
||||
input_handlers = {ty: default_input_handler for ty in
|
||||
[int, float, np.ndarray, np.float64, np.float32]}
|
||||
[bool, int, float, np.ndarray, np.float64, np.float32]}
|
||||
|
||||
def handle_result(aval: ShapedArray, buf):
|
||||
del aval # Unused for now.
|
||||
@ -1520,6 +1569,7 @@ xla_translations[neg_p] = partial(direct_translation, xops.Neg)
|
||||
xla_translations[sin_p] = partial(direct_translation, xops.Sin)
|
||||
xla_translations[cos_p] = partial(direct_translation, xops.Cos)
|
||||
xla_translations[greater_p] = partial(direct_translation, xops.Gt)
|
||||
xla_translations[less_p] = partial(direct_translation, xops.Lt)
|
||||
|
||||
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
|
||||
(x_aval,), (x,) = in_avals, in_vals
|
||||
@ -1708,6 +1758,7 @@ class DeviceArray:
|
||||
_mul = staticmethod(mul)
|
||||
_rmul = staticmethod(mul)
|
||||
_gt = staticmethod(greater)
|
||||
_lt = staticmethod(less)
|
||||
input_handlers[DeviceArray] = lambda x: x.buf
|
||||
|
||||
jax_types.add(DeviceArray)
|
||||
@ -1750,6 +1801,20 @@ print(ydot)
|
||||
# where the application of `f_lin` does not redo any of the linearization work.
|
||||
# We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.
|
||||
#
|
||||
# Tangentially, now that we have linear arrows `-o`, we can provide a slightly
|
||||
# more informative type for `jvp`:
|
||||
# ```
|
||||
# jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
|
||||
# ```
|
||||
# Here we're writing `UnrestrictedUse` just to indicate that we have a special
|
||||
# pair where the first element can be used in an unrestricted (nonlinear) way.
|
||||
# In conjunction with the linear arrow, this notation is just meant to express
|
||||
# that the function `jvp f` uses its first input in a nonlinear way but its
|
||||
# second input in a linear way, producing a corresponding nonlinear output
|
||||
# (which can be used in a nonlinear way) paired with a linear output. This more
|
||||
# refined type signature encodes the data dependencies in `jvp f`, which are
|
||||
# useful for partial evaluation.
|
||||
#
|
||||
# To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:
|
||||
# we evaluate all the primal values as we trace, but stage the tangent
|
||||
# computations into a jaxpr. This is our second way to build jaxprs. But where
|
||||
@ -1760,18 +1825,15 @@ print(ydot)
|
||||
# First, some utilities:
|
||||
|
||||
# +
|
||||
def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
|
||||
return lst[:n], lst[n:]
|
||||
|
||||
def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
assert not len(lst) % 2
|
||||
return split_list(lst, len(lst) // 2)
|
||||
|
||||
def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
lists = lst1, lst2 = [], []
|
||||
for b, x in zip(bs, l):
|
||||
lists[b].append(x)
|
||||
return lst1, lst2
|
||||
def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]:
|
||||
l1, l2 = iter(l1), iter(l2)
|
||||
out = [next(l2) if b else next(l1) for b in which]
|
||||
assert next(l1, None) is next(l2, None) is None
|
||||
return out
|
||||
# -
|
||||
|
||||
# Next, we'll write `linearize` by combining `jvp` together with a general
|
||||
@ -1816,24 +1878,22 @@ def vspace(aval: ShapedArray) -> ShapedArray:
|
||||
# callable's computation which can only be performed after the remaining inputs
|
||||
# are known.
|
||||
#
|
||||
# This transformation can't be summarized purely in a type signature because its
|
||||
# behavior relies on the data dependencies inside the given Python callable and
|
||||
# not just its type. Nevertheless a heuristic type signature is useful. If we
|
||||
# This transformation is tricky to summarize in a type signature. If we
|
||||
# assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where
|
||||
# `a1` and `a2` represent the known and unknown inputs, respectively, and where
|
||||
# `b1` only has a data dependency on `a1` while `b2` has some data dependency on
|
||||
# `a2`, then we might write
|
||||
#
|
||||
# ```
|
||||
# partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> (b1, res, (res, a2) -> b2)
|
||||
# partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
|
||||
# ```
|
||||
#
|
||||
# In words, given values for the inputs of type `a1`, `partial_eval` produces
|
||||
# the outputs of type `b1` along with "residual" values of type `res`
|
||||
# representing the intermediates required to complete the computation in the
|
||||
# second stage. It also produces a function of type `(res, a2) -> b2` which
|
||||
# accepts the residual values as well as the remaining inputs and produces the
|
||||
# remaining outputs.
|
||||
# the outputs of type `b1` along with "residual" values of
|
||||
# existentially-quantified type `r` representing the intermediates required to
|
||||
# complete the computation in the second stage. It also produces a function of
|
||||
# type `(r, a2) -> b2` which accepts the residual values as well as the
|
||||
# remaining inputs and produces the remaining outputs.
|
||||
#
|
||||
# We like to think of partial evaluation as "unzipping" one computation into
|
||||
# two. For example, consider this jaxpr:
|
||||
@ -1845,7 +1905,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
|
||||
# ```
|
||||
# A jaxpr for the JVP would look like:
|
||||
# ```
|
||||
# { lambda a:float64[] b:float64 .
|
||||
# { lambda a:float64[] b:float64[] .
|
||||
# let c:float64[] = sin a
|
||||
# d:float64[] = cos a
|
||||
# e:float64[] = mul d b
|
||||
@ -2073,7 +2133,7 @@ def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe
|
||||
def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:
|
||||
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
|
||||
|
||||
# +
|
||||
# + tags=["hide-input"]
|
||||
def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
|
||||
if not out_nodes: return []
|
||||
out_nodes = remove_duplicates(out_nodes)
|
||||
@ -2125,6 +2185,10 @@ print(sin_lin(1.), cos(3.))
|
||||
# To handle `linearize`-of-`jit`, we still need to write a partial evaluation
|
||||
# rule for `xla_call_p`. Other than tracer bookkeeping, the main task is to
|
||||
# perform partial evaluation of a jaxpr, 'unzipping' it into two jaxprs.
|
||||
#
|
||||
# There are actually two rules to write: one for trace-time partial evaluation,
|
||||
# which we'll call `xla_call_partial_eval`, and one for partial evaluation of
|
||||
# jaxprs, whicch we'll call `xla_call_peval_eqn`.
|
||||
|
||||
# +
|
||||
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
@ -2142,18 +2206,17 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
dict(jaxpr=jaxpr2, num_consts=0),
|
||||
[v.aval for v in jaxpr2.outs], map(ref, outs2))
|
||||
for t in outs2: t.recipe = eqn
|
||||
outs1, outs2 = iter(outs1), iter(outs2)
|
||||
return [next(outs2) if uk else next(outs1) for uk in out_unknowns]
|
||||
return merge_lists(out_unknowns, outs1, outs2)
|
||||
partial_eval_rules[xla_call_p] = xla_call_partial_eval
|
||||
|
||||
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool]
|
||||
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
|
||||
instantiate: Optional[List[bool]] = None,
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], int]:
|
||||
env: Dict[Var, bool] = {}
|
||||
residuals = set()
|
||||
|
||||
def read(v: Atom) -> bool:
|
||||
if type(v) is Lit: raise NotImplementedError
|
||||
return env[v]
|
||||
return type(v) is Var and env[v]
|
||||
|
||||
def write(unk: bool, v: Var) -> None:
|
||||
env[v] = unk
|
||||
@ -2178,6 +2241,11 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool]
|
||||
eqns1.append(eqn)
|
||||
map(partial(write, False), eqn.out_binders)
|
||||
out_unknowns = map(read, jaxpr.outs)
|
||||
if instantiate is not None:
|
||||
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
|
||||
if inst and not uk: new_res(v)
|
||||
out_unknowns = map(op.or_, out_unknowns, instantiate)
|
||||
|
||||
residuals, num_res = list(residuals), len(residuals)
|
||||
|
||||
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
|
||||
@ -2209,7 +2277,7 @@ def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
|
||||
|
||||
partial_eval_jaxpr_rules = {}
|
||||
|
||||
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn
|
||||
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
|
||||
@ -2393,10 +2461,11 @@ transpose_rules[xla_call_p] = xla_call_transpose_rule
|
||||
@lru_cache()
|
||||
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: Tuple[bool, ...]
|
||||
) -> Tuple[Jaxpr, List[Any]]:
|
||||
traceable = partial(eval_jaxpr_transposed, jaxpr)
|
||||
avals_in, avals_out = typecheck_jaxpr(jaxpr)
|
||||
traceable = partial(eval_jaxpr_transposed, jaxpr)
|
||||
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
|
||||
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
|
||||
typecheck_jaxpr(trans_jaxpr)
|
||||
return trans_jaxpr, consts
|
||||
# -
|
||||
|
||||
@ -2477,3 +2546,411 @@ _, hess5 = jvp(grad(f), (3.,), (1.,))
|
||||
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
|
||||
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
|
||||
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
|
||||
# -
|
||||
|
||||
# ## Part 5: the control flow primitives `cond` and `while_loop`
|
||||
#
|
||||
# Next we'll add higher-order primitives for staged-out control flow. These
|
||||
# resemble `jit` from Part 3, another higher-order primitive, but differ in that
|
||||
# they are parameterized by multiple callables rather than just one.
|
||||
|
||||
# ### Adding `cond`
|
||||
#
|
||||
# We introduce a `cond` primitive to represent conditional application of one
|
||||
# function or another inside a jaxpr. We write the type of `cond` as
|
||||
# `Bool -> (a -> b) -> (a -> b) -> a -> b`. In words, `cond` takes a boolean
|
||||
# representing the predicate and two functions of equal types. Depending on the
|
||||
# value of the predicate, it applies one function or the other to its final
|
||||
# argument.
|
||||
#
|
||||
# In Python, we represent it as a function which itself takes two functions as
|
||||
# arguments. As with `jit`, the first step is to call `make_jaxpr` on its
|
||||
# callable arguments to turn them into jaxprs:
|
||||
|
||||
# +
|
||||
def cond(pred, true_fn, false_fn, *operands):
|
||||
avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
|
||||
true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
|
||||
false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
|
||||
if out_tree != out_tree_: raise TypeError
|
||||
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
|
||||
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
|
||||
if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
|
||||
raise TypeError
|
||||
outs = bind_cond(pred, *true_consts, *false_consts, *operands,
|
||||
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
return tree_unflatten(out_tree, outs)
|
||||
cond_p = Primitive('cond')
|
||||
|
||||
def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
|
||||
) -> Tuple[Jaxpr, Jaxpr]:
|
||||
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
|
||||
assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
|
||||
consts1, rest1 = split_list(jaxpr1.in_binders, n1)
|
||||
consts2, rest2 = split_list(jaxpr2.in_binders, n2)
|
||||
new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
|
||||
new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
|
||||
return new_jaxpr1, new_jaxpr2
|
||||
|
||||
def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
|
||||
assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
|
||||
return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
# -
|
||||
|
||||
# We require `true_jaxpr` and `false_jaxpr` to have the same type, but because
|
||||
# they might close over different constants (and because jaxprs can only
|
||||
# represent closed terms, i.e. can't have free variables and are instead
|
||||
# closure-converted) we need to use the helper `_join_jaxpr_consts` to make
|
||||
# consistent the input binder lists of the two jaxprs. (To be more economical we
|
||||
# could try to identify pairs of constants with the same shapes, but instead we
|
||||
# just concatenate the lists of constants.)
|
||||
#
|
||||
# Next we can turn to adding interpreter rules for `cond`. Its evaluation rule
|
||||
# is simple:
|
||||
|
||||
|
||||
def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
|
||||
if pred:
|
||||
return eval_jaxpr(true_jaxpr, operands)
|
||||
else:
|
||||
return eval_jaxpr(false_jaxpr, operands)
|
||||
impl_rules[cond_p] = cond_impl
|
||||
|
||||
out = cond(True, lambda: 3, lambda: 4)
|
||||
print(out)
|
||||
|
||||
# For its JVP and vmap rules, we only need to call the same `jvp_jaxpr` and
|
||||
# `vmap_jaxpr` utilities we created for `jit`, followed by another pass of
|
||||
# `_join_jaxpr_consts`:
|
||||
|
||||
def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
|
||||
pred, *primals = primals
|
||||
_ , *tangents = tangents
|
||||
true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)
|
||||
false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
|
||||
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
|
||||
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
|
||||
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
|
||||
outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
|
||||
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
primals_out, tangents_out = split_half(outs)
|
||||
return primals_out, tangents_out
|
||||
jvp_rules[cond_p] = cond_jvp_rule
|
||||
|
||||
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
|
||||
print(out_tan)
|
||||
|
||||
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
|
||||
pred , *vals_in = vals_in
|
||||
pred_dim, *dims_in = dims_in
|
||||
if pred_dim is not not_mapped: raise NotImplementedError # TODO
|
||||
true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
|
||||
false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
|
||||
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
|
||||
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
|
||||
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
|
||||
outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
|
||||
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
return outs, [0] * len(outs)
|
||||
vmap_rules[cond_p] = cond_vmap_rule
|
||||
|
||||
xs = np.array([1., 2., 3])
|
||||
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
|
||||
print(out)
|
||||
|
||||
# Notice that we're not currently supporting the case where the predicate value
|
||||
# itself is batched. In mainline JAX, we handle this case by transforming the
|
||||
# conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).
|
||||
# That transformation is semantically correct so long as `true_fun` and
|
||||
# `false_fun` do not involve any side-effecting primitives.
|
||||
#
|
||||
# Another thing not represented here, but present in the mainline JAX, is that
|
||||
# applying transformations to two jaxprs of equal type might result in jaxprs of
|
||||
# different types. For example, applying the mainline JAX version of
|
||||
# `vmap_jaxpr` to the identity-function jaxpr
|
||||
#
|
||||
# ```
|
||||
# { lambda a:float32[] .
|
||||
# let
|
||||
# in ( a ) }
|
||||
# ```
|
||||
#
|
||||
# would result in a jaxpr with a batched output, of type
|
||||
# `[float32[10]] -> [float32[10]]` if the batch size were 10, while applying it
|
||||
# to the zero-function jaxpr
|
||||
#
|
||||
# ```
|
||||
# { lambda a:float32[] .
|
||||
# let
|
||||
# in ( 0. ) }
|
||||
# ```
|
||||
#
|
||||
# would result in a jaxpr with an unbatched output, of type
|
||||
# `[float32[10]] -> [float32[]]`. This is an optimization, aimed at not batching
|
||||
# values unnecessarily. But it means that in `cond` we'd need an extra step of
|
||||
# joining the two transformed jaxprs to have consistent output types. We don't
|
||||
# need this step here because we chose `vmap_jaxpr` always to batch all outputs
|
||||
# over the leading axis.
|
||||
|
||||
|
||||
# Next we can turn to abstract evaluation and XLA lowering rules:
|
||||
|
||||
# +
|
||||
def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
|
||||
if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
|
||||
jaxpr_type = typecheck_jaxpr(true_jaxpr)
|
||||
if jaxpr_type != typecheck_jaxpr(false_jaxpr):
|
||||
raise TypeError
|
||||
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
|
||||
raise TypeError
|
||||
return jaxpr_type.out_types
|
||||
abstract_eval_rules[cond_p] = cond_abstract_eval
|
||||
|
||||
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
|
||||
del in_avals # Unused.
|
||||
pred, *in_vals = in_vals
|
||||
flat_vals, in_tree = tree_flatten(in_vals)
|
||||
operand = xops.Tuple(c, flat_vals)
|
||||
operand_shape = c.get_shape(operand)
|
||||
|
||||
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
|
||||
c = xb.make_computation_builder(name)
|
||||
operand = xb.parameter(c, 0, operand_shape)
|
||||
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
|
||||
outs = jaxpr_subcomp(c, jaxpr, operands)
|
||||
return c.build(xops.Tuple(c, outs))
|
||||
|
||||
true_comp = make_comp('true_fn', true_jaxpr)
|
||||
false_comp = make_comp('false_fn', false_jaxpr)
|
||||
|
||||
int_etype = xc.dtype_to_etype(np.dtype('int32'))
|
||||
out = xops.Conditional(xops.ConvertElementType(pred, int_etype),
|
||||
[false_comp, true_comp], [operand] * 2)
|
||||
return destructure_tuple(c, out)
|
||||
xla_translations[cond_p] = cond_translation
|
||||
# -
|
||||
|
||||
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
|
||||
print(out)
|
||||
|
||||
# Finally, to support reverse-mode automatic differentiation, we need partial
|
||||
# evaluation and transposition rules. For partial evaluation, we need to
|
||||
# introduce another jaxpr-munging utility, `_join_jaxpr_res`, to handle the fact
|
||||
# that applying partial evaluation to `true_fun` and `false_fun` will in general
|
||||
# result in distinct residuals. We use `_join_jaxpr_res` to make the output
|
||||
# types of the transformed jaxprs consistent (while `_join_jaxpr_consts` dealt
|
||||
# with input types).
|
||||
|
||||
# +
|
||||
def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
|
||||
pred_tracer, *tracers = tracers
|
||||
assert pred_tracer.pval.is_known
|
||||
pred = pred_tracer.pval.const
|
||||
in_uks = [not t.pval.is_known for t in tracers]
|
||||
|
||||
*jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
|
||||
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
|
||||
|
||||
known_tracers, unknown_tracers = partition_list(in_uks, tracers)
|
||||
known_vals = [t.pval.const for t in known_tracers]
|
||||
outs1_res = bind_cond(pred, *known_vals,
|
||||
true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
|
||||
outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
|
||||
pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
|
||||
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
|
||||
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
|
||||
for v in t_jaxpr2.outs]
|
||||
eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
|
||||
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
|
||||
[v.aval for v in t_jaxpr2.outs], map(ref, outs2))
|
||||
for t in outs2: t.recipe = eqn
|
||||
return merge_lists(out_uks, outs1, outs2)
|
||||
partial_eval_rules[cond_p] = cond_partial_eval
|
||||
|
||||
def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: List[bool]
|
||||
) -> Tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, List[bool], int]:
|
||||
_, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
|
||||
_, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
|
||||
out_uks = map(op.or_, t_out_uks, f_out_uks)
|
||||
|
||||
t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
|
||||
f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)
|
||||
|
||||
t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
|
||||
t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
|
||||
assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
|
||||
assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
|
||||
num_res = t_nres + f_nres
|
||||
|
||||
return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res
|
||||
|
||||
def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
|
||||
) -> Tuple[Jaxpr, Jaxpr]:
|
||||
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
|
||||
out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
|
||||
out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
|
||||
assert out_types1 == out_types2
|
||||
outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
|
||||
outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
|
||||
zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
|
||||
zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
|
||||
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
|
||||
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
|
||||
return new_jaxpr1, new_jaxpr2
|
||||
|
||||
|
||||
# -
|
||||
|
||||
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
|
||||
out = f_lin(3.14)
|
||||
print(out)
|
||||
|
||||
def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
|
||||
pred_unk, *unks_in = unks_in
|
||||
assert not pred_unk
|
||||
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
|
||||
*jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
|
||||
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
|
||||
ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
|
||||
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
|
||||
residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
|
||||
eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
|
||||
dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
|
||||
outs1 + residuals)
|
||||
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
|
||||
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
|
||||
outs2)
|
||||
return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]
|
||||
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
|
||||
|
||||
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
|
||||
out = f_lin(3.14)
|
||||
print(out)
|
||||
|
||||
# Transposition is a fairly straightforward application of `transpose_jaxpr`:
|
||||
|
||||
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
|
||||
undef_primals = tuple([type(x) is UndefPrimal for x in invals])
|
||||
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
|
||||
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
|
||||
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
|
||||
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
|
||||
res = [x for x in invals if type(x) is not UndefPrimal]
|
||||
outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
|
||||
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
|
||||
outs = iter(outs)
|
||||
return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
|
||||
transpose_rules[cond_p] = cond_transpose_rule
|
||||
|
||||
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
|
||||
print(out)
|
||||
|
||||
|
||||
# ### Adding `while_loop`
|
||||
#
|
||||
# Next we'll add a primitive for looping behavior in a jaxpr. We'll use
|
||||
# `while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first
|
||||
# function-valued argument represents the loop condition, the second represents
|
||||
# the loop body, and the final argument is the initial value of the carry.
|
||||
#
|
||||
# After `cond`, adding `while_loop` is not so different:
|
||||
|
||||
def while_loop(cond_fn, body_fn, init_val):
|
||||
init_val, in_tree = tree_flatten(init_val)
|
||||
avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]
|
||||
cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)
|
||||
body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)
|
||||
cond_jaxpr, body_jaxpr = _join_jaxpr_consts(
|
||||
cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))
|
||||
if cond_tree != tree_flatten(True)[1]: raise TypeError
|
||||
if in_tree != in_tree_: raise TypeError
|
||||
outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,
|
||||
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
return tree_unflatten(in_tree, outs)
|
||||
while_loop_p = Primitive('while_loop')
|
||||
|
||||
def while_loop_impl(*args, cond_jaxpr, body_jaxpr):
|
||||
consts, carry = split_list(args, _loop_num_consts(body_jaxpr))
|
||||
while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:
|
||||
carry = eval_jaxpr(body_jaxpr, [*consts, *carry])
|
||||
return carry
|
||||
impl_rules[while_loop_p] = while_loop_impl
|
||||
|
||||
def _loop_num_consts(body_jaxpr: Jaxpr) -> int:
|
||||
return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)
|
||||
|
||||
out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)
|
||||
print(out)
|
||||
|
||||
# Notice the convention that `args = [*consts, *carry]`.
|
||||
#
|
||||
# The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the
|
||||
# convention that all the binders for tangent values are appended after all the
|
||||
# binders for primal values, like `args = [*primals, *tangents]`. But that's in
|
||||
# tension with our `while_loop` convention that the carry binders come after all
|
||||
# the constant binders, i.e. that `args = [*consts, *carry]`, because both the
|
||||
# constants and the carries can have their own tangents. For this reason, we
|
||||
# introduce the `_loop_jvp_binders` helper to rearrange binders as needed.
|
||||
|
||||
# +
|
||||
def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):
|
||||
num_consts = _loop_num_consts(body_jaxpr)
|
||||
body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)
|
||||
cond_jaxpr, body_jaxpr = _loop_jvp_binders(
|
||||
cond_jaxpr, body_jaxpr, len(body_consts), num_consts)
|
||||
outs = bind(while_loop_p, *body_consts, *primals, *tangents,
|
||||
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
primals_out, tangents_out = split_half(outs)
|
||||
return primals_out, tangents_out
|
||||
jvp_rules[while_loop_p] = while_loop_jvp_rule
|
||||
|
||||
def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int
|
||||
) -> Jaxpr:
|
||||
# body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]
|
||||
jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)
|
||||
primal_binders, tangent_binders = split_half(binders)
|
||||
consts , carry = split_list(primal_binders , n2)
|
||||
consts_dot, carry_dot = split_list(tangent_binders, n2)
|
||||
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
|
||||
new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)
|
||||
typecheck_jaxpr(new_body_jaxpr)
|
||||
|
||||
# cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]
|
||||
assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)
|
||||
consts, carry = split_list(cond_jaxpr.in_binders, n2)
|
||||
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
|
||||
new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)
|
||||
|
||||
return new_cond_jaxpr, new_body_jaxpr
|
||||
|
||||
|
||||
# -
|
||||
|
||||
out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),
|
||||
(1.,), (1.,))
|
||||
print(out_tan)
|
||||
|
||||
# +
|
||||
def f(x):
|
||||
def cond_fn(i, _):
|
||||
return i < 3
|
||||
def body_fn(i, x):
|
||||
return i + 1, cos(x)
|
||||
_, y = while_loop(cond_fn, body_fn, (0, x))
|
||||
return y
|
||||
|
||||
def g(x):
|
||||
return cos(cos(cos(x)))
|
||||
|
||||
print(jvp(f, (1.,), (1.,)))
|
||||
print(jvp(g, (1.,), (1.,)))
|
||||
# -
|
||||
|
||||
# The vmap rule for `while_loop` presents two cases:
|
||||
# 1. if the output of `cond_fun` is not batched, then the loop has the same
|
||||
# basic structure, just with a batched body;
|
||||
# 2. but if the output of `cond_fun` is batched, we must represent a batch of
|
||||
# loops which might run for different numbers of iterations.
|
||||
#
|
||||
# ...Stay tuned for the thrilling conclusion!
|
||||
|
Loading…
x
Reference in New Issue
Block a user