add linearize code (needs text)

This commit is contained in:
Matthew Johnson 2021-03-11 10:08:43 -08:00
parent 72a3036b1a
commit 3457696e80
3 changed files with 955 additions and 28 deletions

View File

@ -27,10 +27,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 0
"lines_to_next_cell": 2
},
"outputs": [],
"source": []
"source": [
"import pdb, sys, traceback\n",
"def info(type, value, tb):\n",
" traceback.print_exception(type, value, tb)\n",
" pdb.pm()\n",
"sys.excepthook = info"
]
},
{
"cell_type": "markdown",
@ -322,6 +328,17 @@
" raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def swap(f): return lambda x, y: f(y, x)"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -345,9 +362,9 @@
"\n",
" _neg = staticmethod(neg)\n",
" _add = staticmethod(add)\n",
" _radd = staticmethod(add)\n",
" _radd = staticmethod(swap(add))\n",
" _mul = staticmethod(mul)\n",
" _rmul = staticmethod(mul)\n",
" _rmul = staticmethod(swap(mul))\n",
" _gt = staticmethod(greater)\n",
"\n",
" @staticmethod\n",
@ -407,8 +424,22 @@
"def get_aval(x):\n",
" if isinstance(x, Tracer):\n",
" return x.aval\n",
" elif type(x) in jax_types:\n",
" return ConcreteArray(np.asarray(x))\n",
" else:\n",
" return ConcreteArray(np.asarray(x))"
" raise TypeError(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"jax_types = {bool, int, float,\n",
" np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}"
]
},
{
@ -459,7 +490,7 @@
},
"outputs": [],
"source": [
"from operator import attrgetter"
"import operator as op"
]
},
{
@ -472,7 +503,7 @@
"source": [
"def find_top_trace(xs) -> Trace:\n",
" top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),\n",
" default=trace_stack[0], key=attrgetter('level'))\n",
" default=trace_stack[0], key=op.attrgetter('level'))\n",
" if dynamic_trace and dynamic_trace.level > top_main.level:\n",
" top_main = dynamic_trace\n",
" return top_main.trace_type(top_main)"
@ -526,6 +557,7 @@
"source": [
"def full_raise(trace: Trace, val: Any) -> Tracer:\n",
" if not isinstance(val, Tracer):\n",
" assert type(val) in jax_types\n",
" return trace.pure(val)\n",
" level = trace.main.level\n",
" if val._trace.main is trace.main:\n",
@ -1457,7 +1489,10 @@
"def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):\n",
" (x, y), (x_bdim, y_bdim) = vals_in, dims_in\n",
" if x_bdim != y_bdim:\n",
" y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n",
" if x_bdim is not_mapped:\n",
" x = move_batch_axis(axis_size, x_bdim, y_bdim, x)\n",
" else:\n",
" y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n",
" return [op(x, y)], [x_bdim]\n",
"vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)\n",
"vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)"
@ -1705,6 +1740,9 @@
" eqns: List[JaxprEqn]\n",
" outs: List[Atom]\n",
"\n",
" def __hash__(self): return id(self)\n",
" __eq__ = op.is_\n",
"\n",
"def raise_to_shaped(aval):\n",
" return ShapedArray(aval.shape, aval.dtype)"
]
@ -2693,6 +2731,7 @@
},
"outputs": [],
"source": [
"@lru_cache()\n",
"def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:\n",
" def jvp_traceable(*primals_and_tangents):\n",
" n = len(primals_and_tangents) // 2\n",
@ -2714,13 +2753,14 @@
"source": [
"def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, dims_in)\n",
" new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))\n",
" outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,\n",
" num_consts=len(new_consts))\n",
" return outs, [0] * len(outs)\n",
"vmap_rules[xla_call_p] = xla_call_vmap_rule\n",
"\n",
"def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: List[BatchAxis]\n",
"@lru_cache()\n",
"def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis]\n",
" ) -> Tuple[Jaxpr, List[Any]]:\n",
" vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n",
" in_avals = [unmapped_aval(axis_size, d, v.aval)\n",
@ -2749,6 +2789,7 @@
"source": [
"@jit\n",
"def f(x):\n",
" print('tracing!')\n",
" y = sin(x) * 2.\n",
" z = - y + x\n",
" return z\n",
@ -2758,6 +2799,8 @@
"print(y)\n",
"print(ydot)\n",
"\n",
"y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed\n",
"\n",
"ys = vmap(f, (0,))(np.arange(3.))\n",
"print(ys)"
]
@ -2807,7 +2850,9 @@
" _mul = staticmethod(mul)\n",
" _rmul = staticmethod(mul)\n",
" _gt = staticmethod(greater)\n",
"input_handlers[DeviceArray] = lambda x: x.buf"
"input_handlers[DeviceArray] = lambda x: x.buf\n",
"\n",
"jax_types.add(DeviceArray)"
]
},
{
@ -2827,6 +2872,317 @@
"print(y)\n",
"print(ydot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 4: `linearize` and `vjp` (and `grad`!)\n",
"\n",
"The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve\n",
"jaxprs as well. That's because both involve staging out, or delaying,\n",
"computation.\n",
"\n",
"In the case of `linearize`, we want to stage out the linear part of a `jvp`\n",
"computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,\n",
"then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, where\n",
"```\n",
"y, f_lin = linearize(f, x)\n",
"y_dot = f_lin(x_dot)\n",
"```\n",
"gives the same result for `(y, y_dot)` as\n",
"```\n",
"y, y_dot = jvp(f, (x,), (x_dot,))\n",
"```\n",
"and where the application of `f_lin` does not redo any of the linearization\n",
"work. We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.\n",
"\n",
"To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:\n",
"we evaluate all the primal values as we trace, but stage the tangent\n",
"computations into a jaxpr."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def split_half(lst):\n",
" n, ragged = divmod(len(lst), 2)\n",
" assert not ragged\n",
" return lst[:n], lst[n:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def linearize_flat(f, *primals_in):\n",
" pvals_in = ([PartialVal.known(x) for x in primals_in] +\n",
" [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])\n",
"\n",
" def f_jvp(*primals_tangents_in):\n",
" primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))\n",
" return [*primals_out, *tangents_out]\n",
"\n",
" jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)\n",
" primal_pvals, _ = split_half(pvals_out)\n",
" assert all(pval.is_known for pval in primal_pvals)\n",
" primals_out = [pval.const for pval in primal_pvals]\n",
" f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])\n",
" return primals_out, f_lin\n",
"\n",
"def linearize(f, *primals_in):\n",
" primals_in_flat, in_tree = tree_flatten(primals_in)\n",
" f, out_tree = flatten_fun(f, in_tree)\n",
" primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)\n",
" primals_out = tree_unflatten(out_tree(), primals_out_flat)\n",
"\n",
" def f_lin(*tangents_in):\n",
" tangents_in_flat, in_tree2 = tree_flatten(tangents_in)\n",
" if in_tree != in_tree2: raise TypeError\n",
" tangents_out_flat = f_lin_flat(*tangents_in_flat)\n",
" return tree_unflatten(out_tree(), tangents_out_flat)\n",
"\n",
" return primals_out, f_lin\n",
"\n",
"def vspace(aval: ShapedArray) -> ShapedArray:\n",
" return raise_to_shaped(aval)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"class PartialVal(NamedTuple):\n",
" aval: ShapedArray\n",
" const: Optional[Any]\n",
"\n",
" @classmethod\n",
" def known(cls, val: Any):\n",
" return PartialVal(get_aval(val), val)\n",
"\n",
" @classmethod\n",
" def unknown(cls, aval: ShapedArray):\n",
" return PartialVal(aval, None)\n",
"\n",
" is_known = property(lambda self: self.const is not None)\n",
" is_unknown = property(lambda self: self.const is None)\n",
"\n",
"def partial_eval_flat(f, pvals_in: List[PartialVal]):\n",
" with new_main(PartialEvalTrace) as main:\n",
" trace = PartialEvalTrace(main)\n",
" tracers_in = [trace.new_arg(pval) for pval in pvals_in]\n",
" outs = f(*tracers_in)\n",
" tracers_out = [full_raise(trace, out) for out in outs]\n",
" jaxpr, consts = tracers_to_jaxpr(tracers_in, tracers_out)\n",
" pvals_out = [t.pval for t in tracers_out]\n",
" return jaxpr, pvals_out, consts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"from weakref import ref, ReferenceType\n",
"\n",
"class LambdaBindingRecipe(NamedTuple): pass\n",
"\n",
"class ConstRecipe(NamedTuple):\n",
" val: Any\n",
"\n",
"class JaxprEqnRecipe:\n",
" prim: Primitive\n",
" tracers_in: List['PartialEvalTracer']\n",
" params: Dict[str, Any]\n",
" avals_out: List[ShapedArray]\n",
" tracer_refs_out: List['ReferenceType[PartialEvalTracer]']\n",
"\n",
" def __init__(self, prim, tracers_in, params, avals_out, tracer_refs_out):\n",
" self.prim = prim\n",
" self.tracers_in = tracers_in\n",
" self.params = params\n",
" self.avals_out = avals_out\n",
" self.tracer_refs_out = tracer_refs_out\n",
"\n",
"JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]\n",
"\n",
"class PartialEvalTracer(Tracer):\n",
" pval: PartialVal\n",
" recipe: JaxprRecipe\n",
"\n",
" def __init__(self, trace, pval, recipe):\n",
" self._trace = trace\n",
" self.pval = pval\n",
" self.recipe = recipe\n",
"\n",
" @property\n",
" def aval(self):\n",
" return self.pval.aval\n",
"\n",
" def full_lower(self):\n",
" if self.pval.is_known:\n",
" return full_lower(self.pval.const)\n",
" return self\n",
"\n",
"class PartialEvalTrace(Trace):\n",
" def new_arg(self, pval: PartialVal) -> Any:\n",
" return PartialEvalTracer(self, pval, LambdaBindingRecipe())\n",
"\n",
" def lift(self, val: Any) -> PartialEvalTracer:\n",
" return PartialEvalTracer(self, PartialVal.known(val), None)\n",
" pure = lift\n",
"\n",
" def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:\n",
" if tracer.pval.is_unknown:\n",
" return tracer\n",
" else:\n",
" pval = PartialVal.unknown(raise_to_shaped(tracer.aval))\n",
" return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))\n",
"\n",
" def process_primitive(self, primitive, tracers, params):\n",
" if all(t.pval.is_known for t in tracers):\n",
" return bind(primitive, *map(full_lower, tracers), **params)\n",
" tracers_in = [self.instantiate_const(t) for t in tracers]\n",
" avals_in = [t.aval for t in tracers_in]\n",
" avals_out = abstract_eval_rules[primitive](*avals_in, **params)\n",
" tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)\n",
" for aval in avals_out]\n",
" eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,\n",
" map(ref, tracers_out))\n",
" for t in tracers_out: t.recipe = eqn\n",
" return tracers_out"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],\n",
" tracers_out: List[PartialEvalTracer]):\n",
" tracers_in = [t for t in tracers_in if t.pval.is_unknown]\n",
" tracers_out = [t for t in tracers_out if t.pval.is_unknown]\n",
"\n",
" tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}\n",
" constvar_to_val = {}\n",
" constid_to_var = {}\n",
" processed_eqns = set()\n",
" eqns = []\n",
" for t in toposort(tracers_out, tracer_parents):\n",
" if isinstance(t.recipe, LambdaBindingRecipe):\n",
" assert id(t) in set(map(id, tracers_in))\n",
" elif isinstance(t.recipe, ConstRecipe):\n",
" val = t.recipe.val\n",
" var = constid_to_var.get(id(val))\n",
" if var is None:\n",
" aval = raise_to_shaped(get_aval(val))\n",
" var = tracer_to_var[id(t)] = constid_to_var[id(val)] = Var(aval)\n",
" constvar_to_val[var] = val\n",
" elif isinstance(t.recipe, JaxprEqnRecipe):\n",
" if id(t.recipe) not in processed_eqns:\n",
" eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))\n",
" processed_eqns.add(id(t.recipe))\n",
" else:\n",
" raise TypeError(t.recipe)\n",
"\n",
" constvars, constvals = unzip2(constvar_to_val.items())\n",
" in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]\n",
" out_vars = [tracer_to_var[id(t)] for t in tracers_out]\n",
" jaxpr = Jaxpr(in_binders, eqns, out_vars)\n",
" typecheck_jaxpr(jaxpr)\n",
" return jaxpr, constvals\n",
"\n",
"def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe\n",
" ) -> JaxprEqn:\n",
" inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]\n",
" out_binders = [Var(aval) for aval in recipe.avals_out]\n",
" for t_ref, var in zip(recipe.tracer_refs_out, out_binders):\n",
" if t_ref() is not None: tracer_to_var[id(t_ref())] = var\n",
" return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)\n",
"\n",
"def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:\n",
" return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):\n",
" if not out_nodes: return []\n",
" out_nodes = remove_duplicates(out_nodes)\n",
"\n",
" child_counts = {}\n",
" stack = list(out_nodes)\n",
" while stack:\n",
" node = stack.pop()\n",
" if id(node) in child_counts:\n",
" child_counts[id(node)] += 1\n",
" else:\n",
" child_counts[id(node)] = 1\n",
" stack.extend(parents(node))\n",
" for node in out_nodes:\n",
" child_counts[id(node)] -= 1\n",
"\n",
" sorted_nodes = []\n",
" childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]\n",
" while childless_nodes:\n",
" node = childless_nodes.pop()\n",
" sorted_nodes.append(node)\n",
" for parent in parents(node):\n",
" if child_counts[id(parent)] == 1:\n",
" childless_nodes.append(parent)\n",
" else:\n",
" child_counts[id(parent)] -= 1\n",
"\n",
" sorted_nodes = sorted_nodes[::-1]\n",
" check_toposort(sorted_nodes, parents)\n",
" return sorted_nodes\n",
"\n",
"def remove_duplicates(lst):\n",
" seen = set()\n",
" return [x for x in lst if id(x) not in seen and not seen.add(id(x))]\n",
"\n",
"def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):\n",
" seen = set()\n",
" for node in nodes:\n",
" assert all(id(parent) in seen for parent in parents(node))\n",
" seen.add(id(node))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y, sin_lin = linearize(sin, 3.)\n",
"print(y, sin(3.))\n",
"print(sin_lin(1.), cos(3.))"
]
}
],
"metadata": {

View File

@ -33,7 +33,11 @@ limitations under the License.
```
```{code-cell} ipython3
import pdb, sys, traceback
def info(type, value, tb):
traceback.print_exception(type, value, tb)
pdb.pm()
sys.excepthook = info
```
# Autodidax: JAX core from scratch
@ -241,6 +245,10 @@ class Tracer:
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
```
```{code-cell} ipython3
def swap(f): return lambda x, y: f(y, x)
```
```{code-cell} ipython3
class ShapedArray:
array_abstraction_level = 1
@ -257,9 +265,9 @@ class ShapedArray:
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(add)
_radd = staticmethod(swap(add))
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
@staticmethod
@ -304,8 +312,15 @@ class ConcreteArray(ShapedArray):
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
elif type(x) in jax_types:
return ConcreteArray(np.asarray(x))
else:
raise TypeError(x)
```
```{code-cell} ipython3
jax_types = {bool, int, float,
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
```
Notice that we actually have two `AbstractValue`s for arrays, representing
@ -332,13 +347,13 @@ top trace's `Tracer` instances, and the call to `full_lower` is an optional
optimization so that we unbox values out of `Tracer`s as much as possible.
```{code-cell} ipython3
from operator import attrgetter
import operator as op
```
```{code-cell} ipython3
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=trace_stack[0], key=attrgetter('level'))
default=trace_stack[0], key=op.attrgetter('level'))
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
return top_main.trace_type(top_main)
@ -373,6 +388,7 @@ def full_lower(val: Any):
```{code-cell} ipython3
def full_raise(trace: Trace, val: Any) -> Tracer:
if not isinstance(val, Tracer):
assert type(val) in jax_types
return trace.pure(val)
level = trace.main.level
if val._trace.main is trace.main:
@ -875,7 +891,10 @@ from functools import partial
def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)
vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)
@ -1058,6 +1077,9 @@ class Jaxpr(NamedTuple):
eqns: List[JaxprEqn]
outs: List[Atom]
def __hash__(self): return id(self)
__eq__ = op.is_
def raise_to_shaped(aval):
return ShapedArray(aval.shape, aval.dtype)
```
@ -1743,6 +1765,7 @@ jvp_rules[xla_call_p] = xla_call_jvp_rule
```
```{code-cell} ipython3
@lru_cache()
def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
@ -1757,13 +1780,14 @@ def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
```{code-cell} ipython3
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused.
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, dims_in)
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
num_consts=len(new_consts))
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: List[BatchAxis]
@lru_cache()
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis]
) -> Tuple[Jaxpr, List[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
in_avals = [unmapped_aval(axis_size, d, v.aval)
@ -1784,6 +1808,7 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
```{code-cell} ipython3
@jit
def f(x):
print('tracing!')
y = sin(x) * 2.
z = - y + x
return z
@ -1793,6 +1818,8 @@ y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
```
@ -1831,6 +1858,8 @@ class DeviceArray:
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
input_handlers[DeviceArray] = lambda x: x.buf
jax_types.add(DeviceArray)
```
```{code-cell} ipython3
@ -1845,3 +1874,262 @@ y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
```
## Part 4: `linearize` and `vjp` (and `grad`!)
The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve
jaxprs as well. That's because both involve staging out, or delaying,
computation.
In the case of `linearize`, we want to stage out the linear part of a `jvp`
computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,
then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, where
```
y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot)
```
gives the same result for `(y, y_dot)` as
```
y, y_dot = jvp(f, (x,), (x_dot,))
```
and where the application of `f_lin` does not redo any of the linearization
work. We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.
To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:
we evaluate all the primal values as we trace, but stage the tangent
computations into a jaxpr.
```{code-cell} ipython3
def split_half(lst):
n, ragged = divmod(len(lst), 2)
assert not ragged
return lst[:n], lst[n:]
```
```{code-cell} ipython3
def linearize_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
return primals_out, f_lin
def linearize(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_lin(*tangents_in):
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
if in_tree != in_tree2: raise TypeError
tangents_out_flat = f_lin_flat(*tangents_in_flat)
return tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, f_lin
def vspace(aval: ShapedArray) -> ShapedArray:
return raise_to_shaped(aval)
```
```{code-cell} ipython3
class PartialVal(NamedTuple):
aval: ShapedArray
const: Optional[Any]
@classmethod
def known(cls, val: Any):
return PartialVal(get_aval(val), val)
@classmethod
def unknown(cls, aval: ShapedArray):
return PartialVal(aval, None)
is_known = property(lambda self: self.const is not None)
is_unknown = property(lambda self: self.const is None)
def partial_eval_flat(f, pvals_in: List[PartialVal]):
with new_main(PartialEvalTrace) as main:
trace = PartialEvalTrace(main)
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = tracers_to_jaxpr(tracers_in, tracers_out)
pvals_out = [t.pval for t in tracers_out]
return jaxpr, pvals_out, consts
```
```{code-cell} ipython3
from weakref import ref, ReferenceType
class LambdaBindingRecipe(NamedTuple): pass
class ConstRecipe(NamedTuple):
val: Any
class JaxprEqnRecipe:
prim: Primitive
tracers_in: List['PartialEvalTracer']
params: Dict[str, Any]
avals_out: List[ShapedArray]
tracer_refs_out: List['ReferenceType[PartialEvalTracer]']
def __init__(self, prim, tracers_in, params, avals_out, tracer_refs_out):
self.prim = prim
self.tracers_in = tracers_in
self.params = params
self.avals_out = avals_out
self.tracer_refs_out = tracer_refs_out
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: JaxprRecipe
def __init__(self, trace, pval, recipe):
self._trace = trace
self.pval = pval
self.recipe = recipe
@property
def aval(self):
return self.pval.aval
def full_lower(self):
if self.pval.is_known:
return full_lower(self.pval.const)
return self
class PartialEvalTrace(Trace):
def new_arg(self, pval: PartialVal) -> Any:
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
def lift(self, val: Any) -> PartialEvalTracer:
return PartialEvalTracer(self, PartialVal.known(val), None)
pure = lift
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
if tracer.pval.is_unknown:
return tracer
else:
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
def process_primitive(self, primitive, tracers, params):
if all(t.pval.is_known for t in tracers):
return bind(primitive, *map(full_lower, tracers), **params)
tracers_in = [self.instantiate_const(t) for t in tracers]
avals_in = [t.aval for t in tracers_in]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
for aval in avals_out]
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
map(ref, tracers_out))
for t in tracers_out: t.recipe = eqn
return tracers_out
```
```{code-cell} ipython3
def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
tracers_out: List[PartialEvalTracer]):
tracers_in = [t for t in tracers_in if t.pval.is_unknown]
tracers_out = [t for t in tracers_out if t.pval.is_unknown]
tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}
constvar_to_val = {}
constid_to_var = {}
processed_eqns = set()
eqns = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
elif isinstance(t.recipe, ConstRecipe):
val = t.recipe.val
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = tracer_to_var[id(t)] = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
processed_eqns.add(id(t.recipe))
else:
raise TypeError(t.recipe)
constvars, constvals = unzip2(constvar_to_val.items())
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
jaxpr = Jaxpr(in_binders, eqns, out_vars)
typecheck_jaxpr(jaxpr)
return jaxpr, constvals
def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe
) -> JaxprEqn:
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
out_binders = [Var(aval) for aval in recipe.avals_out]
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
```
```{code-cell} ipython3
def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
if not out_nodes: return []
out_nodes = remove_duplicates(out_nodes)
child_counts = {}
stack = list(out_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(node)] = 1
stack.extend(parents(node))
for node in out_nodes:
child_counts[id(node)] -= 1
sorted_nodes = []
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in parents(node):
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
check_toposort(sorted_nodes, parents)
return sorted_nodes
def remove_duplicates(lst):
seen = set()
return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):
seen = set()
for node in nodes:
assert all(id(parent) in seen for parent in parents(node))
seen.add(id(node))
```
```{code-cell} ipython3
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
```

View File

@ -26,6 +26,12 @@
# name: python3
# ---
import pdb, sys, traceback
def info(type, value, tb):
traceback.print_exception(type, value, tb)
pdb.pm()
sys.excepthook = info
# # Autodidax: JAX core from scratch
#
@ -215,6 +221,8 @@ class Tracer:
except AttributeError:
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
def swap(f): return lambda x, y: f(y, x)
class ShapedArray:
array_abstraction_level = 1
shape: Tuple[int]
@ -230,9 +238,9 @@ class ShapedArray:
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(add)
_radd = staticmethod(swap(add))
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
@staticmethod
@ -273,8 +281,13 @@ class ConcreteArray(ShapedArray):
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
elif type(x) in jax_types:
return ConcreteArray(np.asarray(x))
else:
raise TypeError(x)
jax_types = {bool, int, float,
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
# Notice that we actually have two `AbstractValue`s for arrays, representing
# different levels of abstraction. A `ShapedArray` represents the set of all
@ -297,11 +310,11 @@ def bind(prim, *args, **params):
# top trace's `Tracer` instances, and the call to `full_lower` is an optional
# optimization so that we unbox values out of `Tracer`s as much as possible.
from operator import attrgetter
import operator as op
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=trace_stack[0], key=attrgetter('level'))
default=trace_stack[0], key=op.attrgetter('level'))
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
return top_main.trace_type(top_main)
@ -332,6 +345,7 @@ def full_lower(val: Any):
def full_raise(trace: Trace, val: Any) -> Tracer:
if not isinstance(val, Tracer):
assert type(val) in jax_types
return trace.pure(val)
level = trace.main.level
if val._trace.main is trace.main:
@ -749,7 +763,10 @@ from functools import partial
def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)
vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)
@ -919,6 +936,9 @@ class Jaxpr(NamedTuple):
eqns: List[JaxprEqn]
outs: List[Atom]
def __hash__(self): return id(self)
__eq__ = op.is_
def raise_to_shaped(aval):
return ShapedArray(aval.shape, aval.dtype)
# -
@ -1551,6 +1571,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache()
def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
@ -1564,13 +1585,14 @@ def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
# +
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused.
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, dims_in)
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
num_consts=len(new_consts))
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: List[BatchAxis]
@lru_cache()
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis]
) -> Tuple[Jaxpr, List[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
in_avals = [unmapped_aval(axis_size, d, v.aval)
@ -1591,6 +1613,7 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
# +
@jit
def f(x):
print('tracing!')
y = sin(x) * 2.
z = - y + x
return z
@ -1600,6 +1623,8 @@ y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
# -
@ -1640,6 +1665,8 @@ class DeviceArray:
_gt = staticmethod(greater)
input_handlers[DeviceArray] = lambda x: x.buf
jax_types.add(DeviceArray)
# +
@jit
def f(x):
@ -1651,3 +1678,259 @@ x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
# -
# ## Part 4: `linearize` and `vjp` (and `grad`!)
#
# The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve
# jaxprs as well. That's because both involve staging out, or delaying,
# computation.
#
# In the case of `linearize`, we want to stage out the linear part of a `jvp`
# computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,
# then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, where
# ```
# y, f_lin = linearize(f, x)
# y_dot = f_lin(x_dot)
# ```
# gives the same result for `(y, y_dot)` as
# ```
# y, y_dot = jvp(f, (x,), (x_dot,))
# ```
# and where the application of `f_lin` does not redo any of the linearization
# work. We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.
#
# To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:
# we evaluate all the primal values as we trace, but stage the tangent
# computations into a jaxpr.
def split_half(lst):
n, ragged = divmod(len(lst), 2)
assert not ragged
return lst[:n], lst[n:]
# +
def linearize_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
return primals_out, f_lin
def linearize(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_lin(*tangents_in):
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
if in_tree != in_tree2: raise TypeError
tangents_out_flat = f_lin_flat(*tangents_in_flat)
return tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, f_lin
def vspace(aval: ShapedArray) -> ShapedArray:
return raise_to_shaped(aval)
# +
class PartialVal(NamedTuple):
aval: ShapedArray
const: Optional[Any]
@classmethod
def known(cls, val: Any):
return PartialVal(get_aval(val), val)
@classmethod
def unknown(cls, aval: ShapedArray):
return PartialVal(aval, None)
is_known = property(lambda self: self.const is not None)
is_unknown = property(lambda self: self.const is None)
def partial_eval_flat(f, pvals_in: List[PartialVal]):
with new_main(PartialEvalTrace) as main:
trace = PartialEvalTrace(main)
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = tracers_to_jaxpr(tracers_in, tracers_out)
pvals_out = [t.pval for t in tracers_out]
return jaxpr, pvals_out, consts
# +
from weakref import ref, ReferenceType
class LambdaBindingRecipe(NamedTuple): pass
class ConstRecipe(NamedTuple):
val: Any
class JaxprEqnRecipe:
prim: Primitive
tracers_in: List['PartialEvalTracer']
params: Dict[str, Any]
avals_out: List[ShapedArray]
tracer_refs_out: List['ReferenceType[PartialEvalTracer]']
def __init__(self, prim, tracers_in, params, avals_out, tracer_refs_out):
self.prim = prim
self.tracers_in = tracers_in
self.params = params
self.avals_out = avals_out
self.tracer_refs_out = tracer_refs_out
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: JaxprRecipe
def __init__(self, trace, pval, recipe):
self._trace = trace
self.pval = pval
self.recipe = recipe
@property
def aval(self):
return self.pval.aval
def full_lower(self):
if self.pval.is_known:
return full_lower(self.pval.const)
return self
class PartialEvalTrace(Trace):
def new_arg(self, pval: PartialVal) -> Any:
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
def lift(self, val: Any) -> PartialEvalTracer:
return PartialEvalTracer(self, PartialVal.known(val), None)
pure = lift
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
if tracer.pval.is_unknown:
return tracer
else:
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
def process_primitive(self, primitive, tracers, params):
if all(t.pval.is_known for t in tracers):
return bind(primitive, *map(full_lower, tracers), **params)
tracers_in = [self.instantiate_const(t) for t in tracers]
avals_in = [t.aval for t in tracers_in]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
for aval in avals_out]
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
map(ref, tracers_out))
for t in tracers_out: t.recipe = eqn
return tracers_out
# +
def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
tracers_out: List[PartialEvalTracer]):
tracers_in = [t for t in tracers_in if t.pval.is_unknown]
tracers_out = [t for t in tracers_out if t.pval.is_unknown]
tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}
constvar_to_val = {}
constid_to_var = {}
processed_eqns = set()
eqns = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
elif isinstance(t.recipe, ConstRecipe):
val = t.recipe.val
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = tracer_to_var[id(t)] = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
processed_eqns.add(id(t.recipe))
else:
raise TypeError(t.recipe)
constvars, constvals = unzip2(constvar_to_val.items())
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
jaxpr = Jaxpr(in_binders, eqns, out_vars)
typecheck_jaxpr(jaxpr)
return jaxpr, constvals
def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe
) -> JaxprEqn:
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
out_binders = [Var(aval) for aval in recipe.avals_out]
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
# +
def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
if not out_nodes: return []
out_nodes = remove_duplicates(out_nodes)
child_counts = {}
stack = list(out_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(node)] = 1
stack.extend(parents(node))
for node in out_nodes:
child_counts[id(node)] -= 1
sorted_nodes = []
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in parents(node):
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
check_toposort(sorted_nodes, parents)
return sorted_nodes
def remove_duplicates(lst):
seen = set()
return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):
seen = set()
for node in nodes:
assert all(id(parent) in seen for parent in parents(node))
seen.add(id(node))
# -
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))