From 3595af2ed2d258fe1cf6fd9a9bd805cad0858f01 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 12 Mar 2021 19:42:14 -0800 Subject: [PATCH] add linearize, vjp, grad. fix bugs. --- docs/autodidax.ipynb | 1657 ++++++++++++++++++++++++------------------ docs/autodidax.md | 915 ++++++++++++++++------- docs/autodidax.py | 730 ++++++++++++++++--- setup.cfg | 6 +- 4 files changed, 2255 insertions(+), 1053 deletions(-) diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 1721dac33..555608c5f 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2,7 +2,6 @@ "cells": [ { "cell_type": "raw", - "id": "surrounded-agent", "metadata": {}, "source": [ "---\n", @@ -27,16 +26,10 @@ "cell_type": "code", "execution_count": null, "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "outputs": [], - "source": [ - "import pdb, sys, traceback\n", - "def info(type, value, tb):\n", - " traceback.print_exception(type, value, tb)\n", - " pdb.pm()\n", - "sys.excepthook = info" - ] + "source": [] }, { "cell_type": "markdown", @@ -162,48 +155,22 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], "source": [ "from contextlib import contextmanager\n", - "from typing import Type, List, Optional, Any" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "from typing import Type, List, Optional, Any\n", + "\n", "class MainTrace(NamedTuple):\n", " level: int\n", " trace_type: Type['Trace']\n", - " global_data: Optional[Any]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " global_data: Optional[Any]\n", + "\n", "trace_stack: List[MainTrace] = []\n", - "dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3\n", + "\n", "@contextmanager\n", "def new_main(trace_type: Type['Trace'], global_data=None):\n", " level = len(trace_stack)\n", @@ -289,17 +256,8 @@ "outputs": [], "source": [ "import numpy as np\n", - "from typing import Tuple" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "from typing import Tuple\n", + "\n", "class Tracer:\n", " _trace: Trace\n", "\n", @@ -325,17 +283,8 @@ " try:\n", " return getattr(self.aval, name)\n", " except AttributeError:\n", - " raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")\n", + "\n", "def swap(f): return lambda x, y: f(y, x)" ] }, @@ -343,6 +292,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -383,17 +333,11 @@ "\n", " def __eq__(self, other):\n", " return (type(self) is type(other) and\n", - " self.shape == other.shape and self.dtype == other.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " self.shape == other.shape and self.dtype == other.dtype)\n", + "\n", + " def __repr__(self):\n", + " return f\"ShapedArray(shape={self.shape}, dtype={self.dtype})\"\n", + "\n", "class ConcreteArray(ShapedArray):\n", " array_abstraction_level = 2\n", " val: np.ndarray\n", @@ -409,35 +353,16 @@ "\n", " @staticmethod\n", " def _nonzero(tracer):\n", - " return bool(tracer.aval.val)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_end_of_cell_marker": 0, - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return bool(tracer.aval.val)\n", + "\n", "def get_aval(x):\n", " if isinstance(x, Tracer):\n", " return x.aval\n", " elif type(x) in jax_types:\n", " return ConcreteArray(np.asarray(x))\n", " else:\n", - " raise TypeError(x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " raise TypeError(x)\n", + "\n", "jax_types = {bool, int, float,\n", " np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}" ] @@ -486,21 +411,13 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], "source": [ - "import operator as op" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "import operator as op\n", + "\n", "def find_top_trace(xs) -> Trace:\n", " top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),\n", " default=trace_stack[0], key=op.attrgetter('level'))\n", @@ -536,6 +453,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -544,17 +462,8 @@ " if isinstance(val, Tracer):\n", " return val.full_lower()\n", " else:\n", - " return val" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return val\n", + "\n", "def full_raise(trace: Trace, val: Any) -> Tracer:\n", " if not isinstance(val, Tracer):\n", " assert type(val) in jax_types\n", @@ -598,6 +507,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -606,37 +516,12 @@ " pure = lift = lambda self, x: x # no boxing in Tracers needed\n", "\n", " def process_primitive(self, primitive, tracers, params):\n", - " return impl_rules[primitive](*tracers, **params)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ - "impl_rules = {}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return impl_rules[primitive](*tracers, **params)\n", + "\n", + "trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack\n", + "\n", + "impl_rules = {}\n", + "\n", "impl_rules[add_p] = lambda x, y: [np.add(x, y)]\n", "impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]\n", "impl_rules[neg_p] = lambda x: [np.negative(x)]\n", @@ -644,17 +529,8 @@ "impl_rules[cos_p] = lambda x: [np.cos(x)]\n", "impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]\n", "impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]\n", - "impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]\n", + "\n", "def broadcast_impl(x, *, shape, axes):\n", " return [np.broadcast_to(np.expand_dims(x, axes), shape)]\n", "impl_rules[broadcast_p] = broadcast_impl" @@ -671,6 +547,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -678,25 +555,8 @@ "def f(x):\n", " y = sin(x) * 2.\n", " z = - y + x\n", - " return z" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.7177599838802657\n" - ] - } - ], - "source": [ + " return z\n", + "\n", "print(f(3.0))" ] }, @@ -721,38 +581,21 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], "source": [ "def zeros_like(val):\n", - " return np.zeros_like(val)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return np.zeros_like(val)\n", + "\n", "def unzip2(pairs):\n", " lst1, lst2 = [], []\n", " for x1, x2 in pairs:\n", " lst1.append(x1)\n", " lst2.append(x2)\n", - " return lst1, lst2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return lst1, lst2\n", + "\n", "map_ = map\n", "def map(f, *xs):\n", " return list(map_(f, *xs))" @@ -770,6 +613,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -782,17 +626,8 @@ "\n", " @property\n", " def aval(self):\n", - " return get_aval(self.primal)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return get_aval(self.primal)\n", + "\n", "class JVPTrace(Trace):\n", " pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))\n", "\n", @@ -800,17 +635,8 @@ " primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)\n", " jvp_rule = jvp_rules[primitive]\n", " primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)\n", - " return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]\n", + "\n", "jvp_rules = {}" ] }, @@ -828,6 +654,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -835,87 +662,33 @@ "def add_jvp(primals, tangents):\n", " (x, y), (x_dot, y_dot) = primals, tangents\n", " return [x + y], [x_dot + y_dot]\n", - "jvp_rules[add_p] = add_jvp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "jvp_rules[add_p] = add_jvp\n", + "\n", "def mul_jvp(primals, tangents):\n", " (x, y), (x_dot, y_dot) = primals, tangents\n", " return [x * y], [x_dot * y + x * y_dot]\n", - "jvp_rules[mul_p] = mul_jvp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "jvp_rules[mul_p] = mul_jvp\n", + "\n", "def sin_jvp(primals, tangents):\n", " (x,), (x_dot,) = primals, tangents\n", " return [sin(x)], [cos(x) * x_dot]\n", - "jvp_rules[sin_p] = sin_jvp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "jvp_rules[sin_p] = sin_jvp\n", + "\n", "def cos_jvp(primals, tangents):\n", " (x,), (x_dot,) = primals, tangents\n", " return [cos(x)], [-sin(x) * x_dot]\n", - "jvp_rules[cos_p] = cos_jvp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "jvp_rules[cos_p] = cos_jvp\n", + "\n", "def neg_jvp(primals, tangents):\n", " (x,), (x_dot,) = primals, tangents\n", " return [neg(x)], [neg(x_dot)]\n", - "jvp_rules[neg_p] = neg_jvp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "jvp_rules[neg_p] = neg_jvp\n", + "\n", "def reduce_sum_jvp(primals, tangents, *, axis):\n", " (x,), (x_dot,) = primals, tangents\n", " return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]\n", - "jvp_rules[reduce_sum_p] = reduce_sum_jvp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "jvp_rules[reduce_sum_p] = reduce_sum_jvp\n", + "\n", "def greater_jvp(primals, tangents):\n", " (x, y), _ = primals, tangents\n", " out_primal = greater(x, y)\n", @@ -958,17 +731,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-0.9899924966004454\n", - "-0.9899924966004454\n" - ] - } - ], + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], "source": [ "x = 3.0\n", "y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))\n", @@ -982,16 +748,7 @@ "metadata": { "lines_to_next_cell": 1 }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.7177599838802657\n", - "2.979984993200891\n" - ] - } - ], + "outputs": [], "source": [ "def f(x):\n", " y = sin(x) * 2.\n", @@ -1010,18 +767,7 @@ "metadata": { "lines_to_next_cell": 1 }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-0.9899924966004454\n", - "-0.1411200080598672\n", - "0.9899924966004454\n", - "0.1411200080598672\n" - ] - } - ], + "outputs": [], "source": [ "def deriv(f):\n", " return lambda x: jvp_v1(f, (x,), (1.,))[1]\n", @@ -1039,16 +785,7 @@ "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.0\n", - "1.0\n" - ] - } - ], + "outputs": [], "source": [ "def f(x):\n", " if x > 0.: # Python control flow\n", @@ -1164,29 +901,11 @@ " store.set_value(out_tree)\n", " return out_flat\n", "\n", - " return flat_fun, store" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return flat_fun, store\n", + "\n", "class Empty: pass\n", - "empty = Empty()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "empty = Empty()\n", + "\n", "class Store:\n", " val = empty\n", "\n", @@ -1203,7 +922,7 @@ "execution_count": null, "metadata": { "lines_to_end_of_cell_marker": 0, - "lines_to_next_cell": 2 + "lines_to_next_cell": 1 }, "outputs": [], "source": [ @@ -1211,15 +930,20 @@ "from typing import Callable, Type, Hashable, Dict, Iterable, Iterator\n", "\n", "class NodeType(NamedTuple):\n", + " name: str\n", " to_iterable: Callable\n", " from_iterable: Callable\n", "\n", - "node_types: Dict[Type, NodeType] = {\n", - " tuple: NodeType(lambda t: (None, t), lambda _, xs: tuple(xs)),\n", - " list: NodeType( lambda l: (None, l), lambda _, xs: list(xs)),\n", - " dict: NodeType(lambda d: map(tuple, unzip2(sorted(d.items()))),\n", - " lambda keys, vals: dict(zip(keys, vals))),\n", - "}\n", + "def register_pytree_node(ty: Type, to_iter: Callable, from_iter: Callable\n", + " ) -> None:\n", + " node_types[ty] = NodeType(str(ty), to_iter, from_iter)\n", + "\n", + "node_types: Dict[Type, NodeType] = {}\n", + "register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))\n", + "register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs))\n", + "register_pytree_node(dict,\n", + " lambda d: map(tuple, unzip2(sorted(d.items()))),\n", + " lambda keys, vals: dict(zip(keys, vals)))\n", "\n", "class PyTreeDef(NamedTuple):\n", " node_type: NodeType\n", @@ -1267,24 +991,16 @@ "cell_type": "code", "execution_count": null, "metadata": { - "lines_to_next_cell": 2 + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 }, "outputs": [], "source": [ "def f(x):\n", " y = sin(x) * 2.\n", " z = - y + x\n", - " return {'hi': z, 'there': [x, y]}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return {'hi': z, 'there': [x, y]}\n", + "\n", "x, xdot = 3., 1.\n", "y, ydot = jvp(f, (x,), (xdot,))\n", "print(y)\n", @@ -1306,6 +1022,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -1313,17 +1030,8 @@ "def mapped_aval(batch_dim, aval):\n", " shape = list(aval.shape)\n", " del shape[batch_dim]\n", - " return ShapedArray(tuple(shape), aval.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return ShapedArray(tuple(shape), aval.dtype)\n", + "\n", "def move_batch_axis(axis_size, src, dst, x):\n", " if src is not_mapped:\n", " target_shape = list(np.shape(x))\n", @@ -1332,17 +1040,8 @@ " elif src == dst:\n", " return x\n", " else:\n", - " return moveaxis(x, src, dst)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return moveaxis(x, src, dst)\n", + "\n", "def moveaxis(x, src: int, dst: int):\n", " perm = [i for i in range(np.ndim(x)) if i != src]\n", " perm.insert(dst, src)\n", @@ -1365,40 +1064,13 @@ }, "outputs": [], "source": [ - "from typing import Union" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "from typing import Union\n", + "\n", "class NotMapped: pass\n", - "not_mapped = NotMapped()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ - "BatchAxis = Union[NotMapped, int]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "not_mapped = NotMapped()\n", + "\n", + "BatchAxis = Union[NotMapped, int]\n", + "\n", "class BatchTracer(Tracer):\n", " def __init__(self, trace, val, batch_dim: BatchAxis):\n", " self._trace = trace\n", @@ -1416,17 +1088,8 @@ " if self.batch_dim is not_mapped:\n", " return full_lower(self.val)\n", " else:\n", - " return self" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return self\n", + "\n", "class BatchTrace(Trace):\n", " pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)\n", "\n", @@ -1438,15 +1101,8 @@ "\n", " @property\n", " def axis_size(self):\n", - " return self.main.global_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + " return self.main.global_data\n", + "\n", "vmap_rules = {}" ] }, @@ -1471,21 +1127,13 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], "source": [ - "from functools import partial" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "from functools import partial\n", + "\n", "def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):\n", " (x, y), (x_bdim, y_bdim) = vals_in, dims_in\n", " if x_bdim != y_bdim:\n", @@ -1495,31 +1143,15 @@ " y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n", " return [op(x, y)], [x_bdim]\n", "vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)\n", - "vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)\n", + "\n", "def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):\n", " (x,), (x_bdim,) = vals_in, dims_in\n", " return [op(x)], [x_bdim]\n", "vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)\n", "vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)\n", - "vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)\n", + "\n", "def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):\n", " (x,), (x_bdim,) = vals_in, dims_in\n", " new_axis = axis + (x_bdim <= axis)\n", @@ -1528,13 +1160,6 @@ "vmap_rules[reduce_sum_p] = reduce_sum_batching_rule" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "-" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -1562,17 +1187,8 @@ " vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)\n", " outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)\n", " for val_out, bdim in zip(vals_out, bdims_out)]\n", - " return outs_transposed" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return outs_transposed\n", + "\n", "def vmap(f, in_axes):\n", " def batched_f(*args):\n", " args_flat, in_tree = tree_flatten(args)\n", @@ -1587,7 +1203,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "def add_one_to_a_scalar(scalar):\n", @@ -1624,19 +1242,7 @@ "lines_to_next_cell": 2 }, "source": [ - "That's it for `jvp` and `vmap`! Before moving on, let's highlight a few\n", - "simplifications in what we've seen so far compared to the full JAX\n", - "implementation:\n", - "1. **Fewer, simpler primitives.** More primitives means more interpretation\n", - "rules, and for more complex primitives (like for convolution or advanced\n", - "indexing) each rule is harder to write. But the overarching design is no\n", - "different.\n", - "2. **No pytrees.** Transformations expect arrays in, and either a single array\n", - " out or a flat list of arrays out.\n", - "3. **Missing optimization: no symbolic zeros in autodiff.**\n", - "4. **No special call primitives yet.** The core machinery needs to be\n", - " generalized to handle the most flexible kind of higher-order primitive,\n", - " used by `jax.custom_jvp` and `jax.custom_vjp`." + "That's it for `jvp` and `vmap`!" ] }, { @@ -1685,7 +1291,7 @@ "binder ::= :\n", "var ::= a | b | c | ...\n", "atom ::= | \n", - "literal ::= | \n", + "literal ::= | | | \n", "\n", "eqn ::= , ... = [ ] , ...\n", "```\n", @@ -1767,13 +1373,9 @@ }, "outputs": [], "source": [ - "class JaxprType:\n", - " in_types: List[ShapedArray]\n", - " out_type: List[ShapedArray]\n", - "\n", - " def __init__(self, in_types, out_types):\n", - " self.in_types = in_types\n", - " self.out_types = out_types\n", + "class JaxprType(NamedTuple):\n", + " in_types: List[ShapedArray]\n", + " out_types: List[ShapedArray]\n", "\n", " def __repr__(self):\n", " in_types = ', '.join(aval.str_short() for aval in self.in_types)\n", @@ -1791,7 +1393,7 @@ " in_types = [typecheck_atom(env, x) for x in eqn.inputs]\n", " out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)\n", " for out_binder, out_type in zip(eqn.out_binders, out_types):\n", - " if not types_equal(out_type, out_binder.aval): raise TypeError\n", + " if not out_type == out_binder.aval: raise TypeError\n", " for out_binder in eqn.out_binders:\n", " if out_binder in env: raise TypeError\n", " env.add(out_binder)\n", @@ -1807,10 +1409,7 @@ " elif isinstance(x, Lit):\n", " return raise_to_shaped(get_aval(x.val))\n", " else:\n", - " assert False\n", - "\n", - "def types_equal(a: ShapedArray, b: ShapedArray) -> bool:\n", - " return a.shape == b.shape and a.dtype == b.dtype" + " assert False" ] }, { @@ -1825,6 +1424,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -1836,6 +1436,7 @@ " return env[x] if type(x) is Var else x.val\n", "\n", " def write(v: Var, val: Any) -> None:\n", + " assert v not in env # single-assignment\n", " env[v] = val\n", "\n", " map(write, jaxpr.in_binders, args)\n", @@ -1843,17 +1444,8 @@ " in_vals = map(read, eqn.inputs)\n", " outs = bind(eqn.primitive, *in_vals, **eqn.params)\n", " map(write, eqn.out_binders, outs)\n", - " return map(read, jaxpr.outs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return map(read, jaxpr.outs)\n", + "\n", "def jaxpr_as_fun(jaxpr: Jaxpr):\n", " return lambda *args: eval_jaxpr(jaxpr, args)" ] @@ -2013,9 +1605,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, + "metadata": {}, "outputs": [], "source": [ "def broadcast_shapes(*shapes):\n", @@ -2067,18 +1657,9 @@ }, "outputs": [], "source": [ - "from functools import lru_cache" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ - "@lru_cache()\n", + "from functools import lru_cache\n", + "\n", + "@lru_cache() # ShapedArrays are hashable\n", "def make_jaxpr_v1(f, *avals_in):\n", " avals_in, in_tree = tree_flatten(avals_in)\n", " f, out_tree = flatten_fun(f, in_tree)\n", @@ -2165,7 +1746,9 @@ " if items:\n", " return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')\n", " else:\n", - " return pp(' ')" + " return pp(' ')\n", + "\n", + "Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))" ] }, { @@ -2175,7 +1758,7 @@ "outputs": [], "source": [ "jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))\n", - "print(pp_jaxpr(jaxpr))\n", + "print(jaxpr)\n", "print(typecheck_jaxpr(jaxpr))" ] }, @@ -2197,7 +1780,7 @@ "outputs": [], "source": [ "jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))\n", - "print(pp_jaxpr(jaxpr))" + "print(jaxpr)" ] }, { @@ -2227,18 +1810,9 @@ " try:\n", " yield\n", " finally:\n", - " dynamic_trace = prev_dynamic_trace" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ - "@lru_cache() # ShapedArrays are hashable\n", + " dynamic_trace = prev_dynamic_trace\n", + "\n", + "@lru_cache()\n", "def make_jaxpr(f, *avals_in):\n", " avals_in, in_tree = tree_flatten(avals_in)\n", " f, out_tree = flatten_fun(f, in_tree)\n", @@ -2251,20 +1825,10 @@ " outs = f(*tracers_in)\n", " tracers_out = [full_raise(trace, out) for out in outs]\n", " jaxpr, consts = builder.build(tracers_in, tracers_out)\n", - " return jaxpr, consts, out_tree()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_end_of_cell_marker": 0, - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return jaxpr, consts, out_tree()\n", + "\n", "jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))\n", - "print(pp_jaxpr(jaxpr))" + "print(jaxpr)" ] }, { @@ -2284,12 +1848,12 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "source": [ "That's it for jaxprs! With jaxprs in hand, we can implement the remaining\n", - "major JAX features. But before moving on, let's highlight some\n", - "simplifications we've made:\n", - "1. **Single-output primitives and jaxprs.**" + "major JAX features." ] }, { @@ -2308,28 +1872,28 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### \"Final style\" and \"initial style\"\n", + "### On-the-fly (\"final style\") and staged (\"initial style\") processing\n", "\n", "There are two options for how to handle higher-order primitives. Each requires\n", "a different approach to tracing and engenders different tradeoffs:\n", - "1. **`bind` takes a Python callable as an argument.** We defer forming a jaxpr\n", - " until as late as possible, namely until we're running the final interpreter\n", - " at the bottom of the interpreter stack. That way we can swap a `JaxprTrace`\n", - " in at the bottom of the interpreter stack and thus stage out rather than\n", - " execute all primitive operations. With this approach, transformations in\n", - " the stack get applied as we execute the Python callable as usual. This\n", - " approach can be very tricky to implement, but it's as general as possible\n", - " because it allows higher-order primitives not to raise the abstraction\n", - " level of their arguments and thus allows data-dependent Python control\n", - " flow. We refer to this approach as using a \"final-style higher-order\n", - " primitive\" employing the discharge-at-tracing-time \"final-style\n", - " transformations\" we've used so far.\n", - "2. **`bind` takes a jaxpr as an argument.** Before we call `bind`, in the\n", - " primitive wrapper we can just use `make_jaxpr` to form a jaxpr up-front and\n", - " be done with the Python callable entirely. In this case, `make_jaxpr` puts\n", - " its `JaxprTrace` at the top of the interpreter stack, and no\n", - " transformations lower in the stack, which might enter via closed-over\n", - " Tracers, are applied to the Python callable as we trace it.\n", + "1. **On-the-fly processing, where `bind` takes a Python callable as an\n", + " argument.** We defer forming a jaxpr until as late as possible, namely\n", + " until we're running the final interpreter at the bottom of the interpreter\n", + " stack. That way we can swap a `JaxprTrace` in at the bottom of the\n", + " interpreter stack and thus stage out rather than execute all primitive\n", + " operations. With this approach, transformations in the stack get applied as\n", + " we execute the Python callable as usual. This approach can be very tricky\n", + " to implement, but it's as general as possible because it allows\n", + " higher-order primitives not to raise the abstraction level of their\n", + " arguments and thus allows data-dependent Python control flow. We refer to\n", + " this approach as using a \"final-style higher-order primitive\" employing the\n", + " discharge-at-tracing-time \"final-style transformations\" we've used so far.\n", + "2. **Staged processing, where `bind` takes a jaxpr as an argument.** Before we\n", + " call `bind`, in the primitive wrapper we can just use `make_jaxpr` to form\n", + " a jaxpr up-front and be done with the Python callable entirely. In this\n", + " case, `make_jaxpr` puts its `JaxprTrace` at the top of the interpreter\n", + " stack, and no transformations lower in the stack, which might enter via\n", + " closed-over Tracers, are applied to the Python callable as we trace it.\n", " (Transformations applied within the Python callable are applied as usual,\n", " being added to the stack above the JaxprTrace.) Instead, the\n", " transformations lower in the stack are later applied to the call primitive,\n", @@ -2516,11 +2080,9 @@ " out_bufs = compiled.execute(input_bufs)\n", " return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]\n", "\n", - "input_handlers = {\n", - " int: xb.get_backend(None).buffer_from_pyval,\n", - " float: xb.get_backend(None).buffer_from_pyval,\n", - " np.ndarray: xb.get_backend(None).buffer_from_pyval,\n", - "}\n", + "default_input_handler = xb.get_backend(None).buffer_from_pyval\n", + "input_handlers = {ty: default_input_handler for ty in\n", + " [int, float, np.ndarray, np.float64, np.float32]}\n", "\n", "def handle_result(aval: ShapedArray, buf):\n", " del aval # Unused for now.\n", @@ -2543,39 +2105,22 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], "source": [ "def direct_translation(op, c, in_avals, in_vals):\n", " del c, in_avals\n", - " return [op(*in_vals)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + " return [op(*in_vals)]\n", + "\n", "xla_translations[add_p] = partial(direct_translation, xops.Add)\n", "xla_translations[mul_p] = partial(direct_translation, xops.Mul)\n", "xla_translations[neg_p] = partial(direct_translation, xops.Neg)\n", "xla_translations[sin_p] = partial(direct_translation, xops.Sin)\n", "xla_translations[cos_p] = partial(direct_translation, xops.Cos)\n", - "xla_translations[greater_p] = partial(direct_translation, xops.Gt)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "xla_translations[greater_p] = partial(direct_translation, xops.Gt)\n", + "\n", "def reduce_sum_translation(c, in_avals, in_vals, *, axis):\n", " (x_aval,), (x,) = in_avals, in_vals\n", " zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))\n", @@ -2583,22 +2128,27 @@ " shape = _xla_shape(ShapedArray((), x_aval.dtype))\n", " xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))\n", " return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]\n", - "xla_translations[reduce_sum_p] = reduce_sum_translation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "xla_translations[reduce_sum_p] = reduce_sum_translation\n", + "\n", "def broadcast_translation(c, in_avals, in_vals, *, shape, axes):\n", " x, = in_vals\n", " dims_complement = [i for i in range(len(shape)) if i not in axes]\n", " return [xops.BroadcastInDim(x, shape, dims_complement)]\n", - "xla_translations[broadcast_p] = broadcast_translation" + "xla_translations[broadcast_p] = broadcast_translation\n", + "\n", + "def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):\n", + " del num_consts # Only used at top-level.\n", + " # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.\n", + " subc = xb.make_computation_builder('inner xla_call')\n", + " xla_params = _xla_params(subc, in_avals)\n", + " outs = jaxpr_subcomp(subc, jaxpr, xla_params)\n", + " subc = subc.build(xops.Tuple(subc, outs))\n", + " return destructure_tuple(c, xops.Call(c, subc, in_vals))\n", + "xla_translations[xla_call_p] = xla_call_translation\n", + "\n", + "def destructure_tuple(c, tup):\n", + " num_elements = len(c.get_shape(tup).tuple_shapes())\n", + " return [xops.GetTupleElement(tup, i) for i in range(num_elements)]" ] }, { @@ -2720,17 +2270,8 @@ " n = len(outs) // 2\n", " primals_out, tangents_out = outs[:n], outs[n:]\n", " return primals_out, tangents_out\n", - "jvp_rules[xla_call_p] = xla_call_jvp_rule" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ + "jvp_rules[xla_call_p] = xla_call_jvp_rule\n", + "\n", "@lru_cache()\n", "def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:\n", " def jvp_traceable(*primals_and_tangents):\n", @@ -2747,7 +2288,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 1 }, "outputs": [], "source": [ @@ -2760,7 +2301,7 @@ "vmap_rules[xla_call_p] = xla_call_vmap_rule\n", "\n", "@lru_cache()\n", - "def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis]\n", + "def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis, ...]\n", " ) -> Tuple[Jaxpr, List[Any]]:\n", " vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n", " in_avals = [unmapped_aval(axis_size, d, v.aval)\n", @@ -2782,10 +2323,24 @@ "cell_type": "code", "execution_count": null, "metadata": { - "lines_to_end_of_cell_marker": 0, - "lines_to_next_cell": 2 + "lines_to_next_cell": 1 }, "outputs": [], + "source": [ + "def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):\n", + " del num_consts # Unused.\n", + " jaxpr_type = typecheck_jaxpr(jaxpr)\n", + " if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):\n", + " raise TypeError\n", + " return jaxpr_type.out_types\n", + "abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "@jit\n", "def f(x):\n", @@ -2797,10 +2352,27 @@ "x, xdot = 3., 1.\n", "y, ydot = jvp(f, (x,), (xdot,))\n", "print(y)\n", - "print(ydot)\n", - "\n", - "y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed\n", - "\n", + "print(ydot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "ys = vmap(f, (0,))(np.arange(3.))\n", "print(ys)" ] @@ -2858,7 +2430,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "@jit\n", @@ -2881,12 +2456,22 @@ "\n", "The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve\n", "jaxprs as well. That's because both involve staging out, or delaying,\n", - "computation.\n", + "computation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `linearize`\n", "\n", "In the case of `linearize`, we want to stage out the linear part of a `jvp`\n", "computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,\n", - "then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, where\n", - "```\n", + "then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to\n", + "mean \"the tangent type of `a`\" and using the \"lollipop\" `-o` rather than the\n", + "arrow `->` to indicate a _linear_ function. We define the semantics of\n", + "`linearize` in terms of `jvp` too:\n", + "```python\n", "y, f_lin = linearize(f, x)\n", "y_dot = f_lin(x_dot)\n", "```\n", @@ -2894,32 +2479,55 @@ "```\n", "y, y_dot = jvp(f, (x,), (x_dot,))\n", "```\n", - "and where the application of `f_lin` does not redo any of the linearization\n", - "work. We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.\n", + "where the application of `f_lin` does not redo any of the linearization work.\n", + "We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.\n", "\n", "To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:\n", "we evaluate all the primal values as we trace, but stage the tangent\n", - "computations into a jaxpr." + "computations into a jaxpr. This is our second way to build jaxprs. But where\n", + "`make_jaxpr` and its underlying `JaxprTrace`/`JaxprTracer` interpreters aim\n", + "to stage out every primitive bind, this second approach stages out only those\n", + "primitive binds with a data dependence on tagent inputs.\n", + "\n", + "First, some utilities:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], "source": [ - "def split_half(lst):\n", - " n, ragged = divmod(len(lst), 2)\n", - " assert not ragged\n", - " return lst[:n], lst[n:]" + "def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:\n", + " return lst[:n], lst[n:]\n", + "\n", + "def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:\n", + " assert not len(lst) % 2\n", + " return split_list(lst, len(lst) // 2)\n", + "\n", + "def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:\n", + " lists = lst1, lst2 = [], []\n", + " for b, x in zip(bs, l):\n", + " lists[b].append(x)\n", + " return lst1, lst2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll write `linearize` by combining `jvp` together with a general\n", + "partial evaluation transformation, to be added next:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -2927,14 +2535,12 @@ "def linearize_flat(f, *primals_in):\n", " pvals_in = ([PartialVal.known(x) for x in primals_in] +\n", " [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])\n", - "\n", " def f_jvp(*primals_tangents_in):\n", " primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))\n", " return [*primals_out, *tangents_out]\n", - "\n", " jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)\n", " primal_pvals, _ = split_half(pvals_out)\n", - " assert all(pval.is_known for pval in primal_pvals)\n", + " assert all(pval.is_known for pval in primal_pvals)\n", " primals_out = [pval.const for pval in primal_pvals]\n", " f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])\n", " return primals_out, f_lin\n", @@ -2954,7 +2560,90 @@ " return primals_out, f_lin\n", "\n", "def vspace(aval: ShapedArray) -> ShapedArray:\n", - " return raise_to_shaped(aval)" + " return raise_to_shaped(aval) # TODO handle integers?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we turn to the general partial evaluation transformation. The goal is to\n", + "accept a Python callable and a list of inputs, some known and some unknown,\n", + "and to produce (1) all the outputs which can be computed from the known\n", + "inputs, together with (2) a jaxpr representing the part of the Python\n", + "callable's computation which can only be performed after the remaining inputs\n", + "are known.\n", + "\n", + "This transformation can't be summarized purely in a type signature because its\n", + "behavior relies on the data dependencies inside the given Python callable and\n", + "not just its type. Nevertheless a heuristic type signature is useful. If we\n", + "assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where\n", + "`a1` and `a2` represent the known and unknown inputs, respectively, and where\n", + "`b1` only has a data depenence on `a1` while `b2` has some data dependnece on\n", + "`a2`, then we might write\n", + "\n", + "```\n", + "partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> (b1, res, (res, a2) -> b2)\n", + "```\n", + "\n", + "In words, given values for the inputs of type `a1`, `partial_eval` produces\n", + "the outputs of type `b1` along with \"residual\" values of type `res`\n", + "representing the intermediates required to complete the computation in the\n", + "second stage. It also produces a function of type `(res, a2) -> b2` which\n", + "accepts the residual values as well as the remaining inputs and produces the\n", + "remaining outputs.\n", + "\n", + "We like to think of partial evaluation as \"unzipping\" one computation into\n", + "two. For example, consider this jaxpr:\n", + "```\n", + "{ lambda a:float64[] .\n", + " let b:float64[] = sin a\n", + " c:float64[] = neg b\n", + " in ( c ) }\n", + "```\n", + "A jaxpr for the JVP would look like:\n", + "```\n", + "{ lambda a:float64[] b:float64 .\n", + " let c:float64[] = sin a\n", + " d:float64[] = cos a\n", + " e:float64[] = mul d b\n", + " f:float64[] = neg c\n", + " g:float64[] = neg e\n", + " in ( f, g ) }\n", + "```\n", + "If we imagine applying partial evaluation to this jaxpr with the first input\n", + "known and the second unknown, we end up 'unzipping' the JVP jaxpr into primal\n", + "and tangent jaxprs:\n", + "```\n", + "{ lambda a:float64[] .\n", + " let c:float64[] = sin a\n", + " d:float64[] = cos a\n", + " f:float64[] = neg c\n", + " in ( f, d ) }\n", + "```\n", + "```\n", + "{ lambda d:float64[] b:float64[] .\n", + " let e:float64[] = mul d b\n", + " g:float64[] = neg e\n", + " in ( g ) }\n", + "```\n", + "This second jaxpr is represents the linear computation that we want from\n", + "`linearize`.\n", + "\n", + "However, unlike in this jaxpr example, we want the computation on known values\n", + "to occur while evaluating the input Python callable. That is, rather than\n", + "forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all\n", + "operations out of Python first before sorting out what can be evaluated now\n", + "and what must be delayed, we want only to form a jaxpr for those operations\n", + "that _must_ be delayed due to a dependence on unknown inputs. In the context\n", + "of automatic differentiation, this is the feature ultimately enables us to\n", + "handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control\n", + "flow works because partial evaluation keeps the primal computation in Python.\n", + "As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out\n", + "what can be evaluated and what must be staged out into a jaxpr.\n", + "\n", + "First, we start with a `PartialVal` class, which represents a value that can\n", + "be either known or unknown:" ] }, { @@ -2977,9 +2666,27 @@ " def unknown(cls, aval: ShapedArray):\n", " return PartialVal(aval, None)\n", "\n", - " is_known = property(lambda self: self.const is not None)\n", - " is_unknown = property(lambda self: self.const is None)\n", - "\n", + " is_known = property(lambda self: self.const is not None)\n", + " is_unknown = property(lambda self: self.const is None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Partial evaluation will take a list of `PartialVal`s representing inputs, and\n", + "return a list of `PartialVal` outputs along with a jaxpr representing the\n", + "dleayed computation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "def partial_eval_flat(f, pvals_in: List[PartialVal]):\n", " with new_main(PartialEvalTrace) as main:\n", " trace = PartialEvalTrace(main)\n", @@ -2991,6 +2698,19 @@ " return jaxpr, pvals_out, consts" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next we need to implement `PartialEvalTrace` and its `PartialEvalTracer`. This\n", + "interpreter will build a jaxpr on the fly while tracking data dependencies. To\n", + "do so, it builds a bipartite directed acyclic graph (DAG) between\n", + "`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`\n", + "nodes, representing formulas for how compute some values from others. One kind\n", + "of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive\n", + "application, but we also have recipe types for constants and lambda binders:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -3001,7 +2721,8 @@ "source": [ "from weakref import ref, ReferenceType\n", "\n", - "class LambdaBindingRecipe(NamedTuple): pass\n", + "class LambdaBindingRecipe(NamedTuple):\n", + " pass\n", "\n", "class ConstRecipe(NamedTuple):\n", " val: Any\n", @@ -3020,8 +2741,18 @@ " self.avals_out = avals_out\n", " self.tracer_refs_out = tracer_refs_out\n", "\n", - "JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]\n", - "\n", + "JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "class PartialEvalTracer(Tracer):\n", " pval: PartialVal\n", " recipe: JaxprRecipe\n", @@ -3038,8 +2769,42 @@ " def full_lower(self):\n", " if self.pval.is_known:\n", " return full_lower(self.pval.const)\n", - " return self\n", + " return self" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `PartialEvalTrace` contains the logic for constructing the graph of\n", + "`JaxprRecipe`s and `PartialEvalTracer`s. Each argument corresponds to a\n", + "`LambdaBindingRecipe` leaf node, and each constant is a `ConstRecipe` leaf\n", + "node holding a reference to the constant. All other tracers and recipes come\n", + "from `process_primitive`, which forms tracers with `JaxprEqnRecipe`s.\n", "\n", + "For most primitives, the `process_primitive` logic is straightforward: if all\n", + "inputs are known then we can bind the primitive on the known values\n", + "(evaluating it in Python) and avoid forming tracers corresponding to the\n", + "output. If instead any input is unknown then we instead stage out into a\n", + "`JaxprEqnRecipe` representing the primitive application. To build the tracers\n", + "representing unknown outputs, we need avals, which get from the abstract eval\n", + "rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s\n", + "reference tracers; we avoid circular garbage by using weakrefs.)\n", + "\n", + "That `process_primitive` logic applies to most primitives, but `xla_call_p`\n", + "requires recursive treatment. So we special-case its rule in a\n", + "`partial_eval_rules` dict." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "class PartialEvalTrace(Trace):\n", " def new_arg(self, pval: PartialVal) -> Any:\n", " return PartialEvalTracer(self, pval, LambdaBindingRecipe())\n", @@ -3058,6 +2823,8 @@ " def process_primitive(self, primitive, tracers, params):\n", " if all(t.pval.is_known for t in tracers):\n", " return bind(primitive, *map(full_lower, tracers), **params)\n", + " rule = partial_eval_rules.get(primitive)\n", + " if rule: return rule(self, tracers, **params)\n", " tracers_in = [self.instantiate_const(t) for t in tracers]\n", " avals_in = [t.aval for t in tracers_in]\n", " avals_out = abstract_eval_rules[primitive](*avals_in, **params)\n", @@ -3066,7 +2833,18 @@ " eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,\n", " map(ref, tracers_out))\n", " for t in tracers_out: t.recipe = eqn\n", - " return tracers_out" + " return tracers_out\n", + "\n", + "partial_eval_rules = {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we can build graph representations of jaxprs with `PartialEvalTrace`,\n", + "we need a mechanism to convert the graph representation to a standard jaxpr.\n", + "The jaxpr corresponds to a topological sort of the graph." ] }, { @@ -3127,6 +2905,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_end_of_cell_marker": 0, "lines_to_next_cell": 1 }, "outputs": [], @@ -3173,29 +2952,503 @@ " seen.add(id(node))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can linearize!" + ] + }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "y, sin_lin = linearize(sin, 3.)\n", "print(y, sin(3.))\n", "print(sin_lin(1.), cos(3.))" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To handle linearize-of-jit, we still need to write a partial evaluation rule\n", + "for `xla_call_p`. Other than tracer bookkeeping, the main task is to perform\n", + "partial evaluation of a jaxpr, 'unzipping' it into two jaxprs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):\n", + " del num_consts # Unused.\n", + " in_unknowns = [not t.pval.is_known for t in tracers]\n", + " jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)\n", + " known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)\n", + " known_vals = [t.pval.const for t in known_tracers]\n", + " outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)\n", + " outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)\n", + " res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]\n", + " outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)\n", + " for v in jaxpr2.outs]\n", + " eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,\n", + " dict(jaxpr=jaxpr2, num_consts=0),\n", + " [v.aval for v in jaxpr2.outs], map(ref, outs2))\n", + " for t in outs2: t.recipe = eqn\n", + " outs1, outs2 = iter(outs1), iter(outs2)\n", + " return [next(outs2) if uk else next(outs1) for uk in out_unknowns]\n", + "partial_eval_rules[xla_call_p] = xla_call_partial_eval\n", + "\n", + "def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool]\n", + " ) -> Tuple[Jaxpr, Jaxpr, List[bool], int]:\n", + " env: Dict[Var, bool] = {}\n", + " residuals = set()\n", + "\n", + " def read(v: Atom) -> bool:\n", + " if type(v) is Lit: raise NotImplementedError\n", + " return env[v]\n", + "\n", + " def write(unk: bool, v: Var) -> None:\n", + " env[v] = unk\n", + "\n", + " def new_res(v: Var) -> Var:\n", + " return residuals.add(v) or v\n", + "\n", + " eqns1, eqns2 = [], []\n", + " map(write, in_unknowns, jaxpr.in_binders)\n", + " for eqn in jaxpr.eqns:\n", + " unks_in = map(read, eqn.inputs)\n", + " rule = partial_eval_jaxpr_rules.get(eqn.primitive)\n", + " if rule:\n", + " eqn1, eqn2, unks_out, res = rule(unks_in, eqn)\n", + " eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)\n", + " map(write, unks_out, eqn.out_binders)\n", + " elif any(unks_in):\n", + " inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]\n", + " eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))\n", + " map(partial(write, True), eqn.out_binders)\n", + " else:\n", + " eqns1.append(eqn)\n", + " map(partial(write, False), eqn.out_binders)\n", + " out_unknowns = map(read, jaxpr.outs)\n", + " residuals, num_res = list(residuals), len(residuals)\n", + "\n", + " ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)\n", + " outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)\n", + "\n", + " jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)\n", + " jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)\n", + " typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)\n", + "\n", + " return jaxpr1, jaxpr2, out_unknowns, num_res\n", + "\n", + "def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):\n", + " jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 )\n", + " jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res)\n", + " jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2\n", + "\n", + " a1, a2 = partition_list(unks_in, jaxprty.in_types)\n", + " b1, b2 = partition_list(unks_out, jaxprty.out_types)\n", + " b1_, res = split_list(jaxpr1ty.out_types, len(b1))\n", + " res_, a2_ = split_list(jaxpr2ty.in_types, len(res))\n", + " b2_ = jaxpr2ty.out_types\n", + "\n", + " if jaxpr1ty.in_types != a1: raise TypeError\n", + " if jaxpr2ty.out_types != b2: raise TypeError\n", + " if b1 != b1_: raise TypeError\n", + " if res != res_: raise TypeError\n", + " if a2 != a2_: raise TypeError\n", + " if b2 != b2_: raise TypeError\n", + "\n", + "partial_eval_jaxpr_rules = {}\n", + "\n", + "def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn\n", + " ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:\n", + " jaxpr = eqn.params['jaxpr']\n", + " jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)\n", + " ins1, ins2 = partition_list(unks_in, eqn.inputs)\n", + " outs1, outs2 = partition_list(unks_out, eqn.out_binders)\n", + " residuals, _ = split_list(jaxpr2.in_binders, num_res)\n", + " eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),\n", + " outs1 + residuals)\n", + " eqn2 = JaxprEqn(xla_call_p, residuals + ins2,\n", + " dict(jaxpr=jaxpr2, num_consts=0), outs2)\n", + " return eqn1, eqn2, unks_out, residuals\n", + "partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With that, we can compose `linearize` and `jit` however we like:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "@jit\n", + "def f(x):\n", + " y = sin(x) * 2.\n", + " z = - y + x\n", + " return z\n", + "\n", + "y, f_lin = linearize(f, 3.)\n", + "y_dot = f_lin(1.)\n", + "print(y, y_dot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "@jit\n", + "def f(x):\n", + " y = sin(x) * 2.\n", + " z = g(x, y)\n", + " return z\n", + "\n", + "@jit\n", + "def g(x, y):\n", + " return cos(x) + y\n", + "\n", + "y, f_lin = linearize(f, 3.)\n", + "y_dot = f_lin(1.)\n", + "print(y, y_dot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `vjp` and `grad`\n", + "\n", + "The `vjp` transformation works a lot like linearize. Its type signature is\n", + "analogous:\n", + "\n", + "```\n", + "linearize : (a -> b) -> a -> (b, T a -o T b)\n", + "vjp : (a -> b) -> a -> (b, T b -o T a)\n", + "```\n", + "\n", + "The only difference is that we transpose the linear part of the computation\n", + "before returning it, so that it goes from type `T a -o T b` to type `T b -o T\n", + "a`. That is, we'll implement `vjp` as, essentially,\n", + "\n", + "```\n", + "def vjp(f, x):\n", + " y, f_lin = linearize(f, x)\n", + " f_vjp = lambda y_bar: transpose(f_lin)(y_bar)\n", + " return y, f_vjp\n", + "```\n", + "\n", + "Since we have the linear computation as a jaxpr, not just a Python callable,\n", + "we can implement the transpose transformation as a jaxpr interpreter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def vjp_flat(f, *primals_in):\n", + " pvals_in = ([PartialVal.known(x) for x in primals_in] +\n", + " [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])\n", + " primal_pvals_in, tangent_pvals_in = split_half(pvals_in)\n", + " def f_jvp(*primals_tangents_in):\n", + " primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))\n", + " return [*primals_out, *tangents_out]\n", + " jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize\n", + " primal_pvals, _ = split_half(pvals_out)\n", + " assert all(pval.is_known for pval in primal_pvals)\n", + " primals_out = [pval.const for pval in primal_pvals]\n", + " transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]\n", + " f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)\n", + " return primals_out, f_vjp\n", + "\n", + "def vjp(f, *primals_in):\n", + " primals_in_flat, in_tree = tree_flatten(primals_in)\n", + " f, out_tree = flatten_fun(f, in_tree)\n", + " primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)\n", + " primals_out = tree_unflatten(out_tree(), primals_out_flat)\n", + "\n", + " def f_vjp(*cotangents_out):\n", + " cotangents_out_flat, _ = tree_flatten(cotangents_out)\n", + " cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)\n", + " return tree_unflatten(in_tree, cotangents_in_flat)\n", + "\n", + " return primals_out, f_vjp\n", + "\n", + "class UndefPrimal(NamedTuple):\n", + " aval: ShapedArray\n", + "\n", + "register_pytree_node(UndefPrimal,\n", + " lambda u: (u.aval, ()),\n", + " lambda aval, _: UndefPrimal(aval))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use `UndefPrimal` instances to indicate which arguments with respect to\n", + "with we want to transpose. These arise because in general, being explicit\n", + "about closed-over values, we want to transpose functions of type\n", + "`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the\n", + "inputs with respect to which the function is linear could be scattered through\n", + "the argument list. So we indicate the linear positions using `UndefPrimal`.\n", + "We register `UndefPrimal` as a pytree node because the pytree mechanism gives\n", + "a handy way to prune these placeholders out of argument lists.\n", + "\n", + "Next, we can write `eval_jaxpr_transposed`, along with transpose rules for\n", + "all primitives which can be linear in at least one argument:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "# NB: the analogous function in JAX is called 'backward_pass'\n", + "def eval_jaxpr_transposed(jaxpr: Jaxpr, args: List[Any], cotangents: List[Any]\n", + " ) -> List[Any]:\n", + " primal_env: Dict[Var, Any] = {}\n", + " ct_env: Dict[Var, Any] = {}\n", + "\n", + " def read_primal(x: Atom) -> Any:\n", + " return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val\n", + "\n", + " def write_primal(v: Var, val: Any) -> None:\n", + " if type(val) is not UndefPrimal:\n", + " primal_env[v] = val\n", + "\n", + " def read_cotangent(v: Var) -> Any:\n", + " return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))\n", + "\n", + " def write_cotangent(x: Atom, val: Any):\n", + " if type(x) is Var and val is not None:\n", + " ct_env[x] = add(ct_env[x], val) if x in ct_env else val\n", + "\n", + " map(write_primal, jaxpr.in_binders, args)\n", + " map(write_cotangent, jaxpr.outs, cotangents)\n", + " for eqn in jaxpr.eqns[::-1]:\n", + " primals_in = map(read_primal, eqn.inputs)\n", + " cts_in = map(read_cotangent, eqn.out_binders)\n", + " rule = transpose_rules[eqn.primitive]\n", + " cts_out = rule(cts_in, *primals_in, **eqn.params)\n", + " map(write_cotangent, eqn.inputs, cts_out)\n", + "\n", + " return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)\n", + " if type(x) is UndefPrimal]\n", + "\n", + "transpose_rules = {}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def mul_transpose_rule(cts, x, y):\n", + " z_bar, = cts\n", + " assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)\n", + " return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]\n", + "transpose_rules[mul_p] = mul_transpose_rule\n", + "\n", + "def neg_transpose_rule(cts, x):\n", + " ybar, = cts\n", + " assert type(x) is UndefPrimal\n", + " return [neg(ybar)]\n", + "transpose_rules[neg_p] = neg_transpose_rule\n", + "\n", + "def add_transpose_rule(cts, x, y):\n", + " z_bar, = cts\n", + " return [z_bar, z_bar]\n", + "transpose_rules[add_p] = add_transpose_rule\n", + "\n", + "def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):\n", + " del num_consts # Unused.\n", + " undef_primals = [type(x) is UndefPrimal for x in invals]\n", + " transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))\n", + " residuals, _ = partition_list(undef_primals, invals)\n", + " outs = bind(xla_call_p, *new_consts, *residuals, *cts,\n", + " jaxpr=transposed_jaxpr, num_consts=len(new_consts))\n", + " outs = iter(outs)\n", + " return [next(outs) if undef else None for undef in undef_primals]\n", + "transpose_rules[xla_call_p] = xla_call_transpose_rule\n", + "\n", + "@lru_cache()\n", + "def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: Tuple[bool, ...]\n", + " ) -> Tuple[Jaxpr, List[Any]]:\n", + " traceable = partial(eval_jaxpr_transposed, jaxpr)\n", + " avals_in, avals_out = typecheck_jaxpr(jaxpr)\n", + " args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]\n", + " trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))\n", + " return trans_jaxpr, consts" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we can linearize and transpose, we can finally write `grad`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def grad(f):\n", + " def gradfun(x, *xs):\n", + " y, f_vjp = vjp(f, x, *xs)\n", + " if np.shape(y) != (): raise TypeError\n", + " x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))\n", + " return x_bar\n", + " return gradfun" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "y, f_vjp = vjp(sin, 3.)\n", + "print(f_vjp(1.), cos(3.))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def f(x):\n", + " y = sin(x) * 2.\n", + " z = - y + x\n", + " return z\n", + "\n", + "print(grad(f)(3.))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "@jit\n", + "def f(x):\n", + " y = x * 2.\n", + " z = g(y)\n", + " return z\n", + "\n", + "@jit\n", + "def g(x):\n", + " return cos(x) * 2.\n", + "\n", + "print(grad(f)(3.))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here's something of a compositionality stress test:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from core_test.py fun_with_nested_calls_2\n", + "def foo(x):\n", + " @jit\n", + " def bar(y):\n", + " def baz(w):\n", + " q = jit(lambda x: y)(x)\n", + " q = q + jit(lambda: y)()\n", + " q = q + jit(lambda y: w + y)(y)\n", + " q = jit(lambda w: jit(sin)(x) * y)(1.0) + q\n", + " return q\n", + " p, t = jvp(baz, (x + 1.0,), (y,))\n", + " return t + (x * p)\n", + " return bar(x)\n", + "\n", + "def assert_allclose(*vals):\n", + " for v1, v2 in zip(vals[:-1], vals[1:]):\n", + " np.testing.assert_allclose(v1, v2)\n", + "\n", + "ans1 = f(3.)\n", + "ans2 = jit(f)(3.)\n", + "ans3, _ = jvp(f, (3.,), (5.,))\n", + "ans4, _ = jvp(jit(f), (3.,), (5.,))\n", + "assert_allclose(ans1, ans2, ans3, ans4)\n", + "\n", + "deriv1 = grad(f)(3.)\n", + "deriv2 = grad(jit(f))(3.)\n", + "deriv3 = jit(grad(jit(f)))(3.)\n", + "_, deriv4 = jvp(f, (3.,), (1.,))\n", + "_, deriv5 = jvp(jit(f), (3.,), (1.,))\n", + "assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)\n", + "\n", + "hess1 = grad(grad(f))(3.)\n", + "hess2 = grad(grad(jit(f)))(3.)\n", + "hess3 = grad(jit(grad(f)))(3.)\n", + "hess4 = jit(grad(grad(f)))(3.)\n", + "_, hess5 = jvp(grad(f), (3.,), (1.,))\n", + "_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))\n", + "_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))\n", + "assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)" + ] } ], "metadata": { - "colab": { - "collapsed_sections": [], - "last_runtime": { - "build_target": "//learning/deepmind/dm_python:dm_notebook3", - "kind": "private" - }, - "name": "Autodidax: JAX core from scratch", - "provenance": [], - "toc_visible": true - }, "jupytext": { "formats": "ipynb,md:myst,py", "main_language": "python" @@ -3203,20 +3456,8 @@ "kernelspec": { "display_name": "Python 3", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/docs/autodidax.md b/docs/autodidax.md index 50a93aec4..f433c8fd0 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -32,12 +32,8 @@ limitations under the License. --- ``` -```{code-cell} ipython3 -import pdb, sys, traceback -def info(type, value, tb): - traceback.print_exception(type, value, tb) - pdb.pm() -sys.excepthook = info +```{code-cell} + ``` # Autodidax: JAX core from scratch @@ -80,7 +76,7 @@ We can implement stacks of interpreters and even have them all discharge on the fly as we execute the Python function to be transformed. To start, let's define these primitives so that we can intercept their application: -```{code-cell} ipython3 +```{code-cell} from typing import NamedTuple class Primitive(NamedTuple): @@ -134,24 +130,18 @@ to the element's height in the stack), an interpreter type (which we'll call a needs. We call each element a `MainTrace`, though maybe "Interpreter" would be more descriptive. -```{code-cell} ipython3 +```{code-cell} from contextlib import contextmanager from typing import Type, List, Optional, Any -``` -```{code-cell} ipython3 class MainTrace(NamedTuple): level: int trace_type: Type['Trace'] global_data: Optional[Any] -``` -```{code-cell} ipython3 trace_stack: List[MainTrace] = [] dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 -``` -```{code-cell} ipython3 @contextmanager def new_main(trace_type: Type['Trace'], global_data=None): level = len(trace_stack) @@ -181,7 +171,7 @@ and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps carrying some extra context data used by the interpreter. A `Trace` handles boxing up vales into `Tracers` and also handles primitive application. -```{code-cell} ipython3 +```{code-cell} class Trace: main: MainTrace @@ -211,12 +201,10 @@ relationship between `Tracer`s and `AbstractValue`s is that there's one `Tracer` per transformation, and at least one `AbstractValue` per base type, like arrays.) -```{code-cell} ipython3 +```{code-cell} import numpy as np from typing import Tuple -``` -```{code-cell} ipython3 class Tracer: _trace: Trace @@ -243,13 +231,11 @@ class Tracer: return getattr(self.aval, name) except AttributeError: raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") -``` -```{code-cell} ipython3 def swap(f): return lambda x, y: f(y, x) ``` -```{code-cell} ipython3 +```{code-cell} class ShapedArray: array_abstraction_level = 1 shape: Tuple[int] @@ -287,9 +273,10 @@ class ShapedArray: def __eq__(self, other): return (type(self) is type(other) and self.shape == other.shape and self.dtype == other.dtype) -``` -```{code-cell} ipython3 + def __repr__(self): + return f"ShapedArray(shape={self.shape}, dtype={self.dtype})" + class ConcreteArray(ShapedArray): array_abstraction_level = 2 val: np.ndarray @@ -306,9 +293,7 @@ class ConcreteArray(ShapedArray): @staticmethod def _nonzero(tracer): return bool(tracer.aval.val) -``` -```{code-cell} ipython3 def get_aval(x): if isinstance(x, Tracer): return x.aval @@ -316,9 +301,7 @@ def get_aval(x): return ConcreteArray(np.asarray(x)) else: raise TypeError(x) -``` -```{code-cell} ipython3 jax_types = {bool, int, float, np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray} ``` @@ -331,7 +314,7 @@ singleton set consisting of a single array value. Now that we've set up the interpreter stack, the Trace/Tracer API for interpreters, and abstract values, we can come back to implement `bind`: -```{code-cell} ipython3 +```{code-cell} def bind(prim, *args, **params): top_trace = find_top_trace(args) tracers = [full_raise(top_trace, arg) for arg in args] @@ -346,11 +329,9 @@ rule. The calls to `full_raise` just ensure that the inputs are boxed in the top trace's `Tracer` instances, and the call to `full_lower` is an optional optimization so that we unbox values out of `Tracer`s as much as possible. -```{code-cell} ipython3 +```{code-cell} import operator as op -``` -```{code-cell} ipython3 def find_top_trace(xs) -> Trace: top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)), default=trace_stack[0], key=op.attrgetter('level')) @@ -377,15 +358,13 @@ operation. That's worth exploring! JAX is designed around data dependence in large part because that's so natural for automatic differentiation, and JAX's roots are in autodiff. But it may be over-fit. -```{code-cell} ipython3 +```{code-cell} def full_lower(val: Any): if isinstance(val, Tracer): return val.full_lower() else: return val -``` -```{code-cell} ipython3 def full_raise(trace: Trace, val: Any) -> Tracer: if not isinstance(val, Tracer): assert type(val) in jax_types @@ -417,23 +396,17 @@ That's it for the JAX core! Now we can start adding interpreters. We'll start with the simplest interpreter: the evaluation interpreter that will sit at the bottom of the interpreter stack. -```{code-cell} ipython3 +```{code-cell} class EvalTrace(Trace): pure = lift = lambda self, x: x # no boxing in Tracers needed def process_primitive(self, primitive, tracers, params): return impl_rules[primitive](*tracers, **params) -``` -```{code-cell} ipython3 trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack -``` -```{code-cell} ipython3 impl_rules = {} -``` -```{code-cell} ipython3 impl_rules[add_p] = lambda x, y: [np.add(x, y)] impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)] impl_rules[neg_p] = lambda x: [np.negative(x)] @@ -442,9 +415,7 @@ impl_rules[cos_p] = lambda x: [np.cos(x)] impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)] impl_rules[greater_p] = lambda x, y: [np.greater(x, y)] impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)] -``` -```{code-cell} ipython3 def broadcast_impl(x, *, shape, axes): return [np.broadcast_to(np.expand_dims(x, axes), shape)] impl_rules[broadcast_p] = broadcast_impl @@ -452,14 +423,12 @@ impl_rules[broadcast_p] = broadcast_impl With this interpreter, we can evaluate user functions: -```{code-cell} ipython3 +```{code-cell} def f(x): y = sin(x) * 2. z = - y + x return z -``` -```{code-cell} ipython3 print(f(3.0)) ``` @@ -472,21 +441,17 @@ that now we can add some real transformations. First, a few helper functions: -```{code-cell} ipython3 +```{code-cell} def zeros_like(val): return np.zeros_like(val) -``` -```{code-cell} ipython3 def unzip2(pairs): lst1, lst2 = [], [] for x1, x2 in pairs: lst1.append(x1) lst2.append(x2) return lst1, lst2 -``` -```{code-cell} ipython3 map_ = map def map(f, *xs): return list(map_(f, *xs)) @@ -495,7 +460,7 @@ def map(f, *xs): The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The `Trace` applies JVP rules. -```{code-cell} ipython3 +```{code-cell} class JVPTracer(Tracer): def __init__(self, trace, primal, tangent): self._trace = trace @@ -505,9 +470,7 @@ class JVPTracer(Tracer): @property def aval(self): return get_aval(self.primal) -``` -```{code-cell} ipython3 class JVPTrace(Trace): pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val)) @@ -516,9 +479,7 @@ class JVPTrace(Trace): jvp_rule = jvp_rules[primitive] primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params) return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)] -``` -```{code-cell} ipython3 jvp_rules = {} ``` @@ -527,49 +488,37 @@ minimal amount of context, which is a zero tangent value. Let's add some JVP rules for primitives: -```{code-cell} ipython3 +```{code-cell} def add_jvp(primals, tangents): (x, y), (x_dot, y_dot) = primals, tangents return [x + y], [x_dot + y_dot] jvp_rules[add_p] = add_jvp -``` -```{code-cell} ipython3 def mul_jvp(primals, tangents): (x, y), (x_dot, y_dot) = primals, tangents return [x * y], [x_dot * y + x * y_dot] jvp_rules[mul_p] = mul_jvp -``` -```{code-cell} ipython3 def sin_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents return [sin(x)], [cos(x) * x_dot] jvp_rules[sin_p] = sin_jvp -``` -```{code-cell} ipython3 def cos_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents return [cos(x)], [-sin(x) * x_dot] jvp_rules[cos_p] = cos_jvp -``` -```{code-cell} ipython3 def neg_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents return [neg(x)], [neg(x_dot)] jvp_rules[neg_p] = neg_jvp -``` -```{code-cell} ipython3 def reduce_sum_jvp(primals, tangents, *, axis): (x,), (x_dot,) = primals, tangents return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)] jvp_rules[reduce_sum_p] = reduce_sum_jvp -``` -```{code-cell} ipython3 def greater_jvp(primals, tangents): (x, y), _ = primals, tangents out_primal = greater(x, y) @@ -579,7 +528,7 @@ jvp_rules[greater_p] = greater_jvp Finally, we add a transformation API to kick off the trace: -```{code-cell} ipython3 +```{code-cell} def jvp_v1(f, primals, tangents): with new_main(JVPTrace) as main: trace = JVPTrace(main) @@ -592,14 +541,14 @@ def jvp_v1(f, primals, tangents): And with that, we can differentiate! -```{code-cell} ipython3 +```{code-cell} x = 3.0 y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,)) print(sin_deriv_at_3) print(cos(3.0)) ``` -```{code-cell} ipython3 +```{code-cell} def f(x): y = sin(x) * 2. z = - y + x @@ -611,7 +560,7 @@ print(y) print(ydot) ``` -```{code-cell} ipython3 +```{code-cell} def deriv(f): return lambda x: jvp_v1(f, (x,), (1.,))[1] @@ -621,7 +570,7 @@ print(deriv(deriv(deriv(sin)))(3.)) print(deriv(deriv(deriv(deriv(sin))))(3.)) ``` -```{code-cell} ipython3 +```{code-cell} def f(x): if x > 0.: # Python control flow return 2. * x @@ -649,7 +598,7 @@ Here's how we'd like to write `jvp`, assuming the user always gives us functions that take arrays as inputs and produces a flat list of arrays as outputs: -```{code-cell} ipython3 +```{code-cell} def jvp_flat(f, primals, tangents): with new_main(JVPTrace) as main: trace = JVPTrace(main) @@ -663,7 +612,7 @@ def jvp_flat(f, primals, tangents): To support user functions that have arbitrary containers in the inputs and outputs, here's how we'd write the user-facing `jvp` wrapper: -```{code-cell} ipython3 +```{code-cell} def jvp(f, primals, tangents): primals_flat, in_tree = tree_flatten(primals) tangents_flat, in_tree2 = tree_flatten(tangents) @@ -686,7 +635,7 @@ types](https://en.wikipedia.org/wiki/Substructural_type_system).) All that remains is to write `tree_flatten`, `tree_unflatten`, and `flatten_fun`: -```{code-cell} ipython3 +```{code-cell} def flatten_fun(f, in_tree): store = Store() @@ -698,14 +647,10 @@ def flatten_fun(f, in_tree): return out_flat return flat_fun, store -``` -```{code-cell} ipython3 class Empty: pass empty = Empty() -``` -```{code-cell} ipython3 class Store: val = empty @@ -717,20 +662,25 @@ class Store: return self.val ``` -```{code-cell} ipython3 +```{code-cell} import itertools as it from typing import Callable, Type, Hashable, Dict, Iterable, Iterator class NodeType(NamedTuple): + name: str to_iterable: Callable from_iterable: Callable -node_types: Dict[Type, NodeType] = { - tuple: NodeType(lambda t: (None, t), lambda _, xs: tuple(xs)), - list: NodeType( lambda l: (None, l), lambda _, xs: list(xs)), - dict: NodeType(lambda d: map(tuple, unzip2(sorted(d.items()))), - lambda keys, vals: dict(zip(keys, vals))), -} +def register_pytree_node(ty: Type, to_iter: Callable, from_iter: Callable + ) -> None: + node_types[ty] = NodeType(str(ty), to_iter, from_iter) + +node_types: Dict[Type, NodeType] = {} +register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs)) +register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs)) +register_pytree_node(dict, + lambda d: map(tuple, unzip2(sorted(d.items()))), + lambda keys, vals: dict(zip(keys, vals))) class PyTreeDef(NamedTuple): node_type: NodeType @@ -769,14 +719,12 @@ With this pytree-handling `jvp` impelmentation, we can now handle arbitrary input and output containers. That'll come in handy with future transformations too! -```{code-cell} ipython3 +```{code-cell} def f(x): y = sin(x) * 2. z = - y + x return {'hi': z, 'there': [x, y]} -``` -```{code-cell} ipython3 x, xdot = 3., 1. y, ydot = jvp(f, (x,), (xdot,)) print(y) @@ -789,14 +737,12 @@ First, a couple helper functions, one for producing mapped abstract values from unmapped ones (by removing an axis), and one for moving batch dimensions around: -```{code-cell} ipython3 +```{code-cell} def mapped_aval(batch_dim, aval): shape = list(aval.shape) del shape[batch_dim] return ShapedArray(tuple(shape), aval.dtype) -``` -```{code-cell} ipython3 def move_batch_axis(axis_size, src, dst, x): if src is not_mapped: target_shape = list(np.shape(x)) @@ -806,9 +752,7 @@ def move_batch_axis(axis_size, src, dst, x): return x else: return moveaxis(x, src, dst) -``` -```{code-cell} ipython3 def moveaxis(x, src: int, dst: int): perm = [i for i in range(np.ndim(x)) if i != src] perm.insert(dst, src) @@ -818,20 +762,14 @@ def moveaxis(x, src: int, dst: int): The `Tracer` for vectorized batching carries a batched value and an optional integer indicating which axis (if any) is the batch axis. -```{code-cell} ipython3 +```{code-cell} from typing import Union -``` -```{code-cell} ipython3 class NotMapped: pass not_mapped = NotMapped() -``` -```{code-cell} ipython3 BatchAxis = Union[NotMapped, int] -``` -```{code-cell} ipython3 class BatchTracer(Tracer): def __init__(self, trace, val, batch_dim: BatchAxis): self._trace = trace @@ -850,9 +788,7 @@ class BatchTracer(Tracer): return full_lower(self.val) else: return self -``` -```{code-cell} ipython3 class BatchTrace(Trace): pure = lift = lambda self, val: BatchTracer(self, val, not_mapped) @@ -865,9 +801,7 @@ class BatchTrace(Trace): @property def axis_size(self): return self.main.global_data -``` -```{code-cell} ipython3 vmap_rules = {} ``` @@ -883,11 +817,9 @@ size. Next we can define batching interpreter rules for each primitive: -```{code-cell} ipython3 +```{code-cell} from functools import partial -``` -```{code-cell} ipython3 def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in): (x, y), (x_bdim, y_bdim) = vals_in, dims_in if x_bdim != y_bdim: @@ -898,18 +830,14 @@ def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in): return [op(x, y)], [x_bdim] vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add) vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul) -``` -```{code-cell} ipython3 def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in): (x,), (x_bdim,) = vals_in, dims_in return [op(x)], [x_bdim] vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin) vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos) vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg) -``` -```{code-cell} ipython3 def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis): (x,), (x_bdim,) = vals_in, dims_in new_axis = axis + (x_bdim <= axis) @@ -918,13 +846,9 @@ def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis): vmap_rules[reduce_sum_p] = reduce_sum_batching_rule ``` -- - -+++ - Finally, we add a transformation API to kick off the trace: -```{code-cell} ipython3 +```{code-cell} def vmap_flat(f, in_axes, *args): axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes) if ax is not not_mapped} @@ -938,9 +862,7 @@ def vmap_flat(f, in_axes, *args): outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out) for val_out, bdim in zip(vals_out, bdims_out)] return outs_transposed -``` -```{code-cell} ipython3 def vmap(f, in_axes): def batched_f(*args): args_flat, in_tree = tree_flatten(args) @@ -952,7 +874,7 @@ def vmap(f, in_axes): return batched_f ``` -```{code-cell} ipython3 +```{code-cell} def add_one_to_a_scalar(scalar): assert np.ndim(scalar) == 0 return 1 + scalar @@ -964,7 +886,7 @@ print(vector_in) print(vector_out) ``` -```{code-cell} ipython3 +```{code-cell} def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x,), (v,))[1] vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2) @@ -976,19 +898,7 @@ def f(x): jacfwd(f, np.arange(3.)) ``` -That's it for `jvp` and `vmap`! Before moving on, let's highlight a few -simplifications in what we've seen so far compared to the full JAX -implementation: -1. **Fewer, simpler primitives.** More primitives means more interpretation -rules, and for more complex primitives (like for convolution or advanced -indexing) each rule is harder to write. But the overarching design is no -different. -2. **No pytrees.** Transformations expect arrays in, and either a single array - out or a flat list of arrays out. -3. **Missing optimization: no symbolic zeros in autodiff.** -4. **No special call primitives yet.** The core machinery needs to be - generalized to handle the most flexible kind of higher-order primitive, - used by `jax.custom_jvp` and `jax.custom_vjp`. +That's it for `jvp` and `vmap`! +++ @@ -1031,7 +941,7 @@ jaxpr ::= binder ::= : var ::= a | b | c | ... atom ::= | -literal ::= | +literal ::= | | | eqn ::= , ... = [ ] , ... ``` @@ -1049,7 +959,7 @@ How do we represent these as Python data structures? We reuse ShapedArrays to represent types, and we can represent the term syntax with a few Python structs: -```{code-cell} ipython3 +```{code-cell} from typing import Set class Var: @@ -1088,14 +998,10 @@ Type-checking a jaxpr involves checking that there are no unbound variables, that variables are only bound once, and that for each equation the type of the primitive application matches the type of the output binders. -```{code-cell} ipython3 -class JaxprType: - in_types: List[ShapedArray] - out_type: List[ShapedArray] - - def __init__(self, in_types, out_types): - self.in_types = in_types - self.out_types = out_types +```{code-cell} +class JaxprType(NamedTuple): + in_types: List[ShapedArray] + out_types: List[ShapedArray] def __repr__(self): in_types = ', '.join(aval.str_short() for aval in self.in_types) @@ -1113,7 +1019,7 @@ def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType: in_types = [typecheck_atom(env, x) for x in eqn.inputs] out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params) for out_binder, out_type in zip(eqn.out_binders, out_types): - if not types_equal(out_type, out_binder.aval): raise TypeError + if not out_type == out_binder.aval: raise TypeError for out_binder in eqn.out_binders: if out_binder in env: raise TypeError env.add(out_binder) @@ -1130,15 +1036,12 @@ def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray: return raise_to_shaped(get_aval(x.val)) else: assert False - -def types_equal(a: ShapedArray, b: ShapedArray) -> bool: - return a.shape == b.shape and a.dtype == b.dtype ``` We can apply the function represented by a jaxpr to arguments with a simple interpreter. -```{code-cell} ipython3 +```{code-cell} def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]: env: Dict[Var, Any] = {} @@ -1146,6 +1049,7 @@ def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]: return env[x] if type(x) is Var else x.val def write(v: Var, val: Any) -> None: + assert v not in env # single-assignment env[v] = val map(write, jaxpr.in_binders, args) @@ -1154,9 +1058,7 @@ def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]: outs = bind(eqn.primitive, *in_vals, **eqn.params) map(write, eqn.out_binders, outs) return map(read, jaxpr.outs) -``` -```{code-cell} ipython3 def jaxpr_as_fun(jaxpr: Jaxpr): return lambda *args: eval_jaxpr(jaxpr, args) ``` @@ -1173,7 +1075,7 @@ a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one used by `jit`, which is also used by control flow primitives like `lax.cond`, `lax.while_loop`, and `lax.scan`. -```{code-cell} ipython3 +```{code-cell} # NB: the analogous class in JAX is called 'DynamicJaxprTracer' class JaxprTracer(Tracer): __slots__ = ['aval'] @@ -1219,7 +1121,7 @@ abstract_eval_rules = {} Notice that we keep as interpreter-global data a builder object, which keeps track of variables, constants, and eqns as we build up the jaxpr. -```{code-cell} ipython3 +```{code-cell} class JaxprBuilder: eqns: List[JaxprEqn] tracer_to_var: Dict[int, Var] @@ -1280,7 +1182,7 @@ produce ConcreteArray outputs as well). We'll reuse these abstract evaluation rules for the other jaxpr-producing trace machinery, where the potential extra generality is useful. -```{code-cell} ipython3 +```{code-cell} def broadcast_shapes(*shapes): assert len(shapes) > 1 for sizes in zip(*shapes): @@ -1317,12 +1219,10 @@ abstract_eval_rules[broadcast_p] = broadcast_abstract_eval To check our implementation of jaxprs, we can add a `make_jaxpr` transformation and a pretty-printer: -```{code-cell} ipython3 +```{code-cell} from functools import lru_cache -``` -```{code-cell} ipython3 -@lru_cache() +@lru_cache() # ShapedArrays are hashable def make_jaxpr_v1(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1337,7 +1237,7 @@ def make_jaxpr_v1(f, *avals_in): return jaxpr, consts, out_tree() ``` -```{code-cell} ipython3 +```{code-cell} :tags: [hide-input] from collections import defaultdict @@ -1401,11 +1301,13 @@ def pp_params(params: Dict[str, Any]) -> PPrint: return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ') else: return pp(' ') + +Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self)) ``` -```{code-cell} ipython3 +```{code-cell} jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.))) -print(pp_jaxpr(jaxpr)) +print(jaxpr) print(typecheck_jaxpr(jaxpr)) ``` @@ -1413,9 +1315,9 @@ But there's a limitation here: because of how `find_top_trace` operates by data dependence, `make_jaxpr_v1` can't stage out all the primitive operations performed by the Python callable it's given. For example: -```{code-cell} ipython3 +```{code-cell} jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.)) -print(pp_jaxpr(jaxpr)) +print(jaxpr) ``` This is precisely the issue that @@ -1425,7 +1327,7 @@ applied, regardless of whether any inputs to `bind` are boxed in corresponding `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` global defined in Part 1: -```{code-cell} ipython3 +```{code-cell} @contextmanager def new_dynamic(main: MainTrace): global dynamic_trace @@ -1434,10 +1336,8 @@ def new_dynamic(main: MainTrace): yield finally: dynamic_trace = prev_dynamic_trace -``` -```{code-cell} ipython3 -@lru_cache() # ShapedArrays are hashable +@lru_cache() def make_jaxpr(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1451,11 +1351,9 @@ def make_jaxpr(f, *avals_in): tracers_out = [full_raise(trace, out) for out in outs] jaxpr, consts = builder.build(tracers_in, tracers_out) return jaxpr, consts, out_tree() -``` -```{code-cell} ipython3 jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.)) -print(pp_jaxpr(jaxpr)) +print(jaxpr) ``` Using `dynamic_trace` this way is conceptually the same as stashing the @@ -1471,9 +1369,7 @@ system state simpler. +++ That's it for jaxprs! With jaxprs in hand, we can implement the remaining -major JAX features. But before moving on, let's highlight some -simplifications we've made: -1. **Single-output primitives and jaxprs.** +major JAX features. +++ @@ -1486,28 +1382,28 @@ by a function. +++ -### "Final style" and "initial style" +### On-the-fly ("final style") and staged ("initial style") processing There are two options for how to handle higher-order primitives. Each requires a different approach to tracing and engenders different tradeoffs: -1. **`bind` takes a Python callable as an argument.** We defer forming a jaxpr - until as late as possible, namely until we're running the final interpreter - at the bottom of the interpreter stack. That way we can swap a `JaxprTrace` - in at the bottom of the interpreter stack and thus stage out rather than - execute all primitive operations. With this approach, transformations in - the stack get applied as we execute the Python callable as usual. This - approach can be very tricky to implement, but it's as general as possible - because it allows higher-order primitives not to raise the abstraction - level of their arguments and thus allows data-dependent Python control - flow. We refer to this approach as using a "final-style higher-order - primitive" employing the discharge-at-tracing-time "final-style - transformations" we've used so far. -2. **`bind` takes a jaxpr as an argument.** Before we call `bind`, in the - primitive wrapper we can just use `make_jaxpr` to form a jaxpr up-front and - be done with the Python callable entirely. In this case, `make_jaxpr` puts - its `JaxprTrace` at the top of the interpreter stack, and no - transformations lower in the stack, which might enter via closed-over - Tracers, are applied to the Python callable as we trace it. +1. **On-the-fly processing, where `bind` takes a Python callable as an + argument.** We defer forming a jaxpr until as late as possible, namely + until we're running the final interpreter at the bottom of the interpreter + stack. That way we can swap a `JaxprTrace` in at the bottom of the + interpreter stack and thus stage out rather than execute all primitive + operations. With this approach, transformations in the stack get applied as + we execute the Python callable as usual. This approach can be very tricky + to implement, but it's as general as possible because it allows + higher-order primitives not to raise the abstraction level of their + arguments and thus allows data-dependent Python control flow. We refer to + this approach as using a "final-style higher-order primitive" employing the + discharge-at-tracing-time "final-style transformations" we've used so far. +2. **Staged processing, where `bind` takes a jaxpr as an argument.** Before we + call `bind`, in the primitive wrapper we can just use `make_jaxpr` to form + a jaxpr up-front and be done with the Python callable entirely. In this + case, `make_jaxpr` puts its `JaxprTrace` at the top of the interpreter + stack, and no transformations lower in the stack, which might enter via + closed-over Tracers, are applied to the Python callable as we trace it. (Transformations applied within the Python callable are applied as usual, being added to the stack above the JaxprTrace.) Instead, the transformations lower in the stack are later applied to the call primitive, @@ -1537,7 +1433,7 @@ But it's just imprecise yet sticky jargon. With the initial-style approach, here's the user-facing `jit` wrapper: -```{code-cell} ipython3 +```{code-cell} def jit(f): def f_jitted(*args): avals_in = [raise_to_shaped(get_aval(x)) for x in args] @@ -1560,7 +1456,7 @@ signature. First, some utilities. -```{code-cell} ipython3 +```{code-cell} class IDHashable: val: Any @@ -1576,7 +1472,7 @@ class IDHashable: Next, we'll define the evaluation rule for `xla_call`: -```{code-cell} ipython3 +```{code-cell} from jax.lib import xla_bridge as xb from jax.lib import xla_client as xc xe = xc._xla @@ -1619,7 +1515,7 @@ The main action is in `xla_callable`, which compiles a jaxpr into an XLA HLO program using `jaxpr_subcomp`, then returns a callable which executes the compiled program: -```{code-cell} ipython3 +```{code-cell} def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp] ) -> xe.XlaOp: env: Dict[Var, xe.XlaOp] = {} @@ -1644,11 +1540,9 @@ def execute_compiled(compiled, out_avals, *args): out_bufs = compiled.execute(input_bufs) return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)] -input_handlers = { - int: xb.get_backend(None).buffer_from_pyval, - float: xb.get_backend(None).buffer_from_pyval, - np.ndarray: xb.get_backend(None).buffer_from_pyval, -} +default_input_handler = xb.get_backend(None).buffer_from_pyval +input_handlers = {ty: default_input_handler for ty in + [int, float, np.ndarray, np.float64, np.float32]} def handle_result(aval: ShapedArray, buf): del aval # Unused for now. @@ -1662,22 +1556,18 @@ a common pattern: the way we process jaxprs is usually with an interpreter. And as with any interpreter, we need an interpretation rule for each primitive: -```{code-cell} ipython3 +```{code-cell} def direct_translation(op, c, in_avals, in_vals): del c, in_avals return [op(*in_vals)] -``` -```{code-cell} ipython3 xla_translations[add_p] = partial(direct_translation, xops.Add) xla_translations[mul_p] = partial(direct_translation, xops.Mul) xla_translations[neg_p] = partial(direct_translation, xops.Neg) xla_translations[sin_p] = partial(direct_translation, xops.Sin) xla_translations[cos_p] = partial(direct_translation, xops.Cos) xla_translations[greater_p] = partial(direct_translation, xops.Gt) -``` -```{code-cell} ipython3 def reduce_sum_translation(c, in_avals, in_vals, *, axis): (x_aval,), (x,) = in_avals, in_vals zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype)) @@ -1686,37 +1576,49 @@ def reduce_sum_translation(c, in_avals, in_vals, *, axis): xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape)) return [xops.Reduce(c, [x], [zero], subc.build(), [axis])] xla_translations[reduce_sum_p] = reduce_sum_translation -``` -```{code-cell} ipython3 def broadcast_translation(c, in_avals, in_vals, *, shape, axes): x, = in_vals dims_complement = [i for i in range(len(shape)) if i not in axes] return [xops.BroadcastInDim(x, shape, dims_complement)] xla_translations[broadcast_p] = broadcast_translation + +def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts): + del num_consts # Only used at top-level. + # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead. + subc = xb.make_computation_builder('inner xla_call') + xla_params = _xla_params(subc, in_avals) + outs = jaxpr_subcomp(subc, jaxpr, xla_params) + subc = subc.build(xops.Tuple(subc, outs)) + return destructure_tuple(c, xops.Call(c, subc, in_vals)) +xla_translations[xla_call_p] = xla_call_translation + +def destructure_tuple(c, tup): + num_elements = len(c.get_shape(tup).tuple_shapes()) + return [xops.GetTupleElement(tup, i) for i in range(num_elements)] ``` With that, we can now use `jit` to stage out, compile, and execute programs with XLA! -```{code-cell} ipython3 +```{code-cell} @jit def f(x, y): print('tracing!') return sin(x) * cos(y) ``` -```{code-cell} ipython3 +```{code-cell} z = f(3., 4.) # 'tracing!' prints the first time print(z) ``` -```{code-cell} ipython3 +```{code-cell} z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit! print(z) ``` -```{code-cell} ipython3 +```{code-cell} @jit def f(x): return reduce_sum(x, axis=0) @@ -1724,7 +1626,7 @@ def f(x): print(f(np.array([1., 2., 3.]))) ``` -```{code-cell} ipython3 +```{code-cell} def f(x): y = sin(x) * 2. z = - y + x @@ -1752,7 +1654,7 @@ its evaluation rule. That is, we can't yet do `vmap`-of-`jit` or `jvp`-of-`jit` or even `jit`-of`-jit`. Instead `jit` has to be at the "top level." Let's fix that! -```{code-cell} ipython3 +```{code-cell} def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): del num_consts # Unused. new_jaxpr, new_consts = jvp_jaxpr(jaxpr) @@ -1762,9 +1664,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): primals_out, tangents_out = outs[:n], outs[n:] return primals_out, tangents_out jvp_rules[xla_call_p] = xla_call_jvp_rule -``` -```{code-cell} ipython3 @lru_cache() def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]: def jvp_traceable(*primals_and_tangents): @@ -1777,7 +1677,7 @@ def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]: return new_jaxpr, new_consts ``` -```{code-cell} ipython3 +```{code-cell} def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): del num_consts # Unused. new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in)) @@ -1787,7 +1687,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): vmap_rules[xla_call_p] = xla_call_vmap_rule @lru_cache() -def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis] +def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis, ...] ) -> Tuple[Jaxpr, List[Any]]: vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) in_avals = [unmapped_aval(axis_size, d, v.aval) @@ -1805,7 +1705,17 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray return ShapedArray(tuple(shape), aval.dtype) ``` -```{code-cell} ipython3 +```{code-cell} +def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts): + del num_consts # Unused. + jaxpr_type = typecheck_jaxpr(jaxpr) + if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)): + raise TypeError + return jaxpr_type.out_types +abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule +``` + +```{code-cell} @jit def f(x): print('tracing!') @@ -1817,9 +1727,13 @@ x, xdot = 3., 1. y, ydot = jvp(f, (x,), (xdot,)) print(y) print(ydot) +``` +```{code-cell} y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed +``` +```{code-cell} ys = vmap(f, (0,))(np.arange(3.)) print(ys) ``` @@ -1831,7 +1745,7 @@ transfer them back for the next operation. We can do that by introducing a `DeviceArray` class, which can wrap XLA buffers and otherwise duck-type `numpy.ndarray`s: -```{code-cell} ipython3 +```{code-cell} def handle_result(aval: ShapedArray, buf): # noqa: F811 return DeviceArray(aval, buf) @@ -1862,7 +1776,7 @@ input_handlers[DeviceArray] = lambda x: x.buf jax_types.add(DeviceArray) ``` -```{code-cell} ipython3 +```{code-cell} @jit def f(x): y = sin(x) * 2. @@ -1881,10 +1795,17 @@ The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve jaxprs as well. That's because both involve staging out, or delaying, computation. ++++ + +### `linearize` + In the case of `linearize`, we want to stage out the linear part of a `jvp` computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`, -then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, where -``` +then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to +mean "the tangent type of `a`" and using the "lollipop" `-o` rather than the +arrow `->` to indicate a _linear_ function. We define the semantics of +`linearize` in terms of `jvp` too: +```python y, f_lin = linearize(f, x) y_dot = f_lin(x_dot) ``` @@ -1892,32 +1813,46 @@ gives the same result for `(y, y_dot)` as ``` y, y_dot = jvp(f, (x,), (x_dot,)) ``` -and where the application of `f_lin` does not redo any of the linearization -work. We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr. +where the application of `f_lin` does not redo any of the linearization work. +We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr. To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation: we evaluate all the primal values as we trace, but stage the tangent -computations into a jaxpr. +computations into a jaxpr. This is our second way to build jaxprs. But where +`make_jaxpr` and its underlying `JaxprTrace`/`JaxprTracer` interpreters aim +to stage out every primitive bind, this second approach stages out only those +primitive binds with a data dependence on tagent inputs. -```{code-cell} ipython3 -def split_half(lst): - n, ragged = divmod(len(lst), 2) - assert not ragged +First, some utilities: + +```{code-cell} +def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]: return lst[:n], lst[n:] + +def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]: + assert not len(lst) % 2 + return split_list(lst, len(lst) // 2) + +def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]: + lists = lst1, lst2 = [], [] + for b, x in zip(bs, l): + lists[b].append(x) + return lst1, lst2 ``` -```{code-cell} ipython3 +Next, we'll write `linearize` by combining `jvp` together with a general +partial evaluation transformation, to be added next: + +```{code-cell} def linearize_flat(f, *primals_in): pvals_in = ([PartialVal.known(x) for x in primals_in] + [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in]) - def f_jvp(*primals_tangents_in): primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in)) return [*primals_out, *tangents_out] - jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) primal_pvals, _ = split_half(pvals_out) - assert all(pval.is_known for pval in primal_pvals) + assert all(pval.is_known for pval in primal_pvals) primals_out = [pval.const for pval in primal_pvals] f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents]) return primals_out, f_lin @@ -1937,10 +1872,88 @@ def linearize(f, *primals_in): return primals_out, f_lin def vspace(aval: ShapedArray) -> ShapedArray: - return raise_to_shaped(aval) + return raise_to_shaped(aval) # TODO handle integers? ``` -```{code-cell} ipython3 +Now we turn to the general partial evaluation transformation. The goal is to +accept a Python callable and a list of inputs, some known and some unknown, +and to produce (1) all the outputs which can be computed from the known +inputs, together with (2) a jaxpr representing the part of the Python +callable's computation which can only be performed after the remaining inputs +are known. + +This transformation can't be summarized purely in a type signature because its +behavior relies on the data dependencies inside the given Python callable and +not just its type. Nevertheless a heuristic type signature is useful. If we +assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where +`a1` and `a2` represent the known and unknown inputs, respectively, and where +`b1` only has a data depenence on `a1` while `b2` has some data dependnece on +`a2`, then we might write + +``` +partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> (b1, res, (res, a2) -> b2) +``` + +In words, given values for the inputs of type `a1`, `partial_eval` produces +the outputs of type `b1` along with "residual" values of type `res` +representing the intermediates required to complete the computation in the +second stage. It also produces a function of type `(res, a2) -> b2` which +accepts the residual values as well as the remaining inputs and produces the +remaining outputs. + +We like to think of partial evaluation as "unzipping" one computation into +two. For example, consider this jaxpr: +``` +{ lambda a:float64[] . + let b:float64[] = sin a + c:float64[] = neg b + in ( c ) } +``` +A jaxpr for the JVP would look like: +``` +{ lambda a:float64[] b:float64 . + let c:float64[] = sin a + d:float64[] = cos a + e:float64[] = mul d b + f:float64[] = neg c + g:float64[] = neg e + in ( f, g ) } +``` +If we imagine applying partial evaluation to this jaxpr with the first input +known and the second unknown, we end up 'unzipping' the JVP jaxpr into primal +and tangent jaxprs: +``` +{ lambda a:float64[] . + let c:float64[] = sin a + d:float64[] = cos a + f:float64[] = neg c + in ( f, d ) } +``` +``` +{ lambda d:float64[] b:float64[] . + let e:float64[] = mul d b + g:float64[] = neg e + in ( g ) } +``` +This second jaxpr is represents the linear computation that we want from +`linearize`. + +However, unlike in this jaxpr example, we want the computation on known values +to occur while evaluating the input Python callable. That is, rather than +forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all +operations out of Python first before sorting out what can be evaluated now +and what must be delayed, we want only to form a jaxpr for those operations +that _must_ be delayed due to a dependence on unknown inputs. In the context +of automatic differentiation, this is the feature ultimately enables us to +handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control +flow works because partial evaluation keeps the primal computation in Python. +As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out +what can be evaluated and what must be staged out into a jaxpr. + +First, we start with a `PartialVal` class, which represents a value that can +be either known or unknown: + +```{code-cell} class PartialVal(NamedTuple): aval: ShapedArray const: Optional[Any] @@ -1953,9 +1966,15 @@ class PartialVal(NamedTuple): def unknown(cls, aval: ShapedArray): return PartialVal(aval, None) - is_known = property(lambda self: self.const is not None) - is_unknown = property(lambda self: self.const is None) + is_known = property(lambda self: self.const is not None) + is_unknown = property(lambda self: self.const is None) +``` +Partial evaluation will take a list of `PartialVal`s representing inputs, and +return a list of `PartialVal` outputs along with a jaxpr representing the +dleayed computation: + +```{code-cell} def partial_eval_flat(f, pvals_in: List[PartialVal]): with new_main(PartialEvalTrace) as main: trace = PartialEvalTrace(main) @@ -1967,10 +1986,19 @@ def partial_eval_flat(f, pvals_in: List[PartialVal]): return jaxpr, pvals_out, consts ``` -```{code-cell} ipython3 +Next we need to implement `PartialEvalTrace` and its `PartialEvalTracer`. This +interpreter will build a jaxpr on the fly while tracking data dependencies. To +do so, it builds a bipartite directed acyclic graph (DAG) between +`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe` +nodes, representing formulas for how compute some values from others. One kind +of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive +application, but we also have recipe types for constants and lambda binders: + +```{code-cell} from weakref import ref, ReferenceType -class LambdaBindingRecipe(NamedTuple): pass +class LambdaBindingRecipe(NamedTuple): + pass class ConstRecipe(NamedTuple): val: Any @@ -1990,7 +2018,9 @@ class JaxprEqnRecipe: self.tracer_refs_out = tracer_refs_out JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe] +``` +```{code-cell} class PartialEvalTracer(Tracer): pval: PartialVal recipe: JaxprRecipe @@ -2008,7 +2038,28 @@ class PartialEvalTracer(Tracer): if self.pval.is_known: return full_lower(self.pval.const) return self +``` +The `PartialEvalTrace` contains the logic for constructing the graph of +`JaxprRecipe`s and `PartialEvalTracer`s. Each argument corresponds to a +`LambdaBindingRecipe` leaf node, and each constant is a `ConstRecipe` leaf +node holding a reference to the constant. All other tracers and recipes come +from `process_primitive`, which forms tracers with `JaxprEqnRecipe`s. + +For most primitives, the `process_primitive` logic is straightforward: if all +inputs are known then we can bind the primitive on the known values +(evaluating it in Python) and avoid forming tracers corresponding to the +output. If instead any input is unknown then we instead stage out into a +`JaxprEqnRecipe` representing the primitive application. To build the tracers +representing unknown outputs, we need avals, which get from the abstract eval +rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s +reference tracers; we avoid circular garbage by using weakrefs.) + +That `process_primitive` logic applies to most primitives, but `xla_call_p` +requires recursive treatment. So we special-case its rule in a +`partial_eval_rules` dict. + +```{code-cell} class PartialEvalTrace(Trace): def new_arg(self, pval: PartialVal) -> Any: return PartialEvalTracer(self, pval, LambdaBindingRecipe()) @@ -2027,6 +2078,8 @@ class PartialEvalTrace(Trace): def process_primitive(self, primitive, tracers, params): if all(t.pval.is_known for t in tracers): return bind(primitive, *map(full_lower, tracers), **params) + rule = partial_eval_rules.get(primitive) + if rule: return rule(self, tracers, **params) tracers_in = [self.instantiate_const(t) for t in tracers] avals_in = [t.aval for t in tracers_in] avals_out = abstract_eval_rules[primitive](*avals_in, **params) @@ -2036,9 +2089,15 @@ class PartialEvalTrace(Trace): map(ref, tracers_out)) for t in tracers_out: t.recipe = eqn return tracers_out + +partial_eval_rules = {} ``` -```{code-cell} ipython3 +Now that we can build graph representations of jaxprs with `PartialEvalTrace`, +we need a mechanism to convert the graph representation to a standard jaxpr. +The jaxpr corresponds to a topological sort of the graph. + +```{code-cell} def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer], tracers_out: List[PartialEvalTracer]): tracers_in = [t for t in tracers_in if t.pval.is_unknown] @@ -2085,7 +2144,7 @@ def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]: return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else [] ``` -```{code-cell} ipython3 +```{code-cell} def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]): if not out_nodes: return [] out_nodes = remove_duplicates(out_nodes) @@ -2128,8 +2187,374 @@ def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]): seen.add(id(node)) ``` -```{code-cell} ipython3 +Now we can linearize! + +```{code-cell} y, sin_lin = linearize(sin, 3.) print(y, sin(3.)) print(sin_lin(1.), cos(3.)) ``` + +To handle linearize-of-jit, we still need to write a partial evaluation rule +for `xla_call_p`. Other than tracer bookkeeping, the main task is to perform +partial evaluation of a jaxpr, 'unzipping' it into two jaxprs. + +```{code-cell} +def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts): + del num_consts # Unused. + in_unknowns = [not t.pval.is_known for t in tracers] + jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns) + known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) + known_vals = [t.pval.const for t in known_tracers] + outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0) + outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res) + res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res] + outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None) + for v in jaxpr2.outs] + eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers, + dict(jaxpr=jaxpr2, num_consts=0), + [v.aval for v in jaxpr2.outs], map(ref, outs2)) + for t in outs2: t.recipe = eqn + outs1, outs2 = iter(outs1), iter(outs2) + return [next(outs2) if uk else next(outs1) for uk in out_unknowns] +partial_eval_rules[xla_call_p] = xla_call_partial_eval + +def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool] + ) -> Tuple[Jaxpr, Jaxpr, List[bool], int]: + env: Dict[Var, bool] = {} + residuals = set() + + def read(v: Atom) -> bool: + if type(v) is Lit: raise NotImplementedError + return env[v] + + def write(unk: bool, v: Var) -> None: + env[v] = unk + + def new_res(v: Var) -> Var: + return residuals.add(v) or v + + eqns1, eqns2 = [], [] + map(write, in_unknowns, jaxpr.in_binders) + for eqn in jaxpr.eqns: + unks_in = map(read, eqn.inputs) + rule = partial_eval_jaxpr_rules.get(eqn.primitive) + if rule: + eqn1, eqn2, unks_out, res = rule(unks_in, eqn) + eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res) + map(write, unks_out, eqn.out_binders) + elif any(unks_in): + inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)] + eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders)) + map(partial(write, True), eqn.out_binders) + else: + eqns1.append(eqn) + map(partial(write, False), eqn.out_binders) + out_unknowns = map(read, jaxpr.outs) + residuals, num_res = list(residuals), len(residuals) + + ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders) + outs1, outs2 = partition_list(out_unknowns, jaxpr.outs) + + jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals) + jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2) + typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2) + + return jaxpr1, jaxpr2, out_unknowns, num_res + +def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2): + jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 ) + jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res) + jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2 + + a1, a2 = partition_list(unks_in, jaxprty.in_types) + b1, b2 = partition_list(unks_out, jaxprty.out_types) + b1_, res = split_list(jaxpr1ty.out_types, len(b1)) + res_, a2_ = split_list(jaxpr2ty.in_types, len(res)) + b2_ = jaxpr2ty.out_types + + if jaxpr1ty.in_types != a1: raise TypeError + if jaxpr2ty.out_types != b2: raise TypeError + if b1 != b1_: raise TypeError + if res != res_: raise TypeError + if a2 != a2_: raise TypeError + if b2 != b2_: raise TypeError + +partial_eval_jaxpr_rules = {} + +def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn + ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]: + jaxpr = eqn.params['jaxpr'] + jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in) + ins1, ins2 = partition_list(unks_in, eqn.inputs) + outs1, outs2 = partition_list(unks_out, eqn.out_binders) + residuals, _ = split_list(jaxpr2.in_binders, num_res) + eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0), + outs1 + residuals) + eqn2 = JaxprEqn(xla_call_p, residuals + ins2, + dict(jaxpr=jaxpr2, num_consts=0), outs2) + return eqn1, eqn2, unks_out, residuals +partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn +``` + +With that, we can compose `linearize` and `jit` however we like: + +```{code-cell} +@jit +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +y, f_lin = linearize(f, 3.) +y_dot = f_lin(1.) +print(y, y_dot) +``` + +```{code-cell} +@jit +def f(x): + y = sin(x) * 2. + z = g(x, y) + return z + +@jit +def g(x, y): + return cos(x) + y + +y, f_lin = linearize(f, 3.) +y_dot = f_lin(1.) +print(y, y_dot) +``` + +### `vjp` and `grad` + +The `vjp` transformation works a lot like linearize. Its type signature is +analogous: + +``` +linearize : (a -> b) -> a -> (b, T a -o T b) +vjp : (a -> b) -> a -> (b, T b -o T a) +``` + +The only difference is that we transpose the linear part of the computation +before returning it, so that it goes from type `T a -o T b` to type `T b -o T +a`. That is, we'll implement `vjp` as, essentially, + +``` +def vjp(f, x): + y, f_lin = linearize(f, x) + f_vjp = lambda y_bar: transpose(f_lin)(y_bar) + return y, f_vjp +``` + +Since we have the linear computation as a jaxpr, not just a Python callable, +we can implement the transpose transformation as a jaxpr interpreter. + +```{code-cell} +def vjp_flat(f, *primals_in): + pvals_in = ([PartialVal.known(x) for x in primals_in] + + [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in]) + primal_pvals_in, tangent_pvals_in = split_half(pvals_in) + def f_jvp(*primals_tangents_in): + primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in)) + return [*primals_out, *tangents_out] + jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize + primal_pvals, _ = split_half(pvals_out) + assert all(pval.is_known for pval in primal_pvals) + primals_out = [pval.const for pval in primal_pvals] + transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in] + f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts) + return primals_out, f_vjp + +def vjp(f, *primals_in): + primals_in_flat, in_tree = tree_flatten(primals_in) + f, out_tree = flatten_fun(f, in_tree) + primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat) + primals_out = tree_unflatten(out_tree(), primals_out_flat) + + def f_vjp(*cotangents_out): + cotangents_out_flat, _ = tree_flatten(cotangents_out) + cotangents_in_flat = f_vjp_flat(*cotangents_out_flat) + return tree_unflatten(in_tree, cotangents_in_flat) + + return primals_out, f_vjp + +class UndefPrimal(NamedTuple): + aval: ShapedArray + +register_pytree_node(UndefPrimal, + lambda u: (u.aval, ()), + lambda aval, _: UndefPrimal(aval)) +``` + +We use `UndefPrimal` instances to indicate which arguments with respect to +with we want to transpose. These arise because in general, being explicit +about closed-over values, we want to transpose functions of type +`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the +inputs with respect to which the function is linear could be scattered through +the argument list. So we indicate the linear positions using `UndefPrimal`. +We register `UndefPrimal` as a pytree node because the pytree mechanism gives +a handy way to prune these placeholders out of argument lists. + +Next, we can write `eval_jaxpr_transposed`, along with transpose rules for +all primitives which can be linear in at least one argument: + +```{code-cell} +# NB: the analogous function in JAX is called 'backward_pass' +def eval_jaxpr_transposed(jaxpr: Jaxpr, args: List[Any], cotangents: List[Any] + ) -> List[Any]: + primal_env: Dict[Var, Any] = {} + ct_env: Dict[Var, Any] = {} + + def read_primal(x: Atom) -> Any: + return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val + + def write_primal(v: Var, val: Any) -> None: + if type(val) is not UndefPrimal: + primal_env[v] = val + + def read_cotangent(v: Var) -> Any: + return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype)) + + def write_cotangent(x: Atom, val: Any): + if type(x) is Var and val is not None: + ct_env[x] = add(ct_env[x], val) if x in ct_env else val + + map(write_primal, jaxpr.in_binders, args) + map(write_cotangent, jaxpr.outs, cotangents) + for eqn in jaxpr.eqns[::-1]: + primals_in = map(read_primal, eqn.inputs) + cts_in = map(read_cotangent, eqn.out_binders) + rule = transpose_rules[eqn.primitive] + cts_out = rule(cts_in, *primals_in, **eqn.params) + map(write_cotangent, eqn.inputs, cts_out) + + return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args) + if type(x) is UndefPrimal] + +transpose_rules = {} +``` + +```{code-cell} +def mul_transpose_rule(cts, x, y): + z_bar, = cts + assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal) + return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)] +transpose_rules[mul_p] = mul_transpose_rule + +def neg_transpose_rule(cts, x): + ybar, = cts + assert type(x) is UndefPrimal + return [neg(ybar)] +transpose_rules[neg_p] = neg_transpose_rule + +def add_transpose_rule(cts, x, y): + z_bar, = cts + return [z_bar, z_bar] +transpose_rules[add_p] = add_transpose_rule + +def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts): + del num_consts # Unused. + undef_primals = [type(x) is UndefPrimal for x in invals] + transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals)) + residuals, _ = partition_list(undef_primals, invals) + outs = bind(xla_call_p, *new_consts, *residuals, *cts, + jaxpr=transposed_jaxpr, num_consts=len(new_consts)) + outs = iter(outs) + return [next(outs) if undef else None for undef in undef_primals] +transpose_rules[xla_call_p] = xla_call_transpose_rule + +@lru_cache() +def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: Tuple[bool, ...] + ) -> Tuple[Jaxpr, List[Any]]: + traceable = partial(eval_jaxpr_transposed, jaxpr) + avals_in, avals_out = typecheck_jaxpr(jaxpr) + args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)] + trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out)) + return trans_jaxpr, consts +``` + +Now that we can linearize and transpose, we can finally write `grad`: + +```{code-cell} +def grad(f): + def gradfun(x, *xs): + y, f_vjp = vjp(f, x, *xs) + if np.shape(y) != (): raise TypeError + x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y))) + return x_bar + return gradfun +``` + +```{code-cell} +y, f_vjp = vjp(sin, 3.) +print(f_vjp(1.), cos(3.)) +``` + +```{code-cell} +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +print(grad(f)(3.)) +``` + +```{code-cell} +@jit +def f(x): + y = x * 2. + z = g(y) + return z + +@jit +def g(x): + return cos(x) * 2. + +print(grad(f)(3.)) +``` + +Here's something of a compositionality stress test: + +```{code-cell} +# from core_test.py fun_with_nested_calls_2 +def foo(x): + @jit + def bar(y): + def baz(w): + q = jit(lambda x: y)(x) + q = q + jit(lambda: y)() + q = q + jit(lambda y: w + y)(y) + q = jit(lambda w: jit(sin)(x) * y)(1.0) + q + return q + p, t = jvp(baz, (x + 1.0,), (y,)) + return t + (x * p) + return bar(x) + +def assert_allclose(*vals): + for v1, v2 in zip(vals[:-1], vals[1:]): + np.testing.assert_allclose(v1, v2) + +ans1 = f(3.) +ans2 = jit(f)(3.) +ans3, _ = jvp(f, (3.,), (5.,)) +ans4, _ = jvp(jit(f), (3.,), (5.,)) +assert_allclose(ans1, ans2, ans3, ans4) + +deriv1 = grad(f)(3.) +deriv2 = grad(jit(f))(3.) +deriv3 = jit(grad(jit(f)))(3.) +_, deriv4 = jvp(f, (3.,), (1.,)) +_, deriv5 = jvp(jit(f), (3.,), (1.,)) +assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5) + +hess1 = grad(grad(f))(3.) +hess2 = grad(grad(jit(f)))(3.) +hess3 = grad(jit(grad(f)))(3.) +hess4 = jit(grad(grad(f)))(3.) +_, hess5 = jvp(grad(f), (3.,), (1.,)) +_, hess6 = jvp(jit(grad(f)), (3.,), (1.,)) +_, hess7 = jvp(jit(grad(f)), (3.,), (1.,)) +assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7) +``` diff --git a/docs/autodidax.py b/docs/autodidax.py index c504d88bb..71bcfad3a 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -26,12 +26,6 @@ # name: python3 # --- -import pdb, sys, traceback -def info(type, value, tb): - traceback.print_exception(type, value, tb) - pdb.pm() -sys.excepthook = info - # # Autodidax: JAX core from scratch # @@ -124,6 +118,7 @@ def bind1(prim, *args, **params): # needs. We call each element a `MainTrace`, though maybe "Interpreter" would be # more descriptive. +# + from contextlib import contextmanager from typing import Type, List, Optional, Any @@ -145,6 +140,7 @@ def new_main(trace_type: Type['Trace'], global_data=None): yield main finally: trace_stack.pop() +# - # When we're about to apply a transformation, we'll push another interpreter # onto the stack using `new_main`. Then, as we apply primitives in the function, @@ -191,6 +187,7 @@ class Trace: # `Tracer` per transformation, and at least one `AbstractValue` per base type, # like arrays.) +# + import numpy as np from typing import Tuple @@ -223,6 +220,7 @@ class Tracer: def swap(f): return lambda x, y: f(y, x) +# + class ShapedArray: array_abstraction_level = 1 shape: Tuple[int] @@ -261,6 +259,9 @@ class ShapedArray: return (type(self) is type(other) and self.shape == other.shape and self.dtype == other.dtype) + def __repr__(self): + return f"ShapedArray(shape={self.shape}, dtype={self.dtype})" + class ConcreteArray(ShapedArray): array_abstraction_level = 2 val: np.ndarray @@ -288,6 +289,7 @@ def get_aval(x): jax_types = {bool, int, float, np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray} +# - # Notice that we actually have two `AbstractValue`s for arrays, representing # different levels of abstraction. A `ShapedArray` represents the set of all @@ -310,6 +312,7 @@ def bind(prim, *args, **params): # top trace's `Tracer` instances, and the call to `full_lower` is an optional # optimization so that we unbox values out of `Tracer`s as much as possible. +# + import operator as op def find_top_trace(xs) -> Trace: @@ -318,6 +321,7 @@ def find_top_trace(xs) -> Trace: if dynamic_trace and dynamic_trace.level > top_main.level: top_main = dynamic_trace return top_main.trace_type(top_main) +# - # In words, ignoring the `dynamic_trace` step until Part 3, `find_top_trace` # returns the highest-level interpreter associated with the `Tracer`s on its @@ -337,6 +341,7 @@ def find_top_trace(xs) -> Trace: # large part because that's so natural for automatic differentiation, and JAX's # roots are in autodiff. But it may be over-fit. +# + def full_lower(val: Any): if isinstance(val, Tracer): return val.full_lower() @@ -356,6 +361,7 @@ def full_raise(trace: Trace, val: Any) -> Tracer: raise Exception(f"Can't lift level {val._trace.main.level} to {level}.") else: # val._trace.level == level raise Exception(f"Different traces at same level: {val._trace}, {trace}.") +# - # The logic in `full_raise` serves to box values into `Tracer`s for a particular # `Trace`, calling different methods on the `Trace` based on context: @@ -371,6 +377,7 @@ def full_raise(trace: Trace, val: Any) -> Tracer: # We'll start with the simplest interpreter: the evaluation interpreter that # will sit at the bottom of the interpreter stack. +# + class EvalTrace(Trace): pure = lift = lambda self, x: x # no boxing in Tracers needed @@ -393,15 +400,18 @@ impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)] def broadcast_impl(x, *, shape, axes): return [np.broadcast_to(np.expand_dims(x, axes), shape)] impl_rules[broadcast_p] = broadcast_impl +# - # With this interpreter, we can evaluate user functions: +# + def f(x): y = sin(x) * 2. z = - y + x return z print(f(3.0)) +# - # Woo! Like going around in a big circle. But the point of this indirection is # that now we can add some real transformations. @@ -410,6 +420,7 @@ print(f(3.0)) # # First, a few helper functions: +# + def zeros_like(val): return np.zeros_like(val) @@ -423,10 +434,12 @@ def unzip2(pairs): map_ = map def map(f, *xs): return list(map_(f, *xs)) +# - # The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The # `Trace` applies JVP rules. +# + class JVPTracer(Tracer): def __init__(self, trace, primal, tangent): self._trace = trace @@ -447,12 +460,14 @@ class JVPTrace(Trace): return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)] jvp_rules = {} +# - # Notice both `lift` and `sublift` package a value into a `JVPTracer` with the # minimal amount of context, which is a zero tangent value. # # Let's add some JVP rules for primitives: +# + def add_jvp(primals, tangents): (x, y), (x_dot, y_dot) = primals, tangents return [x + y], [x_dot + y_dot] @@ -488,6 +503,7 @@ def greater_jvp(primals, tangents): out_primal = greater(x, y) return [out_primal], [zeros_like(out_primal)] jvp_rules[greater_p] = greater_jvp +# - # Finally, we add a transformation API to kick off the trace: @@ -507,7 +523,6 @@ y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,)) print(sin_deriv_at_3) print(cos(3.0)) - # + def f(x): y = sin(x) * 2. @@ -587,6 +602,7 @@ def jvp(f, primals, tangents): # All that remains is to write `tree_flatten`, `tree_unflatten`, and # `flatten_fun`: +# + def flatten_fun(f, in_tree): store = Store() @@ -617,15 +633,20 @@ import itertools as it from typing import Callable, Type, Hashable, Dict, Iterable, Iterator class NodeType(NamedTuple): + name: str to_iterable: Callable from_iterable: Callable -node_types: Dict[Type, NodeType] = { - tuple: NodeType(lambda t: (None, t), lambda _, xs: tuple(xs)), - list: NodeType( lambda l: (None, l), lambda _, xs: list(xs)), - dict: NodeType(lambda d: map(tuple, unzip2(sorted(d.items()))), - lambda keys, vals: dict(zip(keys, vals))), -} +def register_pytree_node(ty: Type, to_iter: Callable, from_iter: Callable + ) -> None: + node_types[ty] = NodeType(str(ty), to_iter, from_iter) + +node_types: Dict[Type, NodeType] = {} +register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs)) +register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs)) +register_pytree_node(dict, + lambda d: map(tuple, unzip2(sorted(d.items()))), + lambda keys, vals: dict(zip(keys, vals))) class PyTreeDef(NamedTuple): node_type: NodeType @@ -660,21 +681,21 @@ def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any: return treedef.node_type.from_iterable(treedef.node_metadata, children) # - - # With this pytree-handling `jvp` impelmentation, we can now handle arbitrary # input and output containers. That'll come in handy with future transformations # too! +# + def f(x): y = sin(x) * 2. z = - y + x return {'hi': z, 'there': [x, y]} - x, xdot = 3., 1. y, ydot = jvp(f, (x,), (xdot,)) print(y) print(ydot) +# - # ### Vectorized batching with `vmap` # @@ -682,6 +703,7 @@ print(ydot) # from unmapped ones (by removing an axis), and one for moving batch dimensions # around: +# + def mapped_aval(batch_dim, aval): shape = list(aval.shape) del shape[batch_dim] @@ -701,10 +723,12 @@ def moveaxis(x, src: int, dst: int): perm = [i for i in range(np.ndim(x)) if i != src] perm.insert(dst, src) return transpose(x, perm) +# - # The `Tracer` for vectorized batching carries a batched value and an optional # integer indicating which axis (if any) is the batch axis. +# + from typing import Union class NotMapped: pass @@ -745,6 +769,7 @@ class BatchTrace(Trace): return self.main.global_data vmap_rules = {} +# - # Here we've implemented the optional `Tracer.full_lower` method, which lets us # peel off a batching tracer if it's not needed because it doesn't represent a @@ -758,6 +783,7 @@ vmap_rules = {} # # Next we can define batching interpreter rules for each primitive: +# + from functools import partial def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in): @@ -784,12 +810,11 @@ def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis): out_bdim = x_bdim - (new_axis < x_bdim) return [reduce_sum(x, new_axis)], [out_bdim] vmap_rules[reduce_sum_p] = reduce_sum_batching_rule - - # - # Finally, we add a transformation API to kick off the trace: +# + def vmap_flat(f, in_axes, *args): axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes) if ax is not not_mapped} @@ -825,7 +850,6 @@ vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in) print(vector_in) print(vector_out) - # + def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x,), (v,))[1] @@ -838,19 +862,7 @@ def f(x): jacfwd(f, np.arange(3.)) # - -# That's it for `jvp` and `vmap`! Before moving on, let's highlight a few -# simplifications in what we've seen so far compared to the full JAX -# implementation: -# 1. **Fewer, simpler primitives.** More primitives means more interpretation -# rules, and for more complex primitives (like for convolution or advanced -# indexing) each rule is harder to write. But the overarching design is no -# different. -# 2. **No pytrees.** Transformations expect arrays in, and either a single array -# out or a flat list of arrays out. -# 3. **Missing optimization: no symbolic zeros in autodiff.** -# 4. **No special call primitives yet.** The core machinery needs to be -# generalized to handle the most flexible kind of higher-order primitive, -# used by `jax.custom_jvp` and `jax.custom_vjp`. +# That's it for `jvp` and `vmap`! # ## Part 2: Jaxprs @@ -890,7 +902,7 @@ jacfwd(f, np.arange(3.)) # binder ::= : # var ::= a | b | c | ... # atom ::= | -# literal ::= | +# literal ::= | | | # # eqn ::= , ... = [ ] , ... # ``` @@ -949,13 +961,9 @@ def raise_to_shaped(aval): # + -class JaxprType: - in_types: List[ShapedArray] - out_type: List[ShapedArray] - - def __init__(self, in_types, out_types): - self.in_types = in_types - self.out_types = out_types +class JaxprType(NamedTuple): + in_types: List[ShapedArray] + out_types: List[ShapedArray] def __repr__(self): in_types = ', '.join(aval.str_short() for aval in self.in_types) @@ -973,7 +981,7 @@ def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType: in_types = [typecheck_atom(env, x) for x in eqn.inputs] out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params) for out_binder, out_type in zip(eqn.out_binders, out_types): - if not types_equal(out_type, out_binder.aval): raise TypeError + if not out_type == out_binder.aval: raise TypeError for out_binder in eqn.out_binders: if out_binder in env: raise TypeError env.add(out_binder) @@ -990,14 +998,12 @@ def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray: return raise_to_shaped(get_aval(x.val)) else: assert False - -def types_equal(a: ShapedArray, b: ShapedArray) -> bool: - return a.shape == b.shape and a.dtype == b.dtype # - # We can apply the function represented by a jaxpr to arguments with a simple # interpreter. +# + def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]: env: Dict[Var, Any] = {} @@ -1005,6 +1011,7 @@ def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]: return env[x] if type(x) is Var else x.val def write(v: Var, val: Any) -> None: + assert v not in env # single-assignment env[v] = val map(write, jaxpr.in_binders, args) @@ -1016,6 +1023,7 @@ def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]: def jaxpr_as_fun(jaxpr: Jaxpr): return lambda *args: eval_jaxpr(jaxpr, args) +# - # By using `bind` in the interpreter, this interpreter itself is traceable. @@ -1169,9 +1177,10 @@ abstract_eval_rules[broadcast_p] = broadcast_abstract_eval # To check our implementation of jaxprs, we can add a `make_jaxpr` # transformation and a pretty-printer: +# + from functools import lru_cache -@lru_cache() +@lru_cache() # ShapedArrays are hashable def make_jaxpr_v1(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1247,10 +1256,12 @@ def pp_params(params: Dict[str, Any]) -> PPrint: return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ') else: return pp(' ') + +Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self)) # - jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.))) -print(pp_jaxpr(jaxpr)) +print(jaxpr) print(typecheck_jaxpr(jaxpr)) # But there's a limitation here: because of how `find_top_trace` operates by @@ -1258,7 +1269,7 @@ print(typecheck_jaxpr(jaxpr)) # performed by the Python callable it's given. For example: jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.)) -print(pp_jaxpr(jaxpr)) +print(jaxpr) # This is precisely the issue that # [omnistaging](https://github.com/google/jax/pull/3370) fixed. @@ -1267,6 +1278,7 @@ print(pp_jaxpr(jaxpr)) # `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` # global defined in Part 1: +# + @contextmanager def new_dynamic(main: MainTrace): global dynamic_trace @@ -1276,7 +1288,7 @@ def new_dynamic(main: MainTrace): finally: dynamic_trace = prev_dynamic_trace -@lru_cache() # ShapedArrays are hashable +@lru_cache() def make_jaxpr(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1292,7 +1304,8 @@ def make_jaxpr(f, *avals_in): return jaxpr, consts, out_tree() jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.)) -print(pp_jaxpr(jaxpr)) +print(jaxpr) +# - # Using `dynamic_trace` this way is conceptually the same as stashing the # current interpreter stack and starting a new one with the `JaxprTrace` at the @@ -1305,9 +1318,8 @@ print(pp_jaxpr(jaxpr)) # system state simpler. # That's it for jaxprs! With jaxprs in hand, we can implement the remaining -# major JAX features. But before moving on, let's highlight some -# simplifications we've made: -# 1. **Single-output primitives and jaxprs.** +# major JAX features. + # ## Part 3: `jit`, simplified # @@ -1316,28 +1328,28 @@ print(pp_jaxpr(jaxpr)) # than a transformation. A primitive is _higher-order_ when it's parameterized # by a function. -# ### "Final style" and "initial style" +# ### On-the-fly ("final style") and staged ("initial style") processing # # There are two options for how to handle higher-order primitives. Each requires # a different approach to tracing and engenders different tradeoffs: -# 1. **`bind` takes a Python callable as an argument.** We defer forming a jaxpr -# until as late as possible, namely until we're running the final interpreter -# at the bottom of the interpreter stack. That way we can swap a `JaxprTrace` -# in at the bottom of the interpreter stack and thus stage out rather than -# execute all primitive operations. With this approach, transformations in -# the stack get applied as we execute the Python callable as usual. This -# approach can be very tricky to implement, but it's as general as possible -# because it allows higher-order primitives not to raise the abstraction -# level of their arguments and thus allows data-dependent Python control -# flow. We refer to this approach as using a "final-style higher-order -# primitive" employing the discharge-at-tracing-time "final-style -# transformations" we've used so far. -# 2. **`bind` takes a jaxpr as an argument.** Before we call `bind`, in the -# primitive wrapper we can just use `make_jaxpr` to form a jaxpr up-front and -# be done with the Python callable entirely. In this case, `make_jaxpr` puts -# its `JaxprTrace` at the top of the interpreter stack, and no -# transformations lower in the stack, which might enter via closed-over -# Tracers, are applied to the Python callable as we trace it. +# 1. **On-the-fly processing, where `bind` takes a Python callable as an +# argument.** We defer forming a jaxpr until as late as possible, namely +# until we're running the final interpreter at the bottom of the interpreter +# stack. That way we can swap a `JaxprTrace` in at the bottom of the +# interpreter stack and thus stage out rather than execute all primitive +# operations. With this approach, transformations in the stack get applied as +# we execute the Python callable as usual. This approach can be very tricky +# to implement, but it's as general as possible because it allows +# higher-order primitives not to raise the abstraction level of their +# arguments and thus allows data-dependent Python control flow. We refer to +# this approach as using a "final-style higher-order primitive" employing the +# discharge-at-tracing-time "final-style transformations" we've used so far. +# 2. **Staged processing, where `bind` takes a jaxpr as an argument.** Before we +# call `bind`, in the primitive wrapper we can just use `make_jaxpr` to form +# a jaxpr up-front and be done with the Python callable entirely. In this +# case, `make_jaxpr` puts its `JaxprTrace` at the top of the interpreter +# stack, and no transformations lower in the stack, which might enter via +# closed-over Tracers, are applied to the Python callable as we trace it. # (Transformations applied within the Python callable are applied as usual, # being added to the stack above the JaxprTrace.) Instead, the # transformations lower in the stack are later applied to the call primitive, @@ -1470,11 +1482,9 @@ def execute_compiled(compiled, out_avals, *args): out_bufs = compiled.execute(input_bufs) return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)] -input_handlers = { - int: xb.get_backend(None).buffer_from_pyval, - float: xb.get_backend(None).buffer_from_pyval, - np.ndarray: xb.get_backend(None).buffer_from_pyval, -} +default_input_handler = xb.get_backend(None).buffer_from_pyval +input_handlers = {ty: default_input_handler for ty in + [int, float, np.ndarray, np.float64, np.float32]} def handle_result(aval: ShapedArray, buf): del aval # Unused for now. @@ -1488,6 +1498,7 @@ xla_translations = {} # And as with any interpreter, we need an interpretation rule for each # primitive: +# + def direct_translation(op, c, in_avals, in_vals): del c, in_avals return [op(*in_vals)] @@ -1514,6 +1525,21 @@ def broadcast_translation(c, in_avals, in_vals, *, shape, axes): return [xops.BroadcastInDim(x, shape, dims_complement)] xla_translations[broadcast_p] = broadcast_translation +def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts): + del num_consts # Only used at top-level. + # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead. + subc = xb.make_computation_builder('inner xla_call') + xla_params = _xla_params(subc, in_avals) + outs = jaxpr_subcomp(subc, jaxpr, xla_params) + subc = subc.build(xops.Tuple(subc, outs)) + return destructure_tuple(c, xops.Call(c, subc, in_vals)) +xla_translations[xla_call_p] = xla_call_translation + +def destructure_tuple(c, tup): + num_elements = len(c.get_shape(tup).tuple_shapes()) + return [xops.GetTupleElement(tup, i) for i in range(num_elements)] +# - + # With that, we can now use `jit` to stage out, compile, and execute programs # with XLA! @@ -1561,6 +1587,7 @@ print(jit(deriv(deriv(f)))(3.)) # `jvp`-of-`jit` or even `jit`-of`-jit`. Instead `jit` has to be at the "top # level." Let's fix that! +# + def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): del num_consts # Unused. new_jaxpr, new_consts = jvp_jaxpr(jaxpr) @@ -1592,7 +1619,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): vmap_rules[xla_call_p] = xla_call_vmap_rule @lru_cache() -def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis] +def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis, ...] ) -> Tuple[Jaxpr, List[Any]]: vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) in_avals = [unmapped_aval(axis_size, d, v.aval) @@ -1610,6 +1637,16 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray return ShapedArray(tuple(shape), aval.dtype) +# - + +def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts): + del num_consts # Unused. + jaxpr_type = typecheck_jaxpr(jaxpr) + if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)): + raise TypeError + return jaxpr_type.out_types +abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule + # + @jit def f(x): @@ -1622,13 +1659,12 @@ x, xdot = 3., 1. y, ydot = jvp(f, (x,), (xdot,)) print(y) print(ydot) +# - y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed ys = vmap(f, (0,))(np.arange(3.)) print(ys) -# - - # One piece missing is device memory persistence for arrays. That is, we've # defined `handle_result` to transfer results back to CPU memory as NumPy @@ -1678,8 +1714,6 @@ x, xdot = 3., 1. y, ydot = jvp(f, (x,), (xdot,)) print(y) print(ydot) - - # - # ## Part 4: `linearize` and `vjp` (and `grad`!) @@ -1687,11 +1721,16 @@ print(ydot) # The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve # jaxprs as well. That's because both involve staging out, or delaying, # computation. + +# ### `linearize` # # In the case of `linearize`, we want to stage out the linear part of a `jvp` # computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`, -# then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, where -# ``` +# then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to +# mean "the tangent type of `a`" and using the "lollipop" `-o` rather than the +# arrow `->` to indicate a _linear_ function. We define the semantics of +# `linearize` in terms of `jvp` too: +# ```python # y, f_lin = linearize(f, x) # y_dot = f_lin(x_dot) # ``` @@ -1699,30 +1738,46 @@ print(ydot) # ``` # y, y_dot = jvp(f, (x,), (x_dot,)) # ``` -# and where the application of `f_lin` does not redo any of the linearization -# work. We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr. +# where the application of `f_lin` does not redo any of the linearization work. +# We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr. # # To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation: # we evaluate all the primal values as we trace, but stage the tangent -# computations into a jaxpr. +# computations into a jaxpr. This is our second way to build jaxprs. But where +# `make_jaxpr` and its underlying `JaxprTrace`/`JaxprTracer` interpreters aim +# to stage out every primitive bind, this second approach stages out only those +# primitive binds with a data dependence on tagent inputs. +# +# First, some utilities: -def split_half(lst): - n, ragged = divmod(len(lst), 2) - assert not ragged +# + +def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]: return lst[:n], lst[n:] +def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]: + assert not len(lst) % 2 + return split_list(lst, len(lst) // 2) + +def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]: + lists = lst1, lst2 = [], [] + for b, x in zip(bs, l): + lists[b].append(x) + return lst1, lst2 +# - + +# Next, we'll write `linearize` by combining `jvp` together with a general +# partial evaluation transformation, to be added next: + # + def linearize_flat(f, *primals_in): pvals_in = ([PartialVal.known(x) for x in primals_in] + [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in]) - def f_jvp(*primals_tangents_in): primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in)) return [*primals_out, *tangents_out] - jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) primal_pvals, _ = split_half(pvals_out) - assert all(pval.is_known for pval in primal_pvals) + assert all(pval.is_known for pval in primal_pvals) primals_out = [pval.const for pval in primal_pvals] f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents]) return primals_out, f_lin @@ -1742,9 +1797,87 @@ def linearize(f, *primals_in): return primals_out, f_lin def vspace(aval: ShapedArray) -> ShapedArray: - return raise_to_shaped(aval) + return raise_to_shaped(aval) # TODO handle integers? +# - + +# Now we turn to the general partial evaluation transformation. The goal is to +# accept a Python callable and a list of inputs, some known and some unknown, +# and to produce (1) all the outputs which can be computed from the known +# inputs, together with (2) a jaxpr representing the part of the Python +# callable's computation which can only be performed after the remaining inputs +# are known. +# +# This transformation can't be summarized purely in a type signature because its +# behavior relies on the data dependencies inside the given Python callable and +# not just its type. Nevertheless a heuristic type signature is useful. If we +# assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where +# `a1` and `a2` represent the known and unknown inputs, respectively, and where +# `b1` only has a data depenence on `a1` while `b2` has some data dependnece on +# `a2`, then we might write +# +# ``` +# partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> (b1, res, (res, a2) -> b2) +# ``` +# +# In words, given values for the inputs of type `a1`, `partial_eval` produces +# the outputs of type `b1` along with "residual" values of type `res` +# representing the intermediates required to complete the computation in the +# second stage. It also produces a function of type `(res, a2) -> b2` which +# accepts the residual values as well as the remaining inputs and produces the +# remaining outputs. +# +# We like to think of partial evaluation as "unzipping" one computation into +# two. For example, consider this jaxpr: +# ``` +# { lambda a:float64[] . +# let b:float64[] = sin a +# c:float64[] = neg b +# in ( c ) } +# ``` +# A jaxpr for the JVP would look like: +# ``` +# { lambda a:float64[] b:float64 . +# let c:float64[] = sin a +# d:float64[] = cos a +# e:float64[] = mul d b +# f:float64[] = neg c +# g:float64[] = neg e +# in ( f, g ) } +# ``` +# If we imagine applying partial evaluation to this jaxpr with the first input +# known and the second unknown, we end up 'unzipping' the JVP jaxpr into primal +# and tangent jaxprs: +# ``` +# { lambda a:float64[] . +# let c:float64[] = sin a +# d:float64[] = cos a +# f:float64[] = neg c +# in ( f, d ) } +# ``` +# ``` +# { lambda d:float64[] b:float64[] . +# let e:float64[] = mul d b +# g:float64[] = neg e +# in ( g ) } +# ``` +# This second jaxpr is represents the linear computation that we want from +# `linearize`. +# +# However, unlike in this jaxpr example, we want the computation on known values +# to occur while evaluating the input Python callable. That is, rather than +# forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all +# operations out of Python first before sorting out what can be evaluated now +# and what must be delayed, we want only to form a jaxpr for those operations +# that _must_ be delayed due to a dependence on unknown inputs. In the context +# of automatic differentiation, this is the feature ultimately enables us to +# handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control +# flow works because partial evaluation keeps the primal computation in Python. +# As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out +# what can be evaluated and what must be staged out into a jaxpr. +# +# First, we start with a `PartialVal` class, which represents a value that can +# be either known or unknown: -# + class PartialVal(NamedTuple): aval: ShapedArray const: Optional[Any] @@ -1757,8 +1890,12 @@ class PartialVal(NamedTuple): def unknown(cls, aval: ShapedArray): return PartialVal(aval, None) - is_known = property(lambda self: self.const is not None) - is_unknown = property(lambda self: self.const is None) + is_known = property(lambda self: self.const is not None) + is_unknown = property(lambda self: self.const is None) + +# Partial evaluation will take a list of `PartialVal`s representing inputs, and +# return a list of `PartialVal` outputs along with a jaxpr representing the +# dleayed computation: def partial_eval_flat(f, pvals_in: List[PartialVal]): with new_main(PartialEvalTrace) as main: @@ -1770,10 +1907,19 @@ def partial_eval_flat(f, pvals_in: List[PartialVal]): pvals_out = [t.pval for t in tracers_out] return jaxpr, pvals_out, consts +# Next we need to implement `PartialEvalTrace` and its `PartialEvalTracer`. This +# interpreter will build a jaxpr on the fly while tracking data dependencies. To +# do so, it builds a bipartite directed acyclic graph (DAG) between +# `PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe` +# nodes, representing formulas for how compute some values from others. One kind +# of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive +# application, but we also have recipe types for constants and lambda binders: + # + from weakref import ref, ReferenceType -class LambdaBindingRecipe(NamedTuple): pass +class LambdaBindingRecipe(NamedTuple): + pass class ConstRecipe(NamedTuple): val: Any @@ -1794,6 +1940,9 @@ class JaxprEqnRecipe: JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe] + +# - + class PartialEvalTracer(Tracer): pval: PartialVal recipe: JaxprRecipe @@ -1812,6 +1961,26 @@ class PartialEvalTracer(Tracer): return full_lower(self.pval.const) return self +# The `PartialEvalTrace` contains the logic for constructing the graph of +# `JaxprRecipe`s and `PartialEvalTracer`s. Each argument corresponds to a +# `LambdaBindingRecipe` leaf node, and each constant is a `ConstRecipe` leaf +# node holding a reference to the constant. All other tracers and recipes come +# from `process_primitive`, which forms tracers with `JaxprEqnRecipe`s. +# +# For most primitives, the `process_primitive` logic is straightforward: if all +# inputs are known then we can bind the primitive on the known values +# (evaluating it in Python) and avoid forming tracers corresponding to the +# output. If instead any input is unknown then we instead stage out into a +# `JaxprEqnRecipe` representing the primitive application. To build the tracers +# representing unknown outputs, we need avals, which get from the abstract eval +# rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s +# reference tracers; we avoid circular garbage by using weakrefs.) +# +# That `process_primitive` logic applies to most primitives, but `xla_call_p` +# requires recursive treatment. So we special-case its rule in a +# `partial_eval_rules` dict. + +# + class PartialEvalTrace(Trace): def new_arg(self, pval: PartialVal) -> Any: return PartialEvalTracer(self, pval, LambdaBindingRecipe()) @@ -1830,6 +1999,8 @@ class PartialEvalTrace(Trace): def process_primitive(self, primitive, tracers, params): if all(t.pval.is_known for t in tracers): return bind(primitive, *map(full_lower, tracers), **params) + rule = partial_eval_rules.get(primitive) + if rule: return rule(self, tracers, **params) tracers_in = [self.instantiate_const(t) for t in tracers] avals_in = [t.aval for t in tracers_in] avals_out = abstract_eval_rules[primitive](*avals_in, **params) @@ -1840,6 +2011,13 @@ class PartialEvalTrace(Trace): for t in tracers_out: t.recipe = eqn return tracers_out +partial_eval_rules = {} +# - + +# Now that we can build graph representations of jaxprs with `PartialEvalTrace`, +# we need a mechanism to convert the graph representation to a standard jaxpr. +# The jaxpr corresponds to a topological sort of the graph. + # + def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer], tracers_out: List[PartialEvalTracer]): @@ -1927,10 +2105,366 @@ def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]): for node in nodes: assert all(id(parent) in seen for parent in parents(node)) seen.add(id(node)) - - # - +# Now we can linearize! + y, sin_lin = linearize(sin, 3.) print(y, sin(3.)) print(sin_lin(1.), cos(3.)) + +# To handle linearize-of-jit, we still need to write a partial evaluation rule +# for `xla_call_p`. Other than tracer bookkeeping, the main task is to perform +# partial evaluation of a jaxpr, 'unzipping' it into two jaxprs. + +# + +def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts): + del num_consts # Unused. + in_unknowns = [not t.pval.is_known for t in tracers] + jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns) + known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) + known_vals = [t.pval.const for t in known_tracers] + outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0) + outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res) + res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res] + outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None) + for v in jaxpr2.outs] + eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers, + dict(jaxpr=jaxpr2, num_consts=0), + [v.aval for v in jaxpr2.outs], map(ref, outs2)) + for t in outs2: t.recipe = eqn + outs1, outs2 = iter(outs1), iter(outs2) + return [next(outs2) if uk else next(outs1) for uk in out_unknowns] +partial_eval_rules[xla_call_p] = xla_call_partial_eval + +def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool] + ) -> Tuple[Jaxpr, Jaxpr, List[bool], int]: + env: Dict[Var, bool] = {} + residuals = set() + + def read(v: Atom) -> bool: + if type(v) is Lit: raise NotImplementedError + return env[v] + + def write(unk: bool, v: Var) -> None: + env[v] = unk + + def new_res(v: Var) -> Var: + return residuals.add(v) or v + + eqns1, eqns2 = [], [] + map(write, in_unknowns, jaxpr.in_binders) + for eqn in jaxpr.eqns: + unks_in = map(read, eqn.inputs) + rule = partial_eval_jaxpr_rules.get(eqn.primitive) + if rule: + eqn1, eqn2, unks_out, res = rule(unks_in, eqn) + eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res) + map(write, unks_out, eqn.out_binders) + elif any(unks_in): + inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)] + eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders)) + map(partial(write, True), eqn.out_binders) + else: + eqns1.append(eqn) + map(partial(write, False), eqn.out_binders) + out_unknowns = map(read, jaxpr.outs) + residuals, num_res = list(residuals), len(residuals) + + ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders) + outs1, outs2 = partition_list(out_unknowns, jaxpr.outs) + + jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals) + jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2) + typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2) + + return jaxpr1, jaxpr2, out_unknowns, num_res + +def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2): + jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 ) + jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res) + jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2 + + a1, a2 = partition_list(unks_in, jaxprty.in_types) + b1, b2 = partition_list(unks_out, jaxprty.out_types) + b1_, res = split_list(jaxpr1ty.out_types, len(b1)) + res_, a2_ = split_list(jaxpr2ty.in_types, len(res)) + b2_ = jaxpr2ty.out_types + + if jaxpr1ty.in_types != a1: raise TypeError + if jaxpr2ty.out_types != b2: raise TypeError + if b1 != b1_: raise TypeError + if res != res_: raise TypeError + if a2 != a2_: raise TypeError + if b2 != b2_: raise TypeError + +partial_eval_jaxpr_rules = {} + +def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn + ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]: + jaxpr = eqn.params['jaxpr'] + jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in) + ins1, ins2 = partition_list(unks_in, eqn.inputs) + outs1, outs2 = partition_list(unks_out, eqn.out_binders) + residuals, _ = split_list(jaxpr2.in_binders, num_res) + eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0), + outs1 + residuals) + eqn2 = JaxprEqn(xla_call_p, residuals + ins2, + dict(jaxpr=jaxpr2, num_consts=0), outs2) + return eqn1, eqn2, unks_out, residuals +partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn +# - + +# With that, we can compose `linearize` and `jit` however we like: + +# + +@jit +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +y, f_lin = linearize(f, 3.) +y_dot = f_lin(1.) +print(y, y_dot) + +# + +@jit +def f(x): + y = sin(x) * 2. + z = g(x, y) + return z + +@jit +def g(x, y): + return cos(x) + y + +y, f_lin = linearize(f, 3.) +y_dot = f_lin(1.) +print(y, y_dot) +# - + +# ### `vjp` and `grad` +# +# The `vjp` transformation works a lot like linearize. Its type signature is +# analogous: +# +# ``` +# linearize : (a -> b) -> a -> (b, T a -o T b) +# vjp : (a -> b) -> a -> (b, T b -o T a) +# ``` +# +# The only difference is that we transpose the linear part of the computation +# before returning it, so that it goes from type `T a -o T b` to type `T b -o T +# a`. That is, we'll implement `vjp` as, essentially, +# +# ``` +# def vjp(f, x): +# y, f_lin = linearize(f, x) +# f_vjp = lambda y_bar: transpose(f_lin)(y_bar) +# return y, f_vjp +# ``` +# +# Since we have the linear computation as a jaxpr, not just a Python callable, +# we can implement the transpose transformation as a jaxpr interpreter. + +# + +def vjp_flat(f, *primals_in): + pvals_in = ([PartialVal.known(x) for x in primals_in] + + [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in]) + primal_pvals_in, tangent_pvals_in = split_half(pvals_in) + def f_jvp(*primals_tangents_in): + primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in)) + return [*primals_out, *tangents_out] + jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize + primal_pvals, _ = split_half(pvals_out) + assert all(pval.is_known for pval in primal_pvals) + primals_out = [pval.const for pval in primal_pvals] + transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in] + f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts) + return primals_out, f_vjp + +def vjp(f, *primals_in): + primals_in_flat, in_tree = tree_flatten(primals_in) + f, out_tree = flatten_fun(f, in_tree) + primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat) + primals_out = tree_unflatten(out_tree(), primals_out_flat) + + def f_vjp(*cotangents_out): + cotangents_out_flat, _ = tree_flatten(cotangents_out) + cotangents_in_flat = f_vjp_flat(*cotangents_out_flat) + return tree_unflatten(in_tree, cotangents_in_flat) + + return primals_out, f_vjp + +class UndefPrimal(NamedTuple): + aval: ShapedArray + +register_pytree_node(UndefPrimal, + lambda u: (u.aval, ()), + lambda aval, _: UndefPrimal(aval)) +# - + +# We use `UndefPrimal` instances to indicate which arguments with respect to +# with we want to transpose. These arise because in general, being explicit +# about closed-over values, we want to transpose functions of type +# `a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the +# inputs with respect to which the function is linear could be scattered through +# the argument list. So we indicate the linear positions using `UndefPrimal`. +# We register `UndefPrimal` as a pytree node because the pytree mechanism gives +# a handy way to prune these placeholders out of argument lists. +# +# Next, we can write `eval_jaxpr_transposed`, along with transpose rules for +# all primitives which can be linear in at least one argument: + +# + +# NB: the analogous function in JAX is called 'backward_pass' +def eval_jaxpr_transposed(jaxpr: Jaxpr, args: List[Any], cotangents: List[Any] + ) -> List[Any]: + primal_env: Dict[Var, Any] = {} + ct_env: Dict[Var, Any] = {} + + def read_primal(x: Atom) -> Any: + return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val + + def write_primal(v: Var, val: Any) -> None: + if type(val) is not UndefPrimal: + primal_env[v] = val + + def read_cotangent(v: Var) -> Any: + return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype)) + + def write_cotangent(x: Atom, val: Any): + if type(x) is Var and val is not None: + ct_env[x] = add(ct_env[x], val) if x in ct_env else val + + map(write_primal, jaxpr.in_binders, args) + map(write_cotangent, jaxpr.outs, cotangents) + for eqn in jaxpr.eqns[::-1]: + primals_in = map(read_primal, eqn.inputs) + cts_in = map(read_cotangent, eqn.out_binders) + rule = transpose_rules[eqn.primitive] + cts_out = rule(cts_in, *primals_in, **eqn.params) + map(write_cotangent, eqn.inputs, cts_out) + + return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args) + if type(x) is UndefPrimal] + +transpose_rules = {} + +# + +def mul_transpose_rule(cts, x, y): + z_bar, = cts + assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal) + return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)] +transpose_rules[mul_p] = mul_transpose_rule + +def neg_transpose_rule(cts, x): + ybar, = cts + assert type(x) is UndefPrimal + return [neg(ybar)] +transpose_rules[neg_p] = neg_transpose_rule + +def add_transpose_rule(cts, x, y): + z_bar, = cts + return [z_bar, z_bar] +transpose_rules[add_p] = add_transpose_rule + +def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts): + del num_consts # Unused. + undef_primals = [type(x) is UndefPrimal for x in invals] + transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals)) + residuals, _ = partition_list(undef_primals, invals) + outs = bind(xla_call_p, *new_consts, *residuals, *cts, + jaxpr=transposed_jaxpr, num_consts=len(new_consts)) + outs = iter(outs) + return [next(outs) if undef else None for undef in undef_primals] +transpose_rules[xla_call_p] = xla_call_transpose_rule + +@lru_cache() +def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: Tuple[bool, ...] + ) -> Tuple[Jaxpr, List[Any]]: + traceable = partial(eval_jaxpr_transposed, jaxpr) + avals_in, avals_out = typecheck_jaxpr(jaxpr) + args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)] + trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out)) + return trans_jaxpr, consts +# - + +# Now that we can linearize and transpose, we can finally write `grad`: + +def grad(f): + def gradfun(x, *xs): + y, f_vjp = vjp(f, x, *xs) + if np.shape(y) != (): raise TypeError + x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y))) + return x_bar + return gradfun + +y, f_vjp = vjp(sin, 3.) +print(f_vjp(1.), cos(3.)) + +# + +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +print(grad(f)(3.)) + +# + +@jit +def f(x): + y = x * 2. + z = g(y) + return z + +@jit +def g(x): + return cos(x) * 2. + +print(grad(f)(3.)) +# - + +# Here's something of a compositionality stress test: + +# + +# from core_test.py fun_with_nested_calls_2 +def foo(x): + @jit + def bar(y): + def baz(w): + q = jit(lambda x: y)(x) + q = q + jit(lambda: y)() + q = q + jit(lambda y: w + y)(y) + q = jit(lambda w: jit(sin)(x) * y)(1.0) + q + return q + p, t = jvp(baz, (x + 1.0,), (y,)) + return t + (x * p) + return bar(x) + +def assert_allclose(*vals): + for v1, v2 in zip(vals[:-1], vals[1:]): + np.testing.assert_allclose(v1, v2) + +ans1 = f(3.) +ans2 = jit(f)(3.) +ans3, _ = jvp(f, (3.,), (5.,)) +ans4, _ = jvp(jit(f), (3.,), (5.,)) +assert_allclose(ans1, ans2, ans3, ans4) + +deriv1 = grad(f)(3.) +deriv2 = grad(jit(f))(3.) +deriv3 = jit(grad(jit(f)))(3.) +_, deriv4 = jvp(f, (3.,), (1.,)) +_, deriv5 = jvp(jit(f), (3.,), (1.,)) +assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5) + +hess1 = grad(grad(f))(3.) +hess2 = grad(grad(jit(f)))(3.) +hess3 = grad(jit(grad(f)))(3.) +hess4 = jit(grad(grad(f)))(3.) +_, hess5 = jvp(grad(f), (3.,), (1.,)) +_, hess6 = jvp(jit(grad(f)), (3.,), (1.,)) +_, hess7 = jvp(jit(grad(f)), (3.,), (1.,)) +assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7) diff --git a/setup.cfg b/setup.cfg index a18b95160..683500d2f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,9 @@ ignore = W503, W504 # line breaks around binary operators max-complexity = 18 select = B,C,F,W,T4,B9 -exclude = +exclude = .git, build, - __pycache__ \ No newline at end of file + __pycache__ +per-file-ignores = + docs/autodidax.py:F811