mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details. PiperOrigin-RevId: 476167538
3885 lines
141 KiB
Plaintext
3885 lines
141 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "raw",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"Copyright 2021 The JAX Authors.\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
|
"you may not use this file except in compliance with the License.\n",
|
|
"You may obtain a copy of the License at\n",
|
|
"\n",
|
|
" https://www.apache.org/licenses/LICENSE-2.0\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software\n",
|
|
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
|
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
|
"See the License for the specific language governing permissions and\n",
|
|
"limitations under the License.\n",
|
|
"\n",
|
|
"---"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Autodidax: JAX core from scratch\n",
|
|
"\n",
|
|
"Ever want to learn how JAX works, but the implementation seemed impenetrable?\n",
|
|
"Well, you're in luck! By reading this tutorial, you'll learn every big idea in\n",
|
|
"JAX's core system. You'll even get clued into our weird jargon!\n",
|
|
"\n",
|
|
"**This is a work-in-progress draft.** There are some important ingredients\n",
|
|
"missing, still to come in parts 5 and 6 (and more?). There are also some\n",
|
|
"simplifications here that we haven't yet applied to the main system, but we\n",
|
|
"will."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Part 1: Transformations as interpreters: standard evaluation, `jvp`, and `vmap`\n",
|
|
"\n",
|
|
"We want to transform functions that look like this:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"```\n",
|
|
"\n",
|
|
"Think of functions like `sin` and the arithmetic operations underlying the\n",
|
|
"infix operators (`mul`, `add`, and `neg`) as primitive operations, meaning\n",
|
|
"atomic units of processing rather than compositions.\n",
|
|
"\n",
|
|
"\"Transform\" means \"interpret differently.\" Instead of standard interpretation\n",
|
|
"where we apply primitive operations to numerical inputs to produce numerical\n",
|
|
"outputs, we want to override primitive application and let different values\n",
|
|
"flow through our program. For example, we might want to replace the\n",
|
|
"application of every primitive with an application of [its JVP\n",
|
|
"rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n",
|
|
"and let primal-tangent pairs flow through our program. Moreover, we want to be\n",
|
|
"able to compose multiple transformations, leading to stacks of interpreters."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### JAX core machinery\n",
|
|
"\n",
|
|
"We can implement stacks of interpreters and even have them all discharge on\n",
|
|
"the fly as we execute the Python function to be transformed. To start, let's\n",
|
|
"define these primitives so that we can intercept their application:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import NamedTuple\n",
|
|
"\n",
|
|
"class Primitive(NamedTuple):\n",
|
|
" name: str\n",
|
|
"\n",
|
|
"add_p = Primitive('add')\n",
|
|
"mul_p = Primitive('mul')\n",
|
|
"neg_p = Primitive(\"neg\")\n",
|
|
"sin_p = Primitive(\"sin\")\n",
|
|
"cos_p = Primitive(\"cos\")\n",
|
|
"reduce_sum_p = Primitive(\"reduce_sum\")\n",
|
|
"greater_p = Primitive(\"greater\")\n",
|
|
"less_p = Primitive(\"less\")\n",
|
|
"transpose_p = Primitive(\"transpose\")\n",
|
|
"broadcast_p = Primitive(\"broadcast\")\n",
|
|
"\n",
|
|
"def add(x, y): return bind1(add_p, x, y)\n",
|
|
"def mul(x, y): return bind1(mul_p, x, y)\n",
|
|
"def neg(x): return bind1(neg_p, x)\n",
|
|
"def sin(x): return bind1(sin_p, x)\n",
|
|
"def cos(x): return bind1(cos_p, x)\n",
|
|
"def greater(x, y): return bind1(greater_p, x, y)\n",
|
|
"def less(x, y): return bind1(less_p, x, y)\n",
|
|
"def transpose(x, perm): return bind1(transpose_p, x, perm=perm)\n",
|
|
"def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)\n",
|
|
"def reduce_sum(x, axis=None):\n",
|
|
" if axis is None:\n",
|
|
" axis = tuple(range(np.ndim(x)))\n",
|
|
" if type(axis) is int:\n",
|
|
" axis = (axis,)\n",
|
|
" return bind1(reduce_sum_p, x, axis=axis)\n",
|
|
"\n",
|
|
"def bind1(prim, *args, **params):\n",
|
|
" out, = bind(prim, *args, **params)\n",
|
|
" return out"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We'll set up array data types and infix operator methods in a moment.\n",
|
|
"\n",
|
|
"A `Primitive` is just an object with a name, to which we attach our\n",
|
|
"interpretation rules (one for each transformation). The `bind` function is our\n",
|
|
"interception point: it'll figure out which transformation rule to apply, based\n",
|
|
"on how the arguments are boxed in tracers and what interpreters are active.\n",
|
|
"\n",
|
|
"The functions that user code calls, like `add` and `sin`, are just wrappers\n",
|
|
"around calls to `bind`. These wrappers let us control how arguments are passed\n",
|
|
"to `bind`, and in particular we follow a handy internal convention: when we\n",
|
|
"call `bind`, we pass values representing array data as positional arguments,\n",
|
|
"and we pass metadata like the `axis` argument to `sum_p` via keyword. This\n",
|
|
"calling convention simplifies some core logic (since e.g. instances of the\n",
|
|
"`Tracer` class to be defined below can only occur in positional arguments to\n",
|
|
"`bind`). The wrappers can also provide docstrings!\n",
|
|
"\n",
|
|
"We represent active interpreters as a stack. The stack is just a simple\n",
|
|
"`list`, and each element is a container with an integer level (corresponding\n",
|
|
"to the element's height in the stack), an interpreter type (which we'll call a\n",
|
|
"`trace_type`), and an optional field for any global data the interpreter\n",
|
|
"needs. We call each element a `MainTrace`, though maybe \"Interpreter\" would be\n",
|
|
"more descriptive."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from contextlib import contextmanager\n",
|
|
"from typing import Type, List, Tuple, Sequence, Optional, Any\n",
|
|
"\n",
|
|
"class MainTrace(NamedTuple):\n",
|
|
" level: int\n",
|
|
" trace_type: Type['Trace']\n",
|
|
" global_data: Optional[Any]\n",
|
|
"\n",
|
|
"trace_stack: List[MainTrace] = []\n",
|
|
"dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3\n",
|
|
"\n",
|
|
"@contextmanager\n",
|
|
"def new_main(trace_type: Type['Trace'], global_data=None):\n",
|
|
" level = len(trace_stack)\n",
|
|
" main = MainTrace(level, trace_type, global_data)\n",
|
|
" trace_stack.append(main)\n",
|
|
"\n",
|
|
" try:\n",
|
|
" yield main\n",
|
|
" finally:\n",
|
|
" trace_stack.pop()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"When we're about to apply a transformation, we'll push another interpreter\n",
|
|
"onto the stack using `new_main`. Then, as we apply primitives in the function,\n",
|
|
"we can think of the `bind` first being interpreted by the trace at the top of\n",
|
|
"the stack (i.e. with the highest level). If that first interpreter itself\n",
|
|
"binds other primitives in its interpretation rule for the primitive, like how\n",
|
|
"the JVP rule of `sin_p` might bind `cos_p` and `mul_p`, then those `bind`\n",
|
|
"calls will be handled by the interpreter at the next level down.\n",
|
|
"\n",
|
|
"What goes at the bottom of the interpreter stack? At the bottom, we know all\n",
|
|
"the transformation interpreters are finished, and we just want to do standard\n",
|
|
"evaluation. So at the bottom we'll put an evaluation interpreter.\n",
|
|
"\n",
|
|
"Let's sketch out the interface for interpreters, which is based on the `Trace`\n",
|
|
"and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps\n",
|
|
"carrying some extra context data used by the interpreter. A `Trace` handles\n",
|
|
"boxing up values into `Tracers` and also handles primitive application."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Trace:\n",
|
|
" main: MainTrace\n",
|
|
"\n",
|
|
" def __init__(self, main: MainTrace) -> None:\n",
|
|
" self.main = main\n",
|
|
"\n",
|
|
" def pure(self, val): assert False # must override\n",
|
|
" def lift(self, val): assert False # must override\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" assert False # must override"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The first two methods are about boxing up values in `Tracer`s, which are the\n",
|
|
"objects that flow through the Python programs we transform. The last method is\n",
|
|
"the callback we'll use to interpret primitive application.\n",
|
|
"\n",
|
|
"The `Trace` itself doesn't contain any data, other than a reference to its\n",
|
|
"corresponding `MainTrace` instance. In fact, multiple instances of a `Trace`\n",
|
|
"might be created and discarded during an application of a transformation,\n",
|
|
"whereas only a single `MainTrace` instance is created per application of a\n",
|
|
"transformation.\n",
|
|
"\n",
|
|
"As for `Tracer`s themselves, each one carries an abstract value (and forwards\n",
|
|
"infix operators to it), and the rest is up to the transformation. (The\n",
|
|
"relationship between `Tracer`s and `AbstractValue`s is that there's one\n",
|
|
"`Tracer` per transformation, and at least one `AbstractValue` per base type,\n",
|
|
"like arrays.)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"class Tracer:\n",
|
|
" _trace: Trace\n",
|
|
"\n",
|
|
" __array_priority__ = 1000\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def aval(self):\n",
|
|
" assert False # must override\n",
|
|
"\n",
|
|
" def full_lower(self):\n",
|
|
" return self # default implementation\n",
|
|
"\n",
|
|
" def __neg__(self): return self.aval._neg(self)\n",
|
|
" def __add__(self, other): return self.aval._add(self, other)\n",
|
|
" def __radd__(self, other): return self.aval._radd(self, other)\n",
|
|
" def __mul__(self, other): return self.aval._mul(self, other)\n",
|
|
" def __rmul__(self, other): return self.aval._rmul(self, other)\n",
|
|
" def __gt__(self, other): return self.aval._gt(self, other)\n",
|
|
" def __lt__(self, other): return self.aval._lt(self, other)\n",
|
|
" def __bool__(self): return self.aval._bool(self)\n",
|
|
" def __nonzero__(self): return self.aval._nonzero(self)\n",
|
|
"\n",
|
|
" def __getattr__(self, name):\n",
|
|
" try:\n",
|
|
" return getattr(self.aval, name)\n",
|
|
" except AttributeError:\n",
|
|
" raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")\n",
|
|
"\n",
|
|
"def swap(f): return lambda x, y: f(y, x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class ShapedArray:\n",
|
|
" array_abstraction_level = 1\n",
|
|
" shape: Tuple[int, ...]\n",
|
|
" dtype: np.dtype\n",
|
|
"\n",
|
|
" def __init__(self, shape, dtype):\n",
|
|
" self.shape = shape\n",
|
|
" self.dtype = dtype\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def ndim(self):\n",
|
|
" return len(self.shape)\n",
|
|
"\n",
|
|
" _neg = staticmethod(neg)\n",
|
|
" _add = staticmethod(add)\n",
|
|
" _radd = staticmethod(swap(add))\n",
|
|
" _mul = staticmethod(mul)\n",
|
|
" _rmul = staticmethod(swap(mul))\n",
|
|
" _gt = staticmethod(greater)\n",
|
|
" _lt = staticmethod(less)\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def _bool(tracer):\n",
|
|
" raise Exception(\"ShapedArray can't be unambiguously converted to bool\")\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def _nonzero(tracer):\n",
|
|
" raise Exception(\"ShapedArray can't be unambiguously converted to bool\")\n",
|
|
"\n",
|
|
" def str_short(self):\n",
|
|
" return f'{self.dtype.name}[{\",\".join(str(d) for d in self.shape)}]'\n",
|
|
"\n",
|
|
" def __hash__(self):\n",
|
|
" return hash((self.shape, self.dtype))\n",
|
|
"\n",
|
|
" def __eq__(self, other):\n",
|
|
" return (type(self) is type(other) and\n",
|
|
" self.shape == other.shape and self.dtype == other.dtype)\n",
|
|
"\n",
|
|
" def __repr__(self):\n",
|
|
" return f\"ShapedArray(shape={self.shape}, dtype={self.dtype})\"\n",
|
|
"\n",
|
|
"class ConcreteArray(ShapedArray):\n",
|
|
" array_abstraction_level = 2\n",
|
|
" val: np.ndarray\n",
|
|
"\n",
|
|
" def __init__(self, val):\n",
|
|
" self.val = val\n",
|
|
" self.shape = val.shape\n",
|
|
" self.dtype = val.dtype\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def _bool(tracer):\n",
|
|
" return bool(tracer.aval.val)\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def _nonzero(tracer):\n",
|
|
" return bool(tracer.aval.val)\n",
|
|
"\n",
|
|
"def get_aval(x):\n",
|
|
" if isinstance(x, Tracer):\n",
|
|
" return x.aval\n",
|
|
" elif type(x) in jax_types:\n",
|
|
" return ConcreteArray(np.asarray(x))\n",
|
|
" else:\n",
|
|
" raise TypeError(x)\n",
|
|
"\n",
|
|
"jax_types = {bool, int, float,\n",
|
|
" np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice that we actually have two `AbstractValue`s for arrays, representing\n",
|
|
"different levels of abstraction. A `ShapedArray` represents the set of all\n",
|
|
"possible arrays with a given shape and dtype. A `ConcreteArray` represents a\n",
|
|
"singleton set consisting of a single array value.\n",
|
|
"\n",
|
|
"Now that we've set up the interpreter stack, the Trace/Tracer API for\n",
|
|
"interpreters, and abstract values, we can come back to implement `bind`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def bind(prim, *args, **params):\n",
|
|
" top_trace = find_top_trace(args)\n",
|
|
" tracers = [full_raise(top_trace, arg) for arg in args]\n",
|
|
" outs = top_trace.process_primitive(prim, tracers, params)\n",
|
|
" return [full_lower(out) for out in outs]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The main action is that we call `find_top_trace` to figure out which\n",
|
|
"interpreter should handle this primitive application. We then call that top\n",
|
|
"trace's `process_primitive` so that the trace can apply its interpretation\n",
|
|
"rule. The calls to `full_raise` just ensure that the inputs are boxed in the\n",
|
|
"top trace's `Tracer` instances, and the call to `full_lower` is an optional\n",
|
|
"optimization so that we unbox values out of `Tracer`s as much as possible."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import operator as op\n",
|
|
"\n",
|
|
"def find_top_trace(xs) -> Trace:\n",
|
|
" top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),\n",
|
|
" default=trace_stack[0], key=op.attrgetter('level'))\n",
|
|
" if dynamic_trace and dynamic_trace.level > top_main.level:\n",
|
|
" top_main = dynamic_trace\n",
|
|
" return top_main.trace_type(top_main)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In words, ignoring the `dynamic_trace` step until Part 3, `find_top_trace`\n",
|
|
"returns the highest-level interpreter associated with the `Tracer`s on its\n",
|
|
"inputs, and otherwise returns the interpreter at the bottom of the stack\n",
|
|
"(which is always an evaluation trace, at least for now). This is a deviation\n",
|
|
"from the description above, where we always start by running the interpreter\n",
|
|
"at the top of the stack and then work our way down, applying every interpreter\n",
|
|
"in the stack. Instead, we're only applying an interpreter when the input\n",
|
|
"arguments to a primitive bind are boxed in a `Tracer` corresponding to that\n",
|
|
"interpreter. This optimization lets us skip irrelevant transformations, but\n",
|
|
"bakes in an assumption that transformations mostly follow data dependence\n",
|
|
"(except for the special bottom-of-the-stack interpreter, which interprets\n",
|
|
"everything).\n",
|
|
"\n",
|
|
"An alternative would be to have every interpreter in the stack interpret every\n",
|
|
"operation. That's worth exploring! JAX is designed around data dependence in\n",
|
|
"large part because that's so natural for automatic differentiation, and JAX's\n",
|
|
"roots are in autodiff. But it may be over-fit."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def full_lower(val: Any):\n",
|
|
" if isinstance(val, Tracer):\n",
|
|
" return val.full_lower()\n",
|
|
" else:\n",
|
|
" return val\n",
|
|
"\n",
|
|
"def full_raise(trace: Trace, val: Any) -> Tracer:\n",
|
|
" if not isinstance(val, Tracer):\n",
|
|
" assert type(val) in jax_types\n",
|
|
" return trace.pure(val)\n",
|
|
" level = trace.main.level\n",
|
|
" if val._trace.main is trace.main:\n",
|
|
" return val\n",
|
|
" elif val._trace.main.level < level:\n",
|
|
" return trace.lift(val)\n",
|
|
" elif val._trace.main.level > level:\n",
|
|
" raise Exception(f\"Can't lift level {val._trace.main.level} to {level}.\")\n",
|
|
" else: # val._trace.level == level\n",
|
|
" raise Exception(f\"Different traces at same level: {val._trace}, {trace}.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The logic in `full_raise` serves to box values into `Tracer`s for a particular\n",
|
|
"`Trace`, calling different methods on the `Trace` based on context:\n",
|
|
"`Trace.pure` is called on non-`Tracer` constants, and `Trace.lift` is called\n",
|
|
"for values that are already `Tracer`s from a lower-level interpreter. These\n",
|
|
"two methods could share the same implementation, but by distinguishing them in\n",
|
|
"the core logic we can provide more information to the `Trace` subclass.\n",
|
|
"\n",
|
|
"That's it for the JAX core! Now we can start adding interpreters."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Evaluation interpreter\n",
|
|
"\n",
|
|
"We'll start with the simplest interpreter: the evaluation interpreter that\n",
|
|
"will sit at the bottom of the interpreter stack."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class EvalTrace(Trace):\n",
|
|
" pure = lift = lambda self, x: x # no boxing in Tracers needed\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" return impl_rules[primitive](*tracers, **params)\n",
|
|
"\n",
|
|
"trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack\n",
|
|
"\n",
|
|
"# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance\n",
|
|
"impl_rules = {}\n",
|
|
"\n",
|
|
"impl_rules[add_p] = lambda x, y: [np.add(x, y)]\n",
|
|
"impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]\n",
|
|
"impl_rules[neg_p] = lambda x: [np.negative(x)]\n",
|
|
"impl_rules[sin_p] = lambda x: [np.sin(x)]\n",
|
|
"impl_rules[cos_p] = lambda x: [np.cos(x)]\n",
|
|
"impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]\n",
|
|
"impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]\n",
|
|
"impl_rules[less_p] = lambda x, y: [np.less(x, y)]\n",
|
|
"impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]\n",
|
|
"\n",
|
|
"def broadcast_impl(x, *, shape, axes):\n",
|
|
" for axis in sorted(axes):\n",
|
|
" x = np.expand_dims(x, axis)\n",
|
|
" return [np.broadcast_to(x, shape)]\n",
|
|
"impl_rules[broadcast_p] = broadcast_impl"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With this interpreter, we can evaluate user functions:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"\n",
|
|
"print(f(3.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Woo! Like going around in a big circle. But the point of this indirection is\n",
|
|
"that now we can add some real transformations."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Forward-mode autodiff with `jvp`\n",
|
|
"\n",
|
|
"First, a few helper functions:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def zeros_like(val):\n",
|
|
" aval = get_aval(val)\n",
|
|
" return np.zeros(aval.shape, aval.dtype)\n",
|
|
"\n",
|
|
"def unzip2(pairs):\n",
|
|
" lst1, lst2 = [], []\n",
|
|
" for x1, x2 in pairs:\n",
|
|
" lst1.append(x1)\n",
|
|
" lst2.append(x2)\n",
|
|
" return lst1, lst2\n",
|
|
"\n",
|
|
"map_ = map\n",
|
|
"def map(f, *xs):\n",
|
|
" return list(map_(f, *xs))\n",
|
|
"\n",
|
|
"zip_ = zip\n",
|
|
"def zip(*args):\n",
|
|
" fst, *rest = args = map(list, args)\n",
|
|
" n = len(fst)\n",
|
|
" for arg in rest:\n",
|
|
" assert len(arg) == n\n",
|
|
" return list(zip_(*args))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The\n",
|
|
"`Trace` applies JVP rules."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class JVPTracer(Tracer):\n",
|
|
" def __init__(self, trace, primal, tangent):\n",
|
|
" self._trace = trace\n",
|
|
" self.primal = primal\n",
|
|
" self.tangent = tangent\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def aval(self):\n",
|
|
" return get_aval(self.primal)\n",
|
|
"\n",
|
|
"class JVPTrace(Trace):\n",
|
|
" pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)\n",
|
|
" jvp_rule = jvp_rules[primitive]\n",
|
|
" primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)\n",
|
|
" return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]\n",
|
|
"\n",
|
|
"jvp_rules = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice both `pure` and `lift` package a value into a `JVPTracer` with the\n",
|
|
"minimal amount of context, which is a zero tangent value.\n",
|
|
"\n",
|
|
"Let's add some JVP rules for primitives:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def add_jvp(primals, tangents):\n",
|
|
" (x, y), (x_dot, y_dot) = primals, tangents\n",
|
|
" return [x + y], [x_dot + y_dot]\n",
|
|
"jvp_rules[add_p] = add_jvp\n",
|
|
"\n",
|
|
"def mul_jvp(primals, tangents):\n",
|
|
" (x, y), (x_dot, y_dot) = primals, tangents\n",
|
|
" return [x * y], [x_dot * y + x * y_dot]\n",
|
|
"jvp_rules[mul_p] = mul_jvp\n",
|
|
"\n",
|
|
"def sin_jvp(primals, tangents):\n",
|
|
" (x,), (x_dot,) = primals, tangents\n",
|
|
" return [sin(x)], [cos(x) * x_dot]\n",
|
|
"jvp_rules[sin_p] = sin_jvp\n",
|
|
"\n",
|
|
"def cos_jvp(primals, tangents):\n",
|
|
" (x,), (x_dot,) = primals, tangents\n",
|
|
" return [cos(x)], [-sin(x) * x_dot]\n",
|
|
"jvp_rules[cos_p] = cos_jvp\n",
|
|
"\n",
|
|
"def neg_jvp(primals, tangents):\n",
|
|
" (x,), (x_dot,) = primals, tangents\n",
|
|
" return [neg(x)], [neg(x_dot)]\n",
|
|
"jvp_rules[neg_p] = neg_jvp\n",
|
|
"\n",
|
|
"def reduce_sum_jvp(primals, tangents, *, axis):\n",
|
|
" (x,), (x_dot,) = primals, tangents\n",
|
|
" return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]\n",
|
|
"jvp_rules[reduce_sum_p] = reduce_sum_jvp\n",
|
|
"\n",
|
|
"def greater_jvp(primals, tangents):\n",
|
|
" (x, y), _ = primals, tangents\n",
|
|
" out_primal = greater(x, y)\n",
|
|
" return [out_primal], [zeros_like(out_primal)]\n",
|
|
"jvp_rules[greater_p] = greater_jvp\n",
|
|
"\n",
|
|
"def less_jvp(primals, tangents):\n",
|
|
" (x, y), _ = primals, tangents\n",
|
|
" out_primal = less(x, y)\n",
|
|
" return [out_primal], [zeros_like(out_primal)]\n",
|
|
"jvp_rules[less_p] = less_jvp"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we add a transformation API to kick off the trace:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def jvp_v1(f, primals, tangents):\n",
|
|
" with new_main(JVPTrace) as main:\n",
|
|
" trace = JVPTrace(main)\n",
|
|
" tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]\n",
|
|
" out = f(*tracers_in)\n",
|
|
" tracer_out = full_raise(trace, out)\n",
|
|
" primal_out, tangent_out = tracer_out.primal, tracer_out.tangent\n",
|
|
" return primal_out, tangent_out"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And with that, we can differentiate!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = 3.0\n",
|
|
"y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))\n",
|
|
"print(sin_deriv_at_3)\n",
|
|
"print(cos(3.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"\n",
|
|
"x, xdot = 3., 1.\n",
|
|
"y, ydot = jvp_v1(f, (x,), (xdot,))\n",
|
|
"print(y)\n",
|
|
"print(ydot)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def deriv(f):\n",
|
|
" return lambda x: jvp_v1(f, (x,), (1.,))[1]\n",
|
|
"\n",
|
|
"print(deriv(sin)(3.))\n",
|
|
"print(deriv(deriv(sin))(3.))\n",
|
|
"print(deriv(deriv(deriv(sin)))(3.))\n",
|
|
"print(deriv(deriv(deriv(deriv(sin))))(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" if x > 0.: # Python control flow\n",
|
|
" return 2. * x\n",
|
|
" else:\n",
|
|
" return x\n",
|
|
"\n",
|
|
"print(deriv(f)(3.))\n",
|
|
"print(deriv(f)(-3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Pytrees and flattening user functions' inputs and outputs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"A limitation with `jvp_v1` is that it assumes the user function accepts arrays\n",
|
|
"as positional arguments and produces a single array as output. What if it\n",
|
|
"produced a list as output? Or accepted nested containers as inputs? It would\n",
|
|
"be a pain to deal with all the possible containers in inputs and outputs at\n",
|
|
"every layer of the stack. Instead, we can wrap the user function so that the\n",
|
|
"wrapped version accepts arrays as inputs and returns a flat list of arrays as\n",
|
|
"output. The wrapper just needs to unflatten its input, call the user function,\n",
|
|
"and flatten the output.\n",
|
|
"\n",
|
|
"Here's how we'd like to write `jvp`, assuming the user always gives us\n",
|
|
"functions that take arrays as inputs and produces a flat list of arrays as\n",
|
|
"outputs:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def jvp_flat(f, primals, tangents):\n",
|
|
" with new_main(JVPTrace) as main:\n",
|
|
" trace = JVPTrace(main)\n",
|
|
" tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]\n",
|
|
" outs = f(*tracers_in)\n",
|
|
" tracers_out = [full_raise(trace, out) for out in outs]\n",
|
|
" primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)\n",
|
|
" return primals_out, tangents_out"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To support user functions that have arbitrary containers in the inputs and\n",
|
|
"outputs, here's how we'd write the user-facing `jvp` wrapper:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def jvp(f, primals, tangents):\n",
|
|
" primals_flat, in_tree = tree_flatten(primals)\n",
|
|
" tangents_flat, in_tree2 = tree_flatten(tangents)\n",
|
|
" if in_tree != in_tree2: raise TypeError\n",
|
|
" f, out_tree = flatten_fun(f, in_tree)\n",
|
|
" primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)\n",
|
|
" primals_out = tree_unflatten(out_tree(), primals_out_flat)\n",
|
|
" tangents_out = tree_unflatten(out_tree(), tangents_out_flat)\n",
|
|
" return primals_out, tangents_out"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice that we had to plumb the tree structure of the user function output\n",
|
|
"back to the caller of `flatten_fun`. That information isn't available until we\n",
|
|
"actually run the user function, so `flatten_fun` just returns a reference to a\n",
|
|
"mutable cell, represented as a thunk. These side-effects are safe because we\n",
|
|
"always run the user function exactly once. (This safe regime is the reason for\n",
|
|
"the \"linear\" name in `linear_util.py`, in the sense of [linear\n",
|
|
"types](https://en.wikipedia.org/wiki/Substructural_type_system).)\n",
|
|
"\n",
|
|
"All that remains is to write `tree_flatten`, `tree_unflatten`, and\n",
|
|
"`flatten_fun`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": [
|
|
"hide-input"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def flatten_fun(f, in_tree):\n",
|
|
" store = Store()\n",
|
|
"\n",
|
|
" def flat_fun(*args_flat):\n",
|
|
" pytree_args = tree_unflatten(in_tree, args_flat)\n",
|
|
" out = f(*pytree_args)\n",
|
|
" out_flat, out_tree = tree_flatten(out)\n",
|
|
" store.set_value(out_tree)\n",
|
|
" return out_flat\n",
|
|
"\n",
|
|
" return flat_fun, store\n",
|
|
"\n",
|
|
"class Empty: pass\n",
|
|
"empty = Empty()\n",
|
|
"\n",
|
|
"class Store:\n",
|
|
" val = empty\n",
|
|
"\n",
|
|
" def set_value(self, val):\n",
|
|
" assert self.val is empty\n",
|
|
" self.val = val\n",
|
|
"\n",
|
|
" def __call__(self):\n",
|
|
" return self.val"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": [
|
|
"hide-input"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import itertools as it\n",
|
|
"from typing import Callable, Type, Hashable, Dict, Iterable, Iterator\n",
|
|
"\n",
|
|
"class NodeType(NamedTuple):\n",
|
|
" name: str\n",
|
|
" to_iterable: Callable\n",
|
|
" from_iterable: Callable\n",
|
|
"\n",
|
|
"def register_pytree_node(ty: Type, to_iter: Callable, from_iter: Callable\n",
|
|
" ) -> None:\n",
|
|
" node_types[ty] = NodeType(str(ty), to_iter, from_iter)\n",
|
|
"\n",
|
|
"node_types: Dict[Type, NodeType] = {}\n",
|
|
"register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))\n",
|
|
"register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs))\n",
|
|
"register_pytree_node(dict,\n",
|
|
" lambda d: map(tuple, unzip2(sorted(d.items()))),\n",
|
|
" lambda keys, vals: dict(zip(keys, vals)))\n",
|
|
"\n",
|
|
"class PyTreeDef(NamedTuple):\n",
|
|
" node_type: NodeType\n",
|
|
" node_metadata: Hashable\n",
|
|
" child_treedefs: Tuple['PyTreeDef', ...]\n",
|
|
"\n",
|
|
"class Leaf: pass\n",
|
|
"leaf = Leaf()\n",
|
|
"\n",
|
|
"def tree_flatten(x: Any) -> Tuple[List[Any], PyTreeDef]:\n",
|
|
" children_iter, treedef = _tree_flatten(x)\n",
|
|
" return list(children_iter), treedef\n",
|
|
"\n",
|
|
"def _tree_flatten(x: Any) -> Tuple[Iterable, PyTreeDef]:\n",
|
|
" node_type = node_types.get(type(x))\n",
|
|
" if node_type:\n",
|
|
" node_metadata, children = node_type.to_iterable(x)\n",
|
|
" children_flat, child_trees = unzip2(map(_tree_flatten, children))\n",
|
|
" flattened = it.chain.from_iterable(children_flat)\n",
|
|
" return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))\n",
|
|
" else:\n",
|
|
" return [x], leaf\n",
|
|
"\n",
|
|
"def tree_unflatten(treedef: PyTreeDef, xs: List[Any]) -> Any:\n",
|
|
" return _tree_unflatten(treedef, iter(xs))\n",
|
|
"\n",
|
|
"def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:\n",
|
|
" if treedef is leaf:\n",
|
|
" return next(xs)\n",
|
|
" else:\n",
|
|
" children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)\n",
|
|
" return treedef.node_type.from_iterable(treedef.node_metadata, children)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With this pytree-handling `jvp` implementation, we can now handle arbitrary\n",
|
|
"input and output containers. That'll come in handy with future transformations\n",
|
|
"too!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = - y + x\n",
|
|
" return {'hi': z, 'there': [x, y]}\n",
|
|
"\n",
|
|
"x, xdot = 3., 1.\n",
|
|
"y, ydot = jvp(f, (x,), (xdot,))\n",
|
|
"print(y)\n",
|
|
"print(ydot)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Vectorized batching with `vmap`\n",
|
|
"\n",
|
|
"First, a couple helper functions, one for producing mapped abstract values\n",
|
|
"from unmapped ones (by removing an axis), and one for moving batch dimensions\n",
|
|
"around:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def mapped_aval(batch_dim, aval):\n",
|
|
" shape = list(aval.shape)\n",
|
|
" del shape[batch_dim]\n",
|
|
" return ShapedArray(tuple(shape), aval.dtype)\n",
|
|
"\n",
|
|
"def move_batch_axis(axis_size, src, dst, x):\n",
|
|
" if src is not_mapped:\n",
|
|
" target_shape = list(np.shape(x))\n",
|
|
" target_shape.insert(dst, axis_size)\n",
|
|
" return broadcast(x, target_shape, [dst])\n",
|
|
" elif src == dst:\n",
|
|
" return x\n",
|
|
" else:\n",
|
|
" return moveaxis(x, src, dst)\n",
|
|
"\n",
|
|
"def moveaxis(x, src: int, dst: int):\n",
|
|
" perm = [i for i in range(np.ndim(x)) if i != src]\n",
|
|
" perm.insert(dst, src)\n",
|
|
" return transpose(x, perm)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `Tracer` for vectorized batching carries a batched value and an optional\n",
|
|
"integer indicating which axis (if any) is the batch axis."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Union\n",
|
|
"\n",
|
|
"class NotMapped: pass\n",
|
|
"not_mapped = NotMapped()\n",
|
|
"\n",
|
|
"BatchAxis = Union[NotMapped, int]\n",
|
|
"\n",
|
|
"class BatchTracer(Tracer):\n",
|
|
" def __init__(self, trace, val, batch_dim: BatchAxis):\n",
|
|
" self._trace = trace\n",
|
|
" self.val = val\n",
|
|
" self.batch_dim = batch_dim\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def aval(self):\n",
|
|
" if self.batch_dim is not_mapped:\n",
|
|
" return get_aval(self.val)\n",
|
|
" else:\n",
|
|
" return mapped_aval(self.batch_dim, get_aval(self.val))\n",
|
|
"\n",
|
|
" def full_lower(self):\n",
|
|
" if self.batch_dim is not_mapped:\n",
|
|
" return full_lower(self.val)\n",
|
|
" else:\n",
|
|
" return self\n",
|
|
"\n",
|
|
"class BatchTrace(Trace):\n",
|
|
" pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)\n",
|
|
" vmap_rule = vmap_rules[primitive]\n",
|
|
" val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)\n",
|
|
" return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def axis_size(self):\n",
|
|
" return self.main.global_data\n",
|
|
"\n",
|
|
"vmap_rules = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Here we've implemented the optional `Tracer.full_lower` method, which lets us\n",
|
|
"peel off a batching tracer if it's not needed because it doesn't represent a\n",
|
|
"batched value.\n",
|
|
"\n",
|
|
"For `BatchTrace`, analogous to `JVPTrace`, the methods `pure` and `lift` just\n",
|
|
"box a value in a `BatchTracer` with the minimal amount of context, which in\n",
|
|
"this case is a `batch_dim` taking the sentinel value `not_mapped`. Notice we\n",
|
|
"use the `MainTrace`'s interpreter-global data field to store the batch axis\n",
|
|
"size.\n",
|
|
"\n",
|
|
"Next we can define batching interpreter rules for each primitive:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from functools import partial\n",
|
|
"\n",
|
|
"def binop_batching_rule(op, axis_size, vals_in, dims_in):\n",
|
|
" (x, y), (x_bdim, y_bdim) = vals_in, dims_in\n",
|
|
" if x_bdim != y_bdim:\n",
|
|
" if x_bdim is not_mapped:\n",
|
|
" x = move_batch_axis(axis_size, x_bdim, y_bdim, x)\n",
|
|
" x_bdim = y_bdim\n",
|
|
" else:\n",
|
|
" y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n",
|
|
" return [op(x, y)], [x_bdim]\n",
|
|
"vmap_rules[add_p] = partial(binop_batching_rule, add)\n",
|
|
"vmap_rules[mul_p] = partial(binop_batching_rule, mul)\n",
|
|
"\n",
|
|
"def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):\n",
|
|
" (x,), (x_bdim,) = vals_in, dims_in\n",
|
|
" return [op(x)], [x_bdim]\n",
|
|
"vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)\n",
|
|
"vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)\n",
|
|
"vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)\n",
|
|
"\n",
|
|
"def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):\n",
|
|
" (x,), (x_bdim,) = vals_in, dims_in\n",
|
|
" new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)\n",
|
|
" out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)\n",
|
|
" return [reduce_sum(x, new_axis)], [out_bdim]\n",
|
|
"vmap_rules[reduce_sum_p] = reduce_sum_batching_rule"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we add a transformation API to kick off the trace:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def vmap_flat(f, in_axes, *args):\n",
|
|
" axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)\n",
|
|
" if ax is not not_mapped}\n",
|
|
" with new_main(BatchTrace, axis_size) as main:\n",
|
|
" trace = BatchTrace(main)\n",
|
|
" tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x\n",
|
|
" for x, ax in zip(args, in_axes)]\n",
|
|
" outs = f(*tracers_in)\n",
|
|
" tracers_out = [full_raise(trace, out) for out in outs]\n",
|
|
" vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)\n",
|
|
" outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)\n",
|
|
" for val_out, bdim in zip(vals_out, bdims_out)]\n",
|
|
" return outs_transposed\n",
|
|
"\n",
|
|
"def vmap(f, in_axes):\n",
|
|
" def batched_f(*args):\n",
|
|
" args_flat, in_tree = tree_flatten(args)\n",
|
|
" in_axes_flat, in_tree2 = tree_flatten(in_axes)\n",
|
|
" if in_tree != in_tree2: raise TypeError\n",
|
|
" f_flat, out_tree = flatten_fun(f, in_tree)\n",
|
|
" outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)\n",
|
|
" return tree_unflatten(out_tree(), outs_flat)\n",
|
|
" return batched_f"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def add_one_to_a_scalar(scalar):\n",
|
|
" assert np.ndim(scalar) == 0\n",
|
|
" return 1 + scalar\n",
|
|
"\n",
|
|
"vector_in = np.arange(3.)\n",
|
|
"vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)\n",
|
|
"\n",
|
|
"print(vector_in)\n",
|
|
"print(vector_out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def jacfwd(f, x):\n",
|
|
" pushfwd = lambda v: jvp(f, (x,), (v,))[1]\n",
|
|
" vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)\n",
|
|
" return vmap(pushfwd, (0,))(vecs_in)\n",
|
|
"\n",
|
|
"def f(x):\n",
|
|
" return sin(x)\n",
|
|
"\n",
|
|
"jacfwd(f, np.arange(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"That's it for `jvp` and `vmap`!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Part 2: Jaxprs\n",
|
|
"\n",
|
|
"The next transformations on the horizon are `jit` for just-in-time\n",
|
|
"compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small\n",
|
|
"wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to\n",
|
|
"carry a little bit of extra context, for both `jit` and `vjp` we need much\n",
|
|
"richer context: we need to represent _programs_. That is, we need jaxprs!\n",
|
|
"\n",
|
|
"Jaxprs are JAX's internal intermediate representation of programs. They are\n",
|
|
"explicitly typed, functional, first-order, and in ANF form. We need a\n",
|
|
"program representation for `jit` because the purpose of `jit` is to stage\n",
|
|
"computation out of Python. For any computation we want to stage out, we need\n",
|
|
"to be able to represent it as data, and build it up as we trace a Python\n",
|
|
"function. Similarly, `vjp` needs a way to represent the computation for the\n",
|
|
"backward pass of reverse-mode autodiff. We use the same jaxpr program\n",
|
|
"representation for both needs.\n",
|
|
"\n",
|
|
"(Building a program representation is the most\n",
|
|
"[free](https://en.wikipedia.org/wiki/Free_object) kind of\n",
|
|
"trace-transformation, and so except for issues around handling native Python\n",
|
|
"control flow, any transformation could be implemented by first tracing to a\n",
|
|
"jaxpr and then interpreting the jaxpr.)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Jaxpr data structures\n",
|
|
"\n",
|
|
"The jaxpr term syntax is roughly:\n",
|
|
"\n",
|
|
"```\n",
|
|
"jaxpr ::=\n",
|
|
" { lambda <binder> , ... .\n",
|
|
" let <eqn>\n",
|
|
" ...\n",
|
|
" in ( <atom> , ... ) }\n",
|
|
"\n",
|
|
"binder ::= <var>:<array_type>\n",
|
|
"var ::= a | b | c | ...\n",
|
|
"atom ::= <var> | <literal>\n",
|
|
"literal ::= <int32> | <int64> | <float32> | <float64>\n",
|
|
"\n",
|
|
"eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...\n",
|
|
"```\n",
|
|
"\n",
|
|
"The syntax of types is:\n",
|
|
"\n",
|
|
"```\n",
|
|
"jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]\n",
|
|
"array_type ::= <dtype>[<shape>]\n",
|
|
"dtype ::= f32 | f64 | i32 | i64\n",
|
|
"shape ::= <int> , ...\n",
|
|
"```\n",
|
|
"\n",
|
|
"How do we represent these as Python data structures? We reuse ShapedArrays to\n",
|
|
"represent types, and we can represent the term syntax with a few Python\n",
|
|
"structs:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Set\n",
|
|
"\n",
|
|
"class Var:\n",
|
|
" aval: ShapedArray\n",
|
|
" def __init__(self, aval): self.aval = aval\n",
|
|
"\n",
|
|
"class Lit:\n",
|
|
" val: Any\n",
|
|
" aval: ShapedArray\n",
|
|
"\n",
|
|
" def __init__(self, val):\n",
|
|
" self.aval = aval = raise_to_shaped(get_aval(val))\n",
|
|
" self.val = np.array(val, aval.dtype)\n",
|
|
"\n",
|
|
"Atom = Union[Var, Lit]\n",
|
|
"\n",
|
|
"class JaxprEqn(NamedTuple):\n",
|
|
" primitive: Primitive\n",
|
|
" inputs: List[Atom]\n",
|
|
" params: Dict[str, Any]\n",
|
|
" out_binders: List[Var]\n",
|
|
"\n",
|
|
"class Jaxpr(NamedTuple):\n",
|
|
" in_binders: List[Var]\n",
|
|
" eqns: List[JaxprEqn]\n",
|
|
" outs: List[Atom]\n",
|
|
"\n",
|
|
" def __hash__(self): return id(self)\n",
|
|
" __eq__ = op.is_\n",
|
|
"\n",
|
|
"def raise_to_shaped(aval):\n",
|
|
" return ShapedArray(aval.shape, aval.dtype)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Type-checking a jaxpr involves checking that there are no unbound variables,\n",
|
|
"that variables are only bound once, and that for each equation the type of\n",
|
|
"the primitive application matches the type of the output binders."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class JaxprType(NamedTuple):\n",
|
|
" in_types: List[ShapedArray]\n",
|
|
" out_types: List[ShapedArray]\n",
|
|
"\n",
|
|
" def __repr__(self):\n",
|
|
" in_types = ', '.join(aval.str_short() for aval in self.in_types)\n",
|
|
" out_types = ', '.join(aval.str_short() for aval in self.out_types)\n",
|
|
" return f'({in_types}) -> ({out_types})'\n",
|
|
"\n",
|
|
"def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:\n",
|
|
" env: Set[Var] = set()\n",
|
|
"\n",
|
|
" for v in jaxpr.in_binders:\n",
|
|
" if v in env: raise TypeError\n",
|
|
" env.add(v)\n",
|
|
"\n",
|
|
" for eqn in jaxpr.eqns:\n",
|
|
" in_types = [typecheck_atom(env, x) for x in eqn.inputs]\n",
|
|
" out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)\n",
|
|
" for out_binder, out_type in zip(eqn.out_binders, out_types):\n",
|
|
" if not out_type == out_binder.aval: raise TypeError\n",
|
|
" for out_binder in eqn.out_binders:\n",
|
|
" if out_binder in env: raise TypeError\n",
|
|
" env.add(out_binder)\n",
|
|
"\n",
|
|
" in_types = [v.aval for v in jaxpr.in_binders]\n",
|
|
" out_types = [typecheck_atom(env, x) for x in jaxpr.outs]\n",
|
|
" return JaxprType(in_types, out_types)\n",
|
|
"\n",
|
|
"def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray:\n",
|
|
" if isinstance(x, Var):\n",
|
|
" if x not in env: raise TypeError(\"unbound variable\")\n",
|
|
" return x.aval\n",
|
|
" elif isinstance(x, Lit):\n",
|
|
" return raise_to_shaped(get_aval(x.val))\n",
|
|
" else:\n",
|
|
" assert False"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can apply the function represented by a jaxpr to arguments with a simple\n",
|
|
"interpreter."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]:\n",
|
|
" env: Dict[Var, Any] = {}\n",
|
|
"\n",
|
|
" def read(x: Atom) -> Any:\n",
|
|
" return env[x] if type(x) is Var else x.val\n",
|
|
"\n",
|
|
" def write(v: Var, val: Any) -> None:\n",
|
|
" assert v not in env # single-assignment\n",
|
|
" env[v] = val\n",
|
|
"\n",
|
|
" map(write, jaxpr.in_binders, args)\n",
|
|
" for eqn in jaxpr.eqns:\n",
|
|
" in_vals = map(read, eqn.inputs)\n",
|
|
" outs = bind(eqn.primitive, *in_vals, **eqn.params)\n",
|
|
" map(write, eqn.out_binders, outs)\n",
|
|
" return map(read, jaxpr.outs)\n",
|
|
"\n",
|
|
"def jaxpr_as_fun(jaxpr: Jaxpr):\n",
|
|
" return lambda *args: eval_jaxpr(jaxpr, args)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"By using `bind` in the interpreter, this interpreter itself is traceable."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Building jaxprs with tracing\n",
|
|
"\n",
|
|
"Now that we have jaxprs as a data structure, we need ways to produce these\n",
|
|
"from tracing Python code. In general there are two variants of how we trace to\n",
|
|
"a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one\n",
|
|
"used by `jit`, which is also used by control flow primitives like `lax.cond`,\n",
|
|
"`lax.while_loop`, and `lax.scan`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:\n",
|
|
" assert 0 <= n <= len(lst)\n",
|
|
" return lst[:n], lst[n:]\n",
|
|
"\n",
|
|
"def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:\n",
|
|
" assert len(bs) == len(l)\n",
|
|
" lists = lst1, lst2 = [], []\n",
|
|
" for b, x in zip(bs, l):\n",
|
|
" lists[b].append(x)\n",
|
|
" return lst1, lst2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NB: the analogous class in JAX is called 'DynamicJaxprTracer'\n",
|
|
"class JaxprTracer(Tracer):\n",
|
|
" __slots__ = ['aval']\n",
|
|
" aval: ShapedArray\n",
|
|
"\n",
|
|
" def __init__(self, trace, aval):\n",
|
|
" self._trace = trace\n",
|
|
" self.aval = aval\n",
|
|
"\n",
|
|
"# NB: the analogous class in JAX is called 'DynamicJaxprTrace'\n",
|
|
"class JaxprTrace(Trace):\n",
|
|
" def new_arg(self, aval: ShapedArray) -> JaxprTracer:\n",
|
|
" aval = raise_to_shaped(aval)\n",
|
|
" tracer = self.builder.new_tracer(self, aval)\n",
|
|
" self.builder.tracer_to_var[id(tracer)] = Var(aval)\n",
|
|
" return tracer\n",
|
|
"\n",
|
|
" def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:\n",
|
|
" tracer = self.builder.const_tracers.get(id(val))\n",
|
|
" if tracer is None:\n",
|
|
" tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))\n",
|
|
" self.builder.add_const(tracer, val)\n",
|
|
" return tracer\n",
|
|
" pure = lift = get_or_make_const_tracer\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" avals_in = [t.aval for t in tracers]\n",
|
|
" avals_out = abstract_eval_rules[primitive](*avals_in, **params)\n",
|
|
" out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]\n",
|
|
" inputs = [self.builder.getvar(t) for t in tracers]\n",
|
|
" outvars = [self.builder.add_var(t) for t in out_tracers]\n",
|
|
" self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))\n",
|
|
" return out_tracers\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def builder(self):\n",
|
|
" return self.main.global_data\n",
|
|
"\n",
|
|
"# NB: in JAX, we instead attach abstract eval rules to Primitive instances\n",
|
|
"abstract_eval_rules = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice that we keep as interpreter-global data a builder object, which keeps\n",
|
|
"track of variables, constants, and eqns as we build up the jaxpr."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class JaxprBuilder:\n",
|
|
" eqns: List[JaxprEqn]\n",
|
|
" tracer_to_var: Dict[int, Var]\n",
|
|
" const_tracers: Dict[int, JaxprTracer]\n",
|
|
" constvals: Dict[Var, Any]\n",
|
|
" tracers: List[JaxprTracer]\n",
|
|
"\n",
|
|
" def __init__(self):\n",
|
|
" self.eqns = []\n",
|
|
" self.tracer_to_var = {}\n",
|
|
" self.const_tracers = {}\n",
|
|
" self.constvals = {}\n",
|
|
" self.tracers = []\n",
|
|
"\n",
|
|
" def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:\n",
|
|
" tracer = JaxprTracer(trace, aval)\n",
|
|
" self.tracers.append(tracer)\n",
|
|
" return tracer\n",
|
|
"\n",
|
|
" def add_eqn(self, eqn: JaxprEqn) -> None:\n",
|
|
" self.eqns.append(eqn)\n",
|
|
"\n",
|
|
" def add_var(self, tracer: JaxprTracer) -> Var:\n",
|
|
" assert id(tracer) not in self.tracer_to_var\n",
|
|
" var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)\n",
|
|
" return var\n",
|
|
"\n",
|
|
" def getvar(self, tracer: JaxprTracer) -> Var:\n",
|
|
" var = self.tracer_to_var.get(id(tracer))\n",
|
|
" assert var is not None\n",
|
|
" return var\n",
|
|
"\n",
|
|
" def add_const(self, tracer: JaxprTracer, val: Any) -> Var:\n",
|
|
" var = self.add_var(tracer)\n",
|
|
" self.const_tracers[id(val)] = tracer\n",
|
|
" self.constvals[var] = val\n",
|
|
" return var\n",
|
|
"\n",
|
|
" def build(self, in_tracers: List[JaxprTracer], out_tracers: List[JaxprTracer]\n",
|
|
" ) -> Tuple[Jaxpr, List[Any]]:\n",
|
|
" constvars, constvals = unzip2(self.constvals.items())\n",
|
|
" t2v = lambda t: self.tracer_to_var[id(t)]\n",
|
|
" in_binders = constvars + [t2v(t) for t in in_tracers]\n",
|
|
" out_vars = [t2v(t) for t in out_tracers]\n",
|
|
" jaxpr = Jaxpr(in_binders, self.eqns, out_vars)\n",
|
|
" typecheck_jaxpr(jaxpr)\n",
|
|
" jaxpr, constvals = _inline_literals(jaxpr, constvals)\n",
|
|
" return jaxpr, constvals"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def _inline_literals(jaxpr: Jaxpr, consts: List[Any]) -> Tuple[Jaxpr, List[Any]]:\n",
|
|
" const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))\n",
|
|
" scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]\n",
|
|
" new_const_binders, lit_binders = partition_list(scalars, const_binders)\n",
|
|
" new_consts, lit_vals = partition_list(scalars, consts)\n",
|
|
" literals = dict(zip(lit_binders, map(Lit, lit_vals)))\n",
|
|
" new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],\n",
|
|
" eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]\n",
|
|
" new_outs = [literals.get(x, x) for x in jaxpr.outs]\n",
|
|
" new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)\n",
|
|
" typecheck_jaxpr(new_jaxpr)\n",
|
|
" return new_jaxpr, new_consts"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The rules we need for `JaxprTrace.process_primitive` are essentially typing\n",
|
|
"rules for primitive applications: given the primitive, its parameters, and\n",
|
|
"types for the inputs, the rule must produce a type for the output, which is\n",
|
|
"then packaged with the output `JaxprTracer`. We can use abstract evaluation\n",
|
|
"rules for this same purpose, even though they can be more general (since\n",
|
|
"abstract evaluation rules must accept ConcreteArray inputs, and since they\n",
|
|
"need only return an upper bound on the set of possible outputs, they can\n",
|
|
"produce ConcreteArray outputs as well). We'll reuse these abstract evaluation\n",
|
|
"rules for the other jaxpr-producing trace machinery, where the potential extra\n",
|
|
"generality is useful."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:\n",
|
|
" if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):\n",
|
|
" raise TypeError\n",
|
|
" if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError\n",
|
|
" return [ShapedArray(x.shape, x.dtype)]\n",
|
|
"\n",
|
|
"abstract_eval_rules[add_p] = binop_abstract_eval\n",
|
|
"abstract_eval_rules[mul_p] = binop_abstract_eval\n",
|
|
"\n",
|
|
"def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:\n",
|
|
" if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):\n",
|
|
" raise TypeError\n",
|
|
" if x.shape != y.shape: raise TypeError\n",
|
|
" return [ShapedArray(x.shape, np.dtype('bool'))]\n",
|
|
"abstract_eval_rules[greater_p] = compare_abstract_eval\n",
|
|
"abstract_eval_rules[less_p] = compare_abstract_eval\n",
|
|
"\n",
|
|
"def vectorized_unop_abstract_eval(x: ShapedArray) -> List[ShapedArray]:\n",
|
|
" return [ShapedArray(x.shape, x.dtype)]\n",
|
|
"\n",
|
|
"abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval\n",
|
|
"abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval\n",
|
|
"abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval\n",
|
|
"\n",
|
|
"def reduce_sum_abstract_eval(x: ShapedArray, *, axis: Tuple[int, ...]\n",
|
|
" ) -> List[ShapedArray]:\n",
|
|
" axis_ = set(axis)\n",
|
|
" new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]\n",
|
|
" return [ShapedArray(tuple(new_shape), x.dtype)]\n",
|
|
"abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval\n",
|
|
"\n",
|
|
"def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],\n",
|
|
" axes: Sequence[int]) -> List[ShapedArray]:\n",
|
|
" return [ShapedArray(tuple(shape), x.dtype)]\n",
|
|
"abstract_eval_rules[broadcast_p] = broadcast_abstract_eval"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To check our implementation of jaxprs, we can add a `make_jaxpr`\n",
|
|
"transformation and a pretty-printer:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from functools import lru_cache\n",
|
|
"\n",
|
|
"@lru_cache() # ShapedArrays are hashable\n",
|
|
"def make_jaxpr_v1(f, *avals_in):\n",
|
|
" avals_in, in_tree = tree_flatten(avals_in)\n",
|
|
" f, out_tree = flatten_fun(f, in_tree)\n",
|
|
"\n",
|
|
" builder = JaxprBuilder()\n",
|
|
" with new_main(JaxprTrace, builder) as main:\n",
|
|
" trace = JaxprTrace(main)\n",
|
|
" tracers_in = [trace.new_arg(aval) for aval in avals_in]\n",
|
|
" outs = f(*tracers_in)\n",
|
|
" tracers_out = [full_raise(trace, out) for out in outs]\n",
|
|
" jaxpr, consts = builder.build(tracers_in, tracers_out)\n",
|
|
" return jaxpr, consts, out_tree()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": [
|
|
"hide-input"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import DefaultDict\n",
|
|
"from collections import defaultdict\n",
|
|
"import string\n",
|
|
"\n",
|
|
"class PPrint:\n",
|
|
" lines: List[Tuple[int, str]]\n",
|
|
"\n",
|
|
" def __init__(self, lines):\n",
|
|
" self.lines = lines\n",
|
|
"\n",
|
|
" def indent(self, indent: int) -> 'PPrint':\n",
|
|
" return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])\n",
|
|
"\n",
|
|
" def __add__(self, rhs: 'PPrint') -> 'PPrint':\n",
|
|
" return PPrint(self.lines + rhs.lines)\n",
|
|
"\n",
|
|
" def __rshift__(self, rhs: 'PPrint') -> 'PPrint':\n",
|
|
" if not rhs.lines: return self\n",
|
|
" if not self.lines: return rhs\n",
|
|
" indent, s = self.lines[-1]\n",
|
|
" indented_block = rhs.indent(indent + len(s))\n",
|
|
" common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]\n",
|
|
" return PPrint(self.lines[:-1]\n",
|
|
" + [(indent, common_line)]\n",
|
|
" + indented_block.lines[1:])\n",
|
|
"\n",
|
|
" def __str__(self) -> str:\n",
|
|
" return '\\n'.join(' ' * indent + s for indent, s in self.lines)\n",
|
|
"\n",
|
|
"def pp(s: Any) -> PPrint:\n",
|
|
" return PPrint([(0, line) for line in str(s).splitlines()])\n",
|
|
"\n",
|
|
"def vcat(ps: List[PPrint]) -> PPrint:\n",
|
|
" return sum(ps, pp(''))\n",
|
|
"\n",
|
|
"def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:\n",
|
|
" namegen = (''.join(s) for r in it.count(1)\n",
|
|
" for s in it.permutations(string.ascii_lowercase, r))\n",
|
|
" names = defaultdict(lambda: next(namegen))\n",
|
|
" in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)\n",
|
|
" eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])\n",
|
|
" outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)\n",
|
|
" for v in jaxpr.outs)\n",
|
|
" return (pp(f'{{ lambda {in_binders} .') +\n",
|
|
" ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))\n",
|
|
"\n",
|
|
"def var_str(names: DefaultDict[Var, str], v: Var) -> str:\n",
|
|
" return f'{names[v]}:{v.aval.str_short()}'\n",
|
|
"\n",
|
|
"def pp_eqn(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
|
|
" rule = pp_rules.get(eqn.primitive)\n",
|
|
" if rule:\n",
|
|
" return rule(names, eqn)\n",
|
|
" else:\n",
|
|
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
|
|
" rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
|
|
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
|
|
" for x in eqn.inputs)))\n",
|
|
" return lhs >> pp(' = ') >> rhs\n",
|
|
"\n",
|
|
"def pp_params(params: Dict[str, Any]) -> PPrint:\n",
|
|
" items = sorted(params.items())\n",
|
|
" if items:\n",
|
|
" return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')\n",
|
|
" else:\n",
|
|
" return pp(' ')\n",
|
|
"\n",
|
|
"Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))\n",
|
|
"pp_rules: Dict[Primitive, Callable[..., PPrint]] = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))\n",
|
|
"print(jaxpr)\n",
|
|
"print(typecheck_jaxpr(jaxpr))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"But there's a limitation here: because of how `find_top_trace` operates by\n",
|
|
"data dependence, `make_jaxpr_v1` can't stage out all the primitive operations\n",
|
|
"performed by the Python callable it's given. For example:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))\n",
|
|
"print(jaxpr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This is precisely the issue that\n",
|
|
"[omnistaging](https://github.com/google/jax/pull/3370) fixed.\n",
|
|
"We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n",
|
|
"applied, regardless of whether any inputs to `bind` are boxed in corresponding\n",
|
|
"`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n",
|
|
"global defined in Part 1:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@contextmanager\n",
|
|
"def new_dynamic(main: MainTrace):\n",
|
|
" global dynamic_trace\n",
|
|
" prev_dynamic_trace, dynamic_trace = dynamic_trace, main\n",
|
|
" try:\n",
|
|
" yield\n",
|
|
" finally:\n",
|
|
" dynamic_trace = prev_dynamic_trace\n",
|
|
"\n",
|
|
"@lru_cache()\n",
|
|
"def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n",
|
|
" ) -> Tuple[Jaxpr, List[Any], PyTreeDef]:\n",
|
|
" avals_in, in_tree = tree_flatten(avals_in)\n",
|
|
" f, out_tree = flatten_fun(f, in_tree)\n",
|
|
"\n",
|
|
" builder = JaxprBuilder()\n",
|
|
" with new_main(JaxprTrace, builder) as main:\n",
|
|
" with new_dynamic(main):\n",
|
|
" trace = JaxprTrace(main)\n",
|
|
" tracers_in = [trace.new_arg(aval) for aval in avals_in]\n",
|
|
" outs = f(*tracers_in)\n",
|
|
" tracers_out = [full_raise(trace, out) for out in outs]\n",
|
|
" jaxpr, consts = builder.build(tracers_in, tracers_out)\n",
|
|
" return jaxpr, consts, out_tree()\n",
|
|
"\n",
|
|
"jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))\n",
|
|
"print(jaxpr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Using `dynamic_trace` this way is conceptually the same as stashing the\n",
|
|
"current interpreter stack and starting a new one with the `JaxprTrace` at the\n",
|
|
"bottom. That is, no interpreters lower in the stack than the `dynamic_trace`\n",
|
|
"are applied (since `JaxprTrace.process_primitive` doesn't call `bind`), though\n",
|
|
"if the Python callable being traced to a jaxpr itself uses transformations\n",
|
|
"then those can be pushed onto the interpreter stack above the `JaxprTrace`.\n",
|
|
"But temporarily stashing the interpreter stack would break up the system\n",
|
|
"state. The `dynamic_trace` tag achieves the same goals while keeping the\n",
|
|
"system state simpler."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"That's it for jaxprs! With jaxprs in hand, we can implement the remaining\n",
|
|
"major JAX features."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Part 3: `jit`, simplified\n",
|
|
"\n",
|
|
"While `jit` has a transformation-like API in that it accepts a Python callable\n",
|
|
"as an argument, under the hood it's really a higher-order primitive rather\n",
|
|
"than a transformation. A primitive is _higher-order_ when it's parameterized\n",
|
|
"by a function."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### On-the-fly (\"final style\") and staged (\"initial style\") processing\n",
|
|
"\n",
|
|
"There are two options for how to handle higher-order primitives. Each requires\n",
|
|
"a different approach to tracing and engenders different tradeoffs:\n",
|
|
"1. **On-the-fly processing, where `bind` takes a Python callable as an\n",
|
|
" argument.** We defer forming a jaxpr until as late as possible, namely\n",
|
|
" until we're running the final interpreter at the bottom of the interpreter\n",
|
|
" stack. That way we can swap a `JaxprTrace` in at the bottom of the\n",
|
|
" interpreter stack and thus stage out rather than execute all primitive\n",
|
|
" operations. With this approach, transformations in the stack get applied as\n",
|
|
" we execute the Python callable as usual. This approach can be very tricky\n",
|
|
" to implement, but it's as general as possible because it allows\n",
|
|
" higher-order primitives not to raise the abstraction level of their\n",
|
|
" arguments and thus allows data-dependent Python control flow. We refer to\n",
|
|
" this approach as using a \"final-style higher-order primitive\" employing the\n",
|
|
" discharge-at-tracing-time \"final-style transformations\" we've used so far.\n",
|
|
"2. **Staged processing, where `bind` takes a jaxpr as an argument.** Before we\n",
|
|
" call `bind`, in the primitive wrapper we can just use `make_jaxpr` to form\n",
|
|
" a jaxpr up-front and be done with the Python callable entirely. In this\n",
|
|
" case, `make_jaxpr` puts its `JaxprTrace` at the top of the interpreter\n",
|
|
" stack, and no transformations lower in the stack, which might enter via\n",
|
|
" closed-over Tracers, are applied to the Python callable as we trace it.\n",
|
|
" (Transformations applied within the Python callable are applied as usual,\n",
|
|
" being added to the stack above the JaxprTrace.) Instead, the\n",
|
|
" transformations lower in the stack are later applied to the call primitive,\n",
|
|
" and the call primitive's rules must then transform the jaxpr itself.\n",
|
|
" Because we trace to a jaxpr up-front, this approach can't support\n",
|
|
" data-dependent Python control flow, but it is more straightforward to\n",
|
|
" implement. We refer to this kind of higher-order primitive as an\n",
|
|
" \"initial-style higher-order primitive\", and say that its jaxpr-processing\n",
|
|
" transformation rules are \"initial-style transformation rules.\"\n",
|
|
"\n",
|
|
"The latter approach fits for `jit` because we don't need to support\n",
|
|
"data-dependent Python control flow in the user-provided Python callable, as\n",
|
|
"the whole purpose of `jit` is to stage computation out of Python to be\n",
|
|
"executed by XLA. (In contrast, `custom_jvp` is a higher-order primitive in\n",
|
|
"which we want to support data-dependent Python control flow.)\n",
|
|
"\n",
|
|
"Historically, we started using the \"initial-style\" and \"final-style\"\n",
|
|
"terminology after reading the [typed tagless final\n",
|
|
"interpreters](http://okmij.org/ftp/tagless-final/index.html) paper, and\n",
|
|
"jokingly referring to JAX as an implementation of \"untyped tagful final\n",
|
|
"interpreters.\" We don't claim to carry over (or understand) any deep meaning\n",
|
|
"behind these terms; we loosely use \"initial style\" to mean \"build an AST and\n",
|
|
"then transform it\", and we use \"final style\" to mean \"transform as we trace.\"\n",
|
|
"But it's just imprecise yet sticky jargon."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With the initial-style approach, here's the user-facing `jit` wrapper:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def jit(f):\n",
|
|
" def f_jitted(*args):\n",
|
|
" avals_in = [raise_to_shaped(get_aval(x)) for x in args]\n",
|
|
" jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)\n",
|
|
" outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))\n",
|
|
" return tree_unflatten(out_tree, outs)\n",
|
|
" return f_jitted\n",
|
|
"\n",
|
|
"xla_call_p = Primitive('xla_call')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With any new primitive, we need to give it transformation rules, starting with\n",
|
|
"its evaluation rule. When we evaluate an application of the `xla_call`\n",
|
|
"primitive, we want to stage out out the computation to XLA. That involves\n",
|
|
"translating the jaxpr to an XLA HLO program, transferring the argument values\n",
|
|
"to the XLA device, executing the XLA program, and transferring back the\n",
|
|
"results. We'll cache the XLA HLO compilation so that for each `jit`ted\n",
|
|
"function it only needs to be performed once per argument shape and dtype\n",
|
|
"signature.\n",
|
|
"\n",
|
|
"First, some utilities."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class IDHashable:\n",
|
|
" val: Any\n",
|
|
"\n",
|
|
" def __init__(self, val):\n",
|
|
" self.val = val\n",
|
|
"\n",
|
|
" def __hash__(self) -> int:\n",
|
|
" return id(self.val)\n",
|
|
"\n",
|
|
" def __eq__(self, other):\n",
|
|
" return type(other) is IDHashable and id(self.val) == id(other.val)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next, we'll define the evaluation rule for `xla_call`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jax._src.lib import xla_bridge as xb\n",
|
|
"from jax._src.lib import xla_client as xc\n",
|
|
"xe = xc._xla\n",
|
|
"xops = xc._xla.ops\n",
|
|
"\n",
|
|
"def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):\n",
|
|
" consts, args = args[:num_consts], args[num_consts:]\n",
|
|
" hashable_consts = tuple(map(IDHashable, consts))\n",
|
|
" execute = xla_callable(IDHashable(jaxpr), hashable_consts)\n",
|
|
" return execute(*args)\n",
|
|
"impl_rules[xla_call_p] = xla_call_impl\n",
|
|
"\n",
|
|
"@lru_cache()\n",
|
|
"def xla_callable(hashable_jaxpr: IDHashable,\n",
|
|
" hashable_consts: Tuple[IDHashable, ...]):\n",
|
|
" jaxpr: Jaxpr = hashable_jaxpr.val\n",
|
|
" typecheck_jaxpr(jaxpr)\n",
|
|
" consts = [x.val for x in hashable_consts]\n",
|
|
" in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]\n",
|
|
" c = xc.XlaBuilder('xla_call')\n",
|
|
" xla_consts = _xla_consts(c, consts)\n",
|
|
" xla_params = _xla_params(c, in_avals)\n",
|
|
" outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)\n",
|
|
" out = xops.Tuple(c, outs)\n",
|
|
" compiled = xb.get_backend(None).compile(c.build(out))\n",
|
|
" return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n",
|
|
"\n",
|
|
"def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]:\n",
|
|
" unique_consts = {id(cnst): cnst for cnst in consts}\n",
|
|
" xla_consts = {\n",
|
|
" id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()}\n",
|
|
" return [xla_consts[id(cnst)] for cnst in consts]\n",
|
|
"\n",
|
|
"def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:\n",
|
|
" return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n",
|
|
"\n",
|
|
"def _xla_shape(aval: ShapedArray) -> xe.Shape:\n",
|
|
" return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The main action is in `xla_callable`, which compiles a jaxpr into an XLA HLO\n",
|
|
"program using `jaxpr_subcomp`, then returns a callable which executes the\n",
|
|
"compiled program:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]\n",
|
|
" ) -> xe.XlaOp:\n",
|
|
" env: Dict[Var, xe.XlaOp] = {}\n",
|
|
"\n",
|
|
" def read(x: Atom) -> xe.XlaOp:\n",
|
|
" return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))\n",
|
|
"\n",
|
|
" def write(v: Var, val: xe.XlaOp) -> None:\n",
|
|
" env[v] = val\n",
|
|
"\n",
|
|
" map(write, jaxpr.in_binders, args)\n",
|
|
" for eqn in jaxpr.eqns:\n",
|
|
" in_avals = [x.aval for x in eqn.inputs]\n",
|
|
" in_vals = map(read, eqn.inputs)\n",
|
|
" rule = xla_translations[eqn.primitive]\n",
|
|
" out_vals = rule(c, in_avals, in_vals, **eqn.params)\n",
|
|
" map(write, eqn.out_binders, out_vals)\n",
|
|
" return map(read, jaxpr.outs)\n",
|
|
"\n",
|
|
"def execute_compiled(compiled, out_avals, *args):\n",
|
|
" input_bufs = [input_handlers[type(x)](x) for x in args]\n",
|
|
" out_bufs = compiled.execute(input_bufs)\n",
|
|
" return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]\n",
|
|
"\n",
|
|
"default_input_handler = xb.get_backend(None).buffer_from_pyval\n",
|
|
"input_handlers = {ty: default_input_handler for ty in\n",
|
|
" [bool, int, float, np.ndarray, np.float64, np.float32]}\n",
|
|
"\n",
|
|
"def handle_result(aval: ShapedArray, buf):\n",
|
|
" del aval # Unused for now\n",
|
|
" return np.asarray(buf)\n",
|
|
"\n",
|
|
"xla_translations = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice that `jaxpr_subcomp` has the structure of a simple interpreter. That's\n",
|
|
"a common pattern: the way we process jaxprs is usually with an interpreter.\n",
|
|
"And as with any interpreter, we need an interpretation rule for each\n",
|
|
"primitive:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def direct_translation(op, c, in_avals, in_vals):\n",
|
|
" del c, in_avals\n",
|
|
" return [op(*in_vals)]\n",
|
|
"\n",
|
|
"xla_translations[add_p] = partial(direct_translation, xops.Add)\n",
|
|
"xla_translations[mul_p] = partial(direct_translation, xops.Mul)\n",
|
|
"xla_translations[neg_p] = partial(direct_translation, xops.Neg)\n",
|
|
"xla_translations[sin_p] = partial(direct_translation, xops.Sin)\n",
|
|
"xla_translations[cos_p] = partial(direct_translation, xops.Cos)\n",
|
|
"xla_translations[greater_p] = partial(direct_translation, xops.Gt)\n",
|
|
"xla_translations[less_p] = partial(direct_translation, xops.Lt)\n",
|
|
"\n",
|
|
"def reduce_sum_translation(c, in_avals, in_vals, *, axis):\n",
|
|
" (x_aval,), (x,) = in_avals, in_vals\n",
|
|
" zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))\n",
|
|
" subc = xc.XlaBuilder('add')\n",
|
|
" shape = _xla_shape(ShapedArray((), x_aval.dtype))\n",
|
|
" xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))\n",
|
|
" return [xops.Reduce(c, [x], [zero], subc.build(), axis)]\n",
|
|
"xla_translations[reduce_sum_p] = reduce_sum_translation\n",
|
|
"\n",
|
|
"def broadcast_translation(c, in_avals, in_vals, *, shape, axes):\n",
|
|
" x, = in_vals\n",
|
|
" dims_complement = [i for i in range(len(shape)) if i not in axes]\n",
|
|
" return [xops.BroadcastInDim(x, shape, dims_complement)]\n",
|
|
"xla_translations[broadcast_p] = broadcast_translation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With that, we can now use `jit` to stage out, compile, and execute programs\n",
|
|
"with XLA!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x, y):\n",
|
|
" print('tracing!')\n",
|
|
" return sin(x) * cos(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"z = f(3., 4.) # 'tracing!' prints the first time\n",
|
|
"print(z)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit!\n",
|
|
"print(z)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x):\n",
|
|
" return reduce_sum(x, axis=0)\n",
|
|
"\n",
|
|
"print(f(np.array([1., 2., 3.])))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"\n",
|
|
"def deriv(f):\n",
|
|
" return lambda x: jvp(f, (x,), (1.,))[1]\n",
|
|
"\n",
|
|
"print( deriv(deriv(f))(3.))\n",
|
|
"print(jit(deriv(deriv(f)))(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Instead of implementing `jit` to first trace to a jaxpr and then to lower the\n",
|
|
"jaxpr to XLA HLO, it might appear that we could have skipped the jaxpr step\n",
|
|
"and just lowered to HLO while tracing. That is, perhaps we could have instead\n",
|
|
"implemented `jit` with a `Trace` and `Tracer` that appended to the XLA HLO\n",
|
|
"graph incrementally on each primitive bind. That's correct for now, but won't\n",
|
|
"be possible when we introduce compiled SPMD computations because there we must\n",
|
|
"know the number of replicas needed before compiling the program."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We haven't yet defined any transformation rules for `xla_call_p` other than\n",
|
|
"its evaluation rule. That is, we can't yet do `vmap`-of-`jit` or\n",
|
|
"`jvp`-of-`jit` or even `jit`-of`-jit`. Instead `jit` has to be at the \"top\n",
|
|
"level.\" Let's fix that!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):\n",
|
|
" del num_consts # Unused\n",
|
|
" new_jaxpr, new_consts = jvp_jaxpr(jaxpr)\n",
|
|
" outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,\n",
|
|
" num_consts=len(new_consts))\n",
|
|
" n = len(outs) // 2\n",
|
|
" primals_out, tangents_out = outs[:n], outs[n:]\n",
|
|
" return primals_out, tangents_out\n",
|
|
"jvp_rules[xla_call_p] = xla_call_jvp_rule\n",
|
|
"\n",
|
|
"@lru_cache()\n",
|
|
"def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:\n",
|
|
" def jvp_traceable(*primals_and_tangents):\n",
|
|
" n = len(primals_and_tangents) // 2\n",
|
|
" primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]\n",
|
|
" return jvp(jaxpr_as_fun(jaxpr), primals, tangents)\n",
|
|
"\n",
|
|
" in_avals = [v.aval for v in jaxpr.in_binders]\n",
|
|
" new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)\n",
|
|
" return new_jaxpr, new_consts"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):\n",
|
|
" del num_consts # Unused\n",
|
|
" new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))\n",
|
|
" outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,\n",
|
|
" num_consts=len(new_consts))\n",
|
|
" return outs, [0] * len(outs)\n",
|
|
"vmap_rules[xla_call_p] = xla_call_vmap_rule\n",
|
|
"\n",
|
|
"@lru_cache()\n",
|
|
"def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis, ...]\n",
|
|
" ) -> Tuple[Jaxpr, List[Any]]:\n",
|
|
" vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n",
|
|
" in_avals = [unmapped_aval(axis_size, d, v.aval)\n",
|
|
" for v, d in zip(jaxpr.in_binders, bdims_in)]\n",
|
|
" new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)\n",
|
|
" return new_jaxpr, new_consts\n",
|
|
"\n",
|
|
"def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray\n",
|
|
" ) -> ShapedArray:\n",
|
|
" if batch_dim is not_mapped:\n",
|
|
" return aval\n",
|
|
" else:\n",
|
|
" shape = list(aval.shape)\n",
|
|
" shape.insert(batch_dim, axis_size)\n",
|
|
" return ShapedArray(tuple(shape), aval.dtype)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):\n",
|
|
" del num_consts # Unused\n",
|
|
" jaxpr_type = typecheck_jaxpr(jaxpr)\n",
|
|
" if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):\n",
|
|
" raise TypeError\n",
|
|
" return jaxpr_type.out_types\n",
|
|
"abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule\n",
|
|
"\n",
|
|
"def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):\n",
|
|
" del num_consts # Only used at top-level.\n",
|
|
" # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.\n",
|
|
" subc = xc.XlaBuilder('inner xla_call')\n",
|
|
" xla_params = _xla_params(subc, in_avals)\n",
|
|
" outs = jaxpr_subcomp(subc, jaxpr, xla_params)\n",
|
|
" subc = subc.build(xops.Tuple(subc, outs))\n",
|
|
" return destructure_tuple(c, xops.Call(c, subc, in_vals))\n",
|
|
"xla_translations[xla_call_p] = xla_call_translation\n",
|
|
"\n",
|
|
"def destructure_tuple(c, tup):\n",
|
|
" num_elements = len(c.get_shape(tup).tuple_shapes())\n",
|
|
" return [xops.GetTupleElement(tup, i) for i in range(num_elements)]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x):\n",
|
|
" print('tracing!')\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"\n",
|
|
"x, xdot = 3., 1.\n",
|
|
"y, ydot = jvp(f, (x,), (xdot,))\n",
|
|
"print(y)\n",
|
|
"print(ydot)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ys = vmap(f, (0,))(np.arange(3.))\n",
|
|
"print(ys)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"One piece missing is device memory persistence for arrays. That is, we've\n",
|
|
"defined `handle_result` to transfer results back to CPU memory as NumPy\n",
|
|
"arrays, but it's often preferable to avoid transferring results just to\n",
|
|
"transfer them back for the next operation. We can do that by introducing a\n",
|
|
"`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type\n",
|
|
"`numpy.ndarray`s:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def handle_result(aval: ShapedArray, buf): # noqa: F811\n",
|
|
" return DeviceArray(aval, buf)\n",
|
|
"\n",
|
|
"class DeviceArray:\n",
|
|
" buf: Any\n",
|
|
" aval: ShapedArray\n",
|
|
"\n",
|
|
" def __init__(self, aval, buf):\n",
|
|
" self.aval = aval\n",
|
|
" self.buf = buf\n",
|
|
"\n",
|
|
" dtype = property(lambda self: self.aval.dtype)\n",
|
|
" shape = property(lambda self: self.aval.shape)\n",
|
|
" ndim = property(lambda self: self.aval.ndim)\n",
|
|
"\n",
|
|
" def __array__(self): return np.asarray(self.buf)\n",
|
|
" def __repr__(self): return repr(np.asarray(self.buf))\n",
|
|
" def __str__(self): return str(np.asarray(self.buf))\n",
|
|
"\n",
|
|
" _neg = staticmethod(neg)\n",
|
|
" _add = staticmethod(add)\n",
|
|
" _radd = staticmethod(add)\n",
|
|
" _mul = staticmethod(mul)\n",
|
|
" _rmul = staticmethod(mul)\n",
|
|
" _gt = staticmethod(greater)\n",
|
|
" _lt = staticmethod(less)\n",
|
|
"input_handlers[DeviceArray] = lambda x: x.buf\n",
|
|
"\n",
|
|
"jax_types.add(DeviceArray)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"\n",
|
|
"x, xdot = 3., 1.\n",
|
|
"y, ydot = jvp(f, (x,), (xdot,))\n",
|
|
"print(y)\n",
|
|
"print(ydot)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": [
|
|
"hide-input"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
|
|
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
|
|
" params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}\n",
|
|
" rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>\n",
|
|
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
|
|
" for x in eqn.inputs)))\n",
|
|
" return vcat([lhs >> pp(' = ') >> rhs,\n",
|
|
" pp_jaxpr(eqn.params['jaxpr']).indent(2)])\n",
|
|
"pp_rules[xla_call_p] = pprint_xla_call"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Part 4: `linearize` and `vjp` (and `grad`!)\n",
|
|
"\n",
|
|
"The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve\n",
|
|
"jaxprs as well. That's because both involve staging out, or delaying,\n",
|
|
"computation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### `linearize`\n",
|
|
"\n",
|
|
"In the case of `linearize`, we want to stage out the linear part of a `jvp`\n",
|
|
"computation. That is, in terms of\n",
|
|
"[Haskell-like type signatures](https://wiki.haskell.org/Type_signature),\n",
|
|
"if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,\n",
|
|
"then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to\n",
|
|
"mean \"the tangent type of `a`\" and using the \"lollipop\" `-o` rather than the\n",
|
|
"arrow `->` to indicate a _linear_ function. We define the semantics of\n",
|
|
"`linearize` in terms of `jvp` too:\n",
|
|
"```python\n",
|
|
"y, f_lin = linearize(f, x)\n",
|
|
"y_dot = f_lin(x_dot)\n",
|
|
"```\n",
|
|
"gives the same result for `(y, y_dot)` as\n",
|
|
"```\n",
|
|
"y, y_dot = jvp(f, (x,), (x_dot,))\n",
|
|
"```\n",
|
|
"where the application of `f_lin` does not redo any of the linearization work.\n",
|
|
"We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.\n",
|
|
"\n",
|
|
"Tangentially, now that we have linear arrows `-o`, we can provide a slightly\n",
|
|
"more informative type for `jvp`:\n",
|
|
"```\n",
|
|
"jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)\n",
|
|
"```\n",
|
|
"Here we're writing `UnrestrictedUse` just to indicate that we have a special\n",
|
|
"pair where the first element can be used in an unrestricted (nonlinear) way.\n",
|
|
"In conjunction with the linear arrow, this notation is just meant to express\n",
|
|
"that the function `jvp f` uses its first input in a nonlinear way but its\n",
|
|
"second input in a linear way, producing a corresponding nonlinear output\n",
|
|
"(which can be used in a nonlinear way) paired with a linear output. This more\n",
|
|
"refined type signature encodes the data dependencies in `jvp f`, which are\n",
|
|
"useful for partial evaluation.\n",
|
|
"\n",
|
|
"To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:\n",
|
|
"we evaluate all the primal values as we trace, but stage the tangent\n",
|
|
"computations into a jaxpr. This is our second way to build jaxprs. But where\n",
|
|
"`make_jaxpr` and its underlying `JaxprTrace`/`JaxprTracer` interpreters aim\n",
|
|
"to stage out every primitive bind, this second approach stages out only those\n",
|
|
"primitive binds with a data dependence on tangent inputs.\n",
|
|
"\n",
|
|
"First, some utilities:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:\n",
|
|
" assert not len(lst) % 2\n",
|
|
" return split_list(lst, len(lst) // 2)\n",
|
|
"\n",
|
|
"def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]:\n",
|
|
" l1, l2 = iter(l1), iter(l2)\n",
|
|
" out = [next(l2) if b else next(l1) for b in which]\n",
|
|
" assert next(l1, None) is next(l2, None) is None\n",
|
|
" return out"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next, we'll write `linearize` by combining `jvp` together with a general\n",
|
|
"partial evaluation transformation, to be added next:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def linearize_flat(f, *primals_in):\n",
|
|
" pvals_in = ([PartialVal.known(x) for x in primals_in] +\n",
|
|
" [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])\n",
|
|
" def f_jvp(*primals_tangents_in):\n",
|
|
" primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))\n",
|
|
" return [*primals_out, *tangents_out]\n",
|
|
" jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)\n",
|
|
" primal_pvals, _ = split_half(pvals_out)\n",
|
|
" assert all(pval.is_known for pval in primal_pvals)\n",
|
|
" primals_out = [pval.const for pval in primal_pvals]\n",
|
|
" f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])\n",
|
|
" return primals_out, f_lin\n",
|
|
"\n",
|
|
"def linearize(f, *primals_in):\n",
|
|
" primals_in_flat, in_tree = tree_flatten(primals_in)\n",
|
|
" f, out_tree = flatten_fun(f, in_tree)\n",
|
|
" primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)\n",
|
|
" primals_out = tree_unflatten(out_tree(), primals_out_flat)\n",
|
|
"\n",
|
|
" def f_lin(*tangents_in):\n",
|
|
" tangents_in_flat, in_tree2 = tree_flatten(tangents_in)\n",
|
|
" if in_tree != in_tree2: raise TypeError\n",
|
|
" tangents_out_flat = f_lin_flat(*tangents_in_flat)\n",
|
|
" return tree_unflatten(out_tree(), tangents_out_flat)\n",
|
|
"\n",
|
|
" return primals_out, f_lin\n",
|
|
"\n",
|
|
"def vspace(aval: ShapedArray) -> ShapedArray:\n",
|
|
" return raise_to_shaped(aval) # TODO handle integers?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we turn to the general partial evaluation transformation. The goal is to\n",
|
|
"accept a Python callable and a list of inputs, some known and some unknown,\n",
|
|
"and to produce (1) all the outputs which can be computed from the known\n",
|
|
"inputs, together with (2) a jaxpr representing the part of the Python\n",
|
|
"callable's computation which can only be performed after the remaining inputs\n",
|
|
"are known.\n",
|
|
"\n",
|
|
"This transformation is tricky to summarize in a type signature. If we\n",
|
|
"assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where\n",
|
|
"`a1` and `a2` represent the known and unknown inputs, respectively, and where\n",
|
|
"`b1` only has a data dependency on `a1` while `b2` has some data dependency on\n",
|
|
"`a2`, then we might write\n",
|
|
"\n",
|
|
"```\n",
|
|
"partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)\n",
|
|
"```\n",
|
|
"\n",
|
|
"In words, given values for the inputs of type `a1`, `partial_eval` produces\n",
|
|
"the outputs of type `b1` along with \"residual\" values of\n",
|
|
"existentially-quantified type `r` representing the intermediates required to\n",
|
|
"complete the computation in the second stage. It also produces a function of\n",
|
|
"type `(r, a2) -> b2` which accepts the residual values as well as the\n",
|
|
"remaining inputs and produces the remaining outputs.\n",
|
|
"\n",
|
|
"We like to think of partial evaluation as \"unzipping\" one computation into\n",
|
|
"two. For example, consider this jaxpr:\n",
|
|
"```\n",
|
|
"{ lambda a:float64[] .\n",
|
|
" let b:float64[] = sin a\n",
|
|
" c:float64[] = neg b\n",
|
|
" in ( c ) }\n",
|
|
"```\n",
|
|
"A jaxpr for the JVP would look like:\n",
|
|
"```\n",
|
|
"{ lambda a:float64[] b:float64[] .\n",
|
|
" let c:float64[] = sin a\n",
|
|
" d:float64[] = cos a\n",
|
|
" e:float64[] = mul d b\n",
|
|
" f:float64[] = neg c\n",
|
|
" g:float64[] = neg e\n",
|
|
" in ( f, g ) }\n",
|
|
"```\n",
|
|
"If we imagine applying partial evaluation to this jaxpr with the first input\n",
|
|
"known and the second unknown, we end up 'unzipping' the JVP jaxpr into primal\n",
|
|
"and tangent jaxprs:\n",
|
|
"```\n",
|
|
"{ lambda a:float64[] .\n",
|
|
" let c:float64[] = sin a\n",
|
|
" d:float64[] = cos a\n",
|
|
" f:float64[] = neg c\n",
|
|
" in ( f, d ) }\n",
|
|
"```\n",
|
|
"```\n",
|
|
"{ lambda d:float64[] b:float64[] .\n",
|
|
" let e:float64[] = mul d b\n",
|
|
" g:float64[] = neg e\n",
|
|
" in ( g ) }\n",
|
|
"```\n",
|
|
"This second jaxpr represents the linear computation that we want from\n",
|
|
"`linearize`.\n",
|
|
"\n",
|
|
"However, unlike in this jaxpr example, we want the computation on known values\n",
|
|
"to occur while evaluating the input Python callable. That is, rather than\n",
|
|
"forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all\n",
|
|
"operations out of Python first before sorting out what can be evaluated now\n",
|
|
"and what must be delayed, we want only to form a jaxpr for those operations\n",
|
|
"that _must_ be delayed due to a dependence on unknown inputs. In the context\n",
|
|
"of automatic differentiation, this is the feature that ultimately enables us\n",
|
|
"to handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python\n",
|
|
"control flow works because partial evaluation keeps the primal computation in\n",
|
|
"Python. As a consequence, our `Trace` and `Tracer` subclasses must on the fly\n",
|
|
"sort out what can be evaluated and what must be staged out into a jaxpr.\n",
|
|
"\n",
|
|
"First, we start with a `PartialVal` class, which represents a value that can\n",
|
|
"be either known or unknown:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class PartialVal(NamedTuple):\n",
|
|
" aval: ShapedArray\n",
|
|
" const: Optional[Any]\n",
|
|
"\n",
|
|
" @classmethod\n",
|
|
" def known(cls, val: Any):\n",
|
|
" return PartialVal(get_aval(val), val)\n",
|
|
"\n",
|
|
" @classmethod\n",
|
|
" def unknown(cls, aval: ShapedArray):\n",
|
|
" return PartialVal(aval, None)\n",
|
|
"\n",
|
|
" is_known = property(lambda self: self.const is not None)\n",
|
|
" is_unknown = property(lambda self: self.const is None)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Partial evaluation will take a list of `PartialVal`s representing inputs, and\n",
|
|
"return a list of `PartialVal` outputs along with a jaxpr representing the\n",
|
|
"delayed computation:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def partial_eval_flat(f: Callable, pvals_in: List[PartialVal]\n",
|
|
" ) -> Tuple[Jaxpr, List[PartialVal], List[Any]]:\n",
|
|
" with new_main(PartialEvalTrace) as main:\n",
|
|
" trace = PartialEvalTrace(main)\n",
|
|
" tracers_in = [trace.new_arg(pval) for pval in pvals_in]\n",
|
|
" outs = f(*tracers_in)\n",
|
|
" tracers_out = [full_raise(trace, out) for out in outs]\n",
|
|
" pvals_out = [t.pval for t in tracers_out]\n",
|
|
" unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown]\n",
|
|
" unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]\n",
|
|
" jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)\n",
|
|
" return jaxpr, pvals_out, consts"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next we need to implement `PartialEvalTrace` and its `PartialEvalTracer`. This\n",
|
|
"interpreter will build a jaxpr on the fly while tracking data dependencies. To\n",
|
|
"do so, it builds a bipartite directed acyclic graph (DAG) between\n",
|
|
"`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`\n",
|
|
"nodes, representing formulas for how to compute some values from others. One\n",
|
|
"kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s\n",
|
|
"primitive application, but we also have recipe types for constants and lambda\n",
|
|
"binders:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from weakref import ref, ReferenceType\n",
|
|
"\n",
|
|
"class LambdaBindingRecipe(NamedTuple):\n",
|
|
" pass\n",
|
|
"\n",
|
|
"class ConstRecipe(NamedTuple):\n",
|
|
" val: Any\n",
|
|
"\n",
|
|
"class JaxprEqnRecipe(NamedTuple):\n",
|
|
" prim: Primitive\n",
|
|
" tracers_in: List['PartialEvalTracer']\n",
|
|
" params: Dict[str, Any]\n",
|
|
" avals_out: List[ShapedArray]\n",
|
|
" tracer_refs_out: List['ReferenceType[PartialEvalTracer]']\n",
|
|
"\n",
|
|
"JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class PartialEvalTracer(Tracer):\n",
|
|
" pval: PartialVal\n",
|
|
" recipe: Optional[JaxprRecipe]\n",
|
|
"\n",
|
|
" def __init__(self, trace, pval, recipe):\n",
|
|
" self._trace = trace\n",
|
|
" self.pval = pval\n",
|
|
" self.recipe = recipe\n",
|
|
"\n",
|
|
" aval = property(lambda self: self.pval.aval)\n",
|
|
"\n",
|
|
" def full_lower(self):\n",
|
|
" if self.pval.is_known:\n",
|
|
" return full_lower(self.pval.const)\n",
|
|
" return self"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `PartialEvalTrace` contains the logic for constructing the graph of\n",
|
|
"`JaxprRecipe`s and `PartialEvalTracer`s. Each argument corresponds to a\n",
|
|
"`LambdaBindingRecipe` leaf node, and each constant is a `ConstRecipe` leaf\n",
|
|
"node holding a reference to the constant. All other tracers and recipes come\n",
|
|
"from `process_primitive`, which forms tracers with `JaxprEqnRecipe`s.\n",
|
|
"\n",
|
|
"For most primitives, the `process_primitive` logic is straightforward: if all\n",
|
|
"inputs are known then we can bind the primitive on the known values\n",
|
|
"(evaluating it in Python) and avoid forming tracers corresponding to the\n",
|
|
"output. If instead any input is unknown then we instead stage out into a\n",
|
|
"`JaxprEqnRecipe` representing the primitive application. To build the tracers\n",
|
|
"representing unknown outputs, we need avals, which we get from the abstract\n",
|
|
"eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and\n",
|
|
"`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using\n",
|
|
"weakrefs.)\n",
|
|
"\n",
|
|
"That `process_primitive` logic applies to most primitives, but `xla_call_p`\n",
|
|
"requires recursive treatment. So we special-case its rule in a\n",
|
|
"`partial_eval_rules` dict."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class PartialEvalTrace(Trace):\n",
|
|
" def new_arg(self, pval: PartialVal) -> Any:\n",
|
|
" return PartialEvalTracer(self, pval, LambdaBindingRecipe())\n",
|
|
"\n",
|
|
" def lift(self, val: Any) -> PartialEvalTracer:\n",
|
|
" return PartialEvalTracer(self, PartialVal.known(val), None)\n",
|
|
" pure = lift\n",
|
|
"\n",
|
|
" def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:\n",
|
|
" if tracer.pval.is_unknown:\n",
|
|
" return tracer\n",
|
|
" else:\n",
|
|
" pval = PartialVal.unknown(raise_to_shaped(tracer.aval))\n",
|
|
" return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" if all(t.pval.is_known for t in tracers):\n",
|
|
" return bind(primitive, *map(full_lower, tracers), **params)\n",
|
|
" rule = partial_eval_rules.get(primitive)\n",
|
|
" if rule: return rule(self, tracers, **params)\n",
|
|
" tracers_in = [self.instantiate_const(t) for t in tracers]\n",
|
|
" avals_in = [t.aval for t in tracers_in]\n",
|
|
" avals_out = abstract_eval_rules[primitive](*avals_in, **params)\n",
|
|
" tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)\n",
|
|
" for aval in avals_out]\n",
|
|
" eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,\n",
|
|
" map(ref, tracers_out))\n",
|
|
" for t in tracers_out: t.recipe = eqn\n",
|
|
" return tracers_out\n",
|
|
"\n",
|
|
"partial_eval_rules = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now that we can build graph representations of jaxprs with `PartialEvalTrace`,\n",
|
|
"we need a mechanism to convert the graph representation to a standard jaxpr.\n",
|
|
"The jaxpr corresponds to a topological sort of the graph."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],\n",
|
|
" tracers_out: List[PartialEvalTracer]):\n",
|
|
" tracer_to_var: Dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))\n",
|
|
" for t in tracers_in}\n",
|
|
" constvar_to_val: Dict[int, Any] = {}\n",
|
|
" constid_to_var: Dict[int, Var] = {}\n",
|
|
" processed_eqns: Set[int] = set()\n",
|
|
" eqns: List[JaxprEqn] = []\n",
|
|
" for t in toposort(tracers_out, tracer_parents):\n",
|
|
" if isinstance(t.recipe, LambdaBindingRecipe):\n",
|
|
" assert id(t) in set(map(id, tracers_in))\n",
|
|
" elif isinstance(t.recipe, ConstRecipe):\n",
|
|
" val = t.recipe.val\n",
|
|
" var = constid_to_var.get(id(val))\n",
|
|
" if var is None:\n",
|
|
" aval = raise_to_shaped(get_aval(val))\n",
|
|
" var = constid_to_var[id(val)] = Var(aval)\n",
|
|
" constvar_to_val[var] = val\n",
|
|
" tracer_to_var[id(t)] = var\n",
|
|
" elif isinstance(t.recipe, JaxprEqnRecipe):\n",
|
|
" if id(t.recipe) not in processed_eqns:\n",
|
|
" eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))\n",
|
|
" processed_eqns.add(id(t.recipe))\n",
|
|
" else:\n",
|
|
" raise TypeError(t.recipe)\n",
|
|
"\n",
|
|
" constvars, constvals = unzip2(constvar_to_val.items())\n",
|
|
" in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]\n",
|
|
" out_vars = [tracer_to_var[id(t)] for t in tracers_out]\n",
|
|
" jaxpr = Jaxpr(in_binders, eqns, out_vars)\n",
|
|
" typecheck_jaxpr(jaxpr)\n",
|
|
" return jaxpr, constvals\n",
|
|
"\n",
|
|
"def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe\n",
|
|
" ) -> JaxprEqn:\n",
|
|
" inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]\n",
|
|
" out_binders = [Var(aval) for aval in recipe.avals_out]\n",
|
|
" for t_ref, var in zip(recipe.tracer_refs_out, out_binders):\n",
|
|
" if t_ref() is not None: tracer_to_var[id(t_ref())] = var\n",
|
|
" return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)\n",
|
|
"\n",
|
|
"def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:\n",
|
|
" return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": [
|
|
"hide-input"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):\n",
|
|
" if not out_nodes: return []\n",
|
|
" out_nodes = remove_duplicates(out_nodes)\n",
|
|
"\n",
|
|
" child_counts = {}\n",
|
|
" stack = list(out_nodes)\n",
|
|
" while stack:\n",
|
|
" node = stack.pop()\n",
|
|
" if id(node) in child_counts:\n",
|
|
" child_counts[id(node)] += 1\n",
|
|
" else:\n",
|
|
" child_counts[id(node)] = 1\n",
|
|
" stack.extend(parents(node))\n",
|
|
" for node in out_nodes:\n",
|
|
" child_counts[id(node)] -= 1\n",
|
|
"\n",
|
|
" sorted_nodes = []\n",
|
|
" childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]\n",
|
|
" while childless_nodes:\n",
|
|
" node = childless_nodes.pop()\n",
|
|
" sorted_nodes.append(node)\n",
|
|
" for parent in parents(node):\n",
|
|
" if child_counts[id(parent)] == 1:\n",
|
|
" childless_nodes.append(parent)\n",
|
|
" else:\n",
|
|
" child_counts[id(parent)] -= 1\n",
|
|
"\n",
|
|
" sorted_nodes = sorted_nodes[::-1]\n",
|
|
" check_toposort(sorted_nodes, parents)\n",
|
|
" return sorted_nodes\n",
|
|
"\n",
|
|
"def remove_duplicates(lst):\n",
|
|
" seen = set()\n",
|
|
" return [x for x in lst if id(x) not in seen and not seen.add(id(x))]\n",
|
|
"\n",
|
|
"def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):\n",
|
|
" seen = set()\n",
|
|
" for node in nodes:\n",
|
|
" assert all(id(parent) in seen for parent in parents(node))\n",
|
|
" seen.add(id(node))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we can linearize!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y, sin_lin = linearize(sin, 3.)\n",
|
|
"print(y, sin(3.))\n",
|
|
"print(sin_lin(1.), cos(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To handle `linearize`-of-`jit`, we still need to write a partial evaluation\n",
|
|
"rule for `xla_call_p`. Other than tracer bookkeeping, the main task is to\n",
|
|
"perform partial evaluation of a jaxpr, 'unzipping' it into two jaxprs.\n",
|
|
"\n",
|
|
"There are actually two rules to write: one for trace-time partial evaluation,\n",
|
|
"which we'll call `xla_call_partial_eval`, and one for partial evaluation of\n",
|
|
"jaxprs, which we'll call `xla_call_peval_eqn`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):\n",
|
|
" del num_consts # Unused\n",
|
|
" in_unknowns = [not t.pval.is_known for t in tracers]\n",
|
|
" jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)\n",
|
|
" known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)\n",
|
|
" known_vals = [t.pval.const for t in known_tracers]\n",
|
|
" outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)\n",
|
|
" outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)\n",
|
|
" res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]\n",
|
|
" outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)\n",
|
|
" for v in jaxpr2.outs]\n",
|
|
" eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,\n",
|
|
" dict(jaxpr=jaxpr2, num_consts=0),\n",
|
|
" [v.aval for v in jaxpr2.outs], map(ref, outs2))\n",
|
|
" for t in outs2: t.recipe = eqn\n",
|
|
" return merge_lists(out_unknowns, outs1, outs2)\n",
|
|
"partial_eval_rules[xla_call_p] = xla_call_partial_eval\n",
|
|
"\n",
|
|
"def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],\n",
|
|
" instantiate: Optional[List[bool]] = None,\n",
|
|
" ) -> Tuple[Jaxpr, Jaxpr, List[bool], int]:\n",
|
|
" env: Dict[Var, bool] = {}\n",
|
|
" residuals: Set[Var] = set()\n",
|
|
"\n",
|
|
" def read(x: Atom) -> bool:\n",
|
|
" return type(x) is Var and env[x]\n",
|
|
"\n",
|
|
" def write(unk: bool, v: Var) -> None:\n",
|
|
" env[v] = unk\n",
|
|
"\n",
|
|
" def new_res(x: Atom) -> Atom:\n",
|
|
" if type(x) is Var: residuals.add(x)\n",
|
|
" return x\n",
|
|
"\n",
|
|
" eqns1, eqns2 = [], []\n",
|
|
" map(write, in_unknowns, jaxpr.in_binders)\n",
|
|
" for eqn in jaxpr.eqns:\n",
|
|
" unks_in = map(read, eqn.inputs)\n",
|
|
" rule = partial_eval_jaxpr_rules.get(eqn.primitive)\n",
|
|
" if rule:\n",
|
|
" eqn1, eqn2, unks_out, res = rule(unks_in, eqn)\n",
|
|
" eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)\n",
|
|
" map(write, unks_out, eqn.out_binders)\n",
|
|
" elif any(unks_in):\n",
|
|
" inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]\n",
|
|
" eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))\n",
|
|
" map(partial(write, True), eqn.out_binders)\n",
|
|
" else:\n",
|
|
" eqns1.append(eqn)\n",
|
|
" map(partial(write, False), eqn.out_binders)\n",
|
|
" out_unknowns = map(read, jaxpr.outs)\n",
|
|
" if instantiate is not None:\n",
|
|
" for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):\n",
|
|
" if inst and not uk: new_res(v)\n",
|
|
" out_unknowns = map(op.or_, out_unknowns, instantiate)\n",
|
|
"\n",
|
|
" residuals, num_res = list(residuals), len(residuals)\n",
|
|
" assert all(type(v) is Var for v in residuals), residuals\n",
|
|
"\n",
|
|
" ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)\n",
|
|
" outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)\n",
|
|
"\n",
|
|
" jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)\n",
|
|
" jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)\n",
|
|
" typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)\n",
|
|
"\n",
|
|
" return jaxpr1, jaxpr2, out_unknowns, num_res\n",
|
|
"\n",
|
|
"def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):\n",
|
|
" jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 )\n",
|
|
" jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res)\n",
|
|
" jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2\n",
|
|
"\n",
|
|
" a1, a2 = partition_list(unks_in, jaxprty.in_types)\n",
|
|
" b1, b2 = partition_list(unks_out, jaxprty.out_types)\n",
|
|
" b1_, res = split_list(jaxpr1ty.out_types, len(b1))\n",
|
|
" res_, a2_ = split_list(jaxpr2ty.in_types, len(res))\n",
|
|
" b2_ = jaxpr2ty.out_types\n",
|
|
"\n",
|
|
" if jaxpr1ty.in_types != a1: raise TypeError\n",
|
|
" if jaxpr2ty.out_types != b2: raise TypeError\n",
|
|
" if b1 != b1_: raise TypeError\n",
|
|
" if res != res_: raise TypeError\n",
|
|
" if a2 != a2_: raise TypeError\n",
|
|
" if b2 != b2_: raise TypeError\n",
|
|
"\n",
|
|
"partial_eval_jaxpr_rules = {}\n",
|
|
"\n",
|
|
"def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,\n",
|
|
" ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Var]]:\n",
|
|
" jaxpr = eqn.params['jaxpr']\n",
|
|
" jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)\n",
|
|
" ins1, ins2 = partition_list(unks_in, eqn.inputs)\n",
|
|
" out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)\n",
|
|
" residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]\n",
|
|
" eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),\n",
|
|
" out_binders1 + residuals)\n",
|
|
" eqn2 = JaxprEqn(xla_call_p, residuals + ins2,\n",
|
|
" dict(jaxpr=jaxpr2, num_consts=0), out_binders2)\n",
|
|
" return eqn1, eqn2, unks_out, residuals\n",
|
|
"partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With that, we can compose `linearize` and `jit` however we like:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"\n",
|
|
"y, f_lin = linearize(f, 3.)\n",
|
|
"y_dot = f_lin(1.)\n",
|
|
"print(y, y_dot)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = g(x, y)\n",
|
|
" return z\n",
|
|
"\n",
|
|
"@jit\n",
|
|
"def g(x, y):\n",
|
|
" return cos(x) + y\n",
|
|
"\n",
|
|
"y, f_lin = linearize(f, 3.)\n",
|
|
"y_dot = f_lin(1.)\n",
|
|
"print(y, y_dot)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### `vjp` and `grad`\n",
|
|
"\n",
|
|
"The `vjp` transformation works a lot like linearize. Its type signature is\n",
|
|
"analogous:\n",
|
|
"\n",
|
|
"```\n",
|
|
"linearize : (a -> b) -> a -> (b, T a -o T b)\n",
|
|
"vjp : (a -> b) -> a -> (b, T b -o T a)\n",
|
|
"```\n",
|
|
"\n",
|
|
"The only difference is that we transpose the linear part of the computation\n",
|
|
"before returning it, so that it goes from type `T a -o T b` to type `T b -o T\n",
|
|
"a`. That is, we'll implement `vjp` as, essentially,\n",
|
|
"\n",
|
|
"```\n",
|
|
"def vjp(f, x):\n",
|
|
" y, f_lin = linearize(f, x)\n",
|
|
" f_vjp = lambda y_bar: transpose(f_lin)(y_bar)\n",
|
|
" return y, f_vjp\n",
|
|
"```\n",
|
|
"\n",
|
|
"Since we have the linear computation as a jaxpr, not just a Python callable,\n",
|
|
"we can implement the transpose transformation as a jaxpr interpreter."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def vjp_flat(f, *primals_in):\n",
|
|
" pvals_in = ([PartialVal.known(x) for x in primals_in] +\n",
|
|
" [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])\n",
|
|
" primal_pvals_in, tangent_pvals_in = split_half(pvals_in)\n",
|
|
" def f_jvp(*primals_tangents_in):\n",
|
|
" primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))\n",
|
|
" return [*primals_out, *tangents_out]\n",
|
|
" jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize\n",
|
|
" primal_pvals, _ = split_half(pvals_out)\n",
|
|
" assert all(pval.is_known for pval in primal_pvals)\n",
|
|
" primals_out = [pval.const for pval in primal_pvals]\n",
|
|
" transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]\n",
|
|
" f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)\n",
|
|
" return primals_out, f_vjp\n",
|
|
"\n",
|
|
"def vjp(f, *primals_in):\n",
|
|
" primals_in_flat, in_tree = tree_flatten(primals_in)\n",
|
|
" f, out_tree = flatten_fun(f, in_tree)\n",
|
|
" primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)\n",
|
|
" primals_out = tree_unflatten(out_tree(), primals_out_flat)\n",
|
|
"\n",
|
|
" def f_vjp(*cotangents_out):\n",
|
|
" cotangents_out_flat, _ = tree_flatten(cotangents_out)\n",
|
|
" cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)\n",
|
|
" return tree_unflatten(in_tree, cotangents_in_flat)\n",
|
|
"\n",
|
|
" return primals_out, f_vjp\n",
|
|
"\n",
|
|
"class UndefPrimal(NamedTuple):\n",
|
|
" aval: ShapedArray\n",
|
|
"\n",
|
|
"register_pytree_node(UndefPrimal,\n",
|
|
" lambda u: (u.aval, ()),\n",
|
|
" lambda aval, _: UndefPrimal(aval))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We use `UndefPrimal` instances to indicate which arguments with respect to\n",
|
|
"which we want to transpose. These arise because in general, being explicit\n",
|
|
"about closed-over values, we want to transpose functions of type\n",
|
|
"`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the\n",
|
|
"inputs with respect to which the function is linear could be scattered through\n",
|
|
"the argument list. So we indicate the linear positions using `UndefPrimal`.\n",
|
|
"We register `UndefPrimal` as a pytree node because the pytree mechanism gives\n",
|
|
"a handy way to prune these placeholders out of argument lists.\n",
|
|
"\n",
|
|
"Next, we can write `eval_jaxpr_transposed`, along with transpose rules for\n",
|
|
"all primitives which can be linear in at least one argument:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NB: the analogous function in JAX is called 'backward_pass'\n",
|
|
"def eval_jaxpr_transposed(jaxpr: Jaxpr, args: List[Any], cotangents: List[Any]\n",
|
|
" ) -> List[Any]:\n",
|
|
" primal_env: Dict[Var, Any] = {}\n",
|
|
" ct_env: Dict[Var, Any] = {}\n",
|
|
"\n",
|
|
" def read_primal(x: Atom) -> Any:\n",
|
|
" return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val\n",
|
|
"\n",
|
|
" def write_primal(v: Var, val: Any) -> None:\n",
|
|
" if type(val) is not UndefPrimal:\n",
|
|
" primal_env[v] = val\n",
|
|
"\n",
|
|
" def read_cotangent(v: Var) -> Any:\n",
|
|
" return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))\n",
|
|
"\n",
|
|
" def write_cotangent(x: Atom, val: Any):\n",
|
|
" if type(x) is Var and val is not None:\n",
|
|
" ct_env[x] = add(ct_env[x], val) if x in ct_env else val\n",
|
|
"\n",
|
|
" map(write_primal, jaxpr.in_binders, args)\n",
|
|
" map(write_cotangent, jaxpr.outs, cotangents)\n",
|
|
" for eqn in jaxpr.eqns[::-1]:\n",
|
|
" primals_in = map(read_primal, eqn.inputs)\n",
|
|
" cts_in = map(read_cotangent, eqn.out_binders)\n",
|
|
" rule = transpose_rules[eqn.primitive]\n",
|
|
" cts_out = rule(cts_in, *primals_in, **eqn.params)\n",
|
|
" map(write_cotangent, eqn.inputs, cts_out)\n",
|
|
"\n",
|
|
" return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)\n",
|
|
" if type(x) is UndefPrimal]\n",
|
|
"\n",
|
|
"transpose_rules = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def mul_transpose_rule(cts, x, y):\n",
|
|
" z_bar, = cts\n",
|
|
" assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)\n",
|
|
" return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]\n",
|
|
"transpose_rules[mul_p] = mul_transpose_rule\n",
|
|
"\n",
|
|
"def neg_transpose_rule(cts, x):\n",
|
|
" ybar, = cts\n",
|
|
" assert type(x) is UndefPrimal\n",
|
|
" return [neg(ybar)]\n",
|
|
"transpose_rules[neg_p] = neg_transpose_rule\n",
|
|
"\n",
|
|
"def add_transpose_rule(cts, x, y):\n",
|
|
" z_bar, = cts\n",
|
|
" return [z_bar, z_bar]\n",
|
|
"transpose_rules[add_p] = add_transpose_rule\n",
|
|
"\n",
|
|
"def reduce_sum_transpose_rule(cts, x, *, axis):\n",
|
|
" y_bar, = cts\n",
|
|
" return [broadcast(y_bar, x.aval.shape, axis)]\n",
|
|
"transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule\n",
|
|
"\n",
|
|
"def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):\n",
|
|
" del num_consts # Unused\n",
|
|
" undef_primals = [type(x) is UndefPrimal for x in invals]\n",
|
|
" transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))\n",
|
|
" residuals, _ = partition_list(undef_primals, invals)\n",
|
|
" outs = bind(xla_call_p, *new_consts, *residuals, *cts,\n",
|
|
" jaxpr=transposed_jaxpr, num_consts=len(new_consts))\n",
|
|
" outs = iter(outs)\n",
|
|
" return [next(outs) if undef else None for undef in undef_primals]\n",
|
|
"transpose_rules[xla_call_p] = xla_call_transpose_rule\n",
|
|
"\n",
|
|
"@lru_cache()\n",
|
|
"def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: Tuple[bool, ...]\n",
|
|
" ) -> Tuple[Jaxpr, List[Any]]:\n",
|
|
" avals_in, avals_out = typecheck_jaxpr(jaxpr)\n",
|
|
" traceable = partial(eval_jaxpr_transposed, jaxpr)\n",
|
|
" args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]\n",
|
|
" trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))\n",
|
|
" typecheck_jaxpr(trans_jaxpr)\n",
|
|
" return trans_jaxpr, consts"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now that we can linearize and transpose, we can finally write `grad`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def grad(f):\n",
|
|
" def gradfun(x, *xs):\n",
|
|
" y, f_vjp = vjp(f, x, *xs)\n",
|
|
" if np.shape(y) != (): raise TypeError\n",
|
|
" x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))\n",
|
|
" return x_bar\n",
|
|
" return gradfun"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y, f_vjp = vjp(sin, 3.)\n",
|
|
"print(f_vjp(1.), cos(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2.\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"\n",
|
|
"print(grad(f)(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x):\n",
|
|
" y = x * 2.\n",
|
|
" z = g(y)\n",
|
|
" return z\n",
|
|
"\n",
|
|
"@jit\n",
|
|
"def g(x):\n",
|
|
" return cos(x) * 2.\n",
|
|
"\n",
|
|
"print(grad(f)(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Here's something of a compositionality stress test:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# from core_test.py fun_with_nested_calls_2\n",
|
|
"def foo(x):\n",
|
|
" @jit\n",
|
|
" def bar(y):\n",
|
|
" def baz(w):\n",
|
|
" q = jit(lambda x: y)(x)\n",
|
|
" q = q + jit(lambda: y)()\n",
|
|
" q = q + jit(lambda y: w + y)(y)\n",
|
|
" q = jit(lambda w: jit(sin)(x) * y)(1.0) + q\n",
|
|
" return q\n",
|
|
" p, t = jvp(baz, (x + 1.0,), (y,))\n",
|
|
" return t + (x * p)\n",
|
|
" return bar(x)\n",
|
|
"\n",
|
|
"def assert_allclose(*vals):\n",
|
|
" for v1, v2 in zip(vals[:-1], vals[1:]):\n",
|
|
" np.testing.assert_allclose(v1, v2)\n",
|
|
"\n",
|
|
"ans1 = f(3.)\n",
|
|
"ans2 = jit(f)(3.)\n",
|
|
"ans3, _ = jvp(f, (3.,), (5.,))\n",
|
|
"ans4, _ = jvp(jit(f), (3.,), (5.,))\n",
|
|
"assert_allclose(ans1, ans2, ans3, ans4)\n",
|
|
"\n",
|
|
"deriv1 = grad(f)(3.)\n",
|
|
"deriv2 = grad(jit(f))(3.)\n",
|
|
"deriv3 = jit(grad(jit(f)))(3.)\n",
|
|
"_, deriv4 = jvp(f, (3.,), (1.,))\n",
|
|
"_, deriv5 = jvp(jit(f), (3.,), (1.,))\n",
|
|
"assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)\n",
|
|
"\n",
|
|
"hess1 = grad(grad(f))(3.)\n",
|
|
"hess2 = grad(grad(jit(f)))(3.)\n",
|
|
"hess3 = grad(jit(grad(f)))(3.)\n",
|
|
"hess4 = jit(grad(grad(f)))(3.)\n",
|
|
"_, hess5 = jvp(grad(f), (3.,), (1.,))\n",
|
|
"_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))\n",
|
|
"_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))\n",
|
|
"assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Part 5: the control flow primitives `cond`\n",
|
|
"\n",
|
|
"Next we'll add higher-order primitives for staged-out control flow. These\n",
|
|
"resemble `jit` from Part 3, another higher-order primitive, but differ in that\n",
|
|
"they are parameterized by multiple callables rather than just one."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Adding `cond`\n",
|
|
"\n",
|
|
"We introduce a `cond` primitive to represent conditional application of one\n",
|
|
"function or another inside a jaxpr. We write the type of `cond` as\n",
|
|
"`Bool -> (a -> b) -> (a -> b) -> a -> b`. In words, `cond` takes a boolean\n",
|
|
"representing the predicate and two functions of equal types. Depending on the\n",
|
|
"value of the predicate, it applies one function or the other to its final\n",
|
|
"argument.\n",
|
|
"\n",
|
|
"In Python, we represent it as a function which itself takes two functions as\n",
|
|
"arguments. As with `jit`, the first step is to call `make_jaxpr` on its\n",
|
|
"callable arguments to turn them into jaxprs:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cond(pred, true_fn, false_fn, *operands):\n",
|
|
" avals_in = [raise_to_shaped(get_aval(x)) for x in operands]\n",
|
|
" true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)\n",
|
|
" false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)\n",
|
|
" if out_tree != out_tree_: raise TypeError\n",
|
|
" true_jaxpr, false_jaxpr = _join_jaxpr_consts(\n",
|
|
" true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))\n",
|
|
" if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):\n",
|
|
" raise TypeError\n",
|
|
" outs = bind_cond(pred, *true_consts, *false_consts, *operands,\n",
|
|
" true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)\n",
|
|
" return tree_unflatten(out_tree, outs)\n",
|
|
"cond_p = Primitive('cond')\n",
|
|
"\n",
|
|
"def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int\n",
|
|
" ) -> Tuple[Jaxpr, Jaxpr]:\n",
|
|
" jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)\n",
|
|
" assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]\n",
|
|
" consts1, rest1 = split_list(jaxpr1.in_binders, n1)\n",
|
|
" consts2, rest2 = split_list(jaxpr2.in_binders, n2)\n",
|
|
" new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)\n",
|
|
" new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)\n",
|
|
" return new_jaxpr1, new_jaxpr2\n",
|
|
"\n",
|
|
"def bind_cond(pred, *args, true_jaxpr, false_jaxpr):\n",
|
|
" assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)\n",
|
|
" return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We require `true_jaxpr` and `false_jaxpr` to have the same type, but because\n",
|
|
"they might close over different constants (and because jaxprs can only\n",
|
|
"represent closed terms, i.e. can't have free variables and are instead\n",
|
|
"closure-converted) we need to use the helper `_join_jaxpr_consts` to make\n",
|
|
"consistent the input binder lists of the two jaxprs. (To be more economical we\n",
|
|
"could try to identify pairs of constants with the same shapes, but instead we\n",
|
|
"just concatenate the lists of constants.)\n",
|
|
"\n",
|
|
"Next we can turn to adding interpreter rules for `cond`. Its evaluation rule\n",
|
|
"is simple:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):\n",
|
|
" if pred:\n",
|
|
" return eval_jaxpr(true_jaxpr, operands)\n",
|
|
" else:\n",
|
|
" return eval_jaxpr(false_jaxpr, operands)\n",
|
|
"impl_rules[cond_p] = cond_impl"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"out = cond(True, lambda: 3, lambda: 4)\n",
|
|
"print(out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"For its JVP and vmap rules, we only need to call the same `jvp_jaxpr` and\n",
|
|
"`vmap_jaxpr` utilities we created for `jit`, followed by another pass of\n",
|
|
"`_join_jaxpr_consts`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):\n",
|
|
" pred, *primals = primals\n",
|
|
" _ , *tangents = tangents\n",
|
|
" true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)\n",
|
|
" false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)\n",
|
|
" true_jaxpr, false_jaxpr = _join_jaxpr_consts(\n",
|
|
" true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))\n",
|
|
" assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)\n",
|
|
" outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,\n",
|
|
" true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)\n",
|
|
" primals_out, tangents_out = split_half(outs)\n",
|
|
" return primals_out, tangents_out\n",
|
|
"jvp_rules[cond_p] = cond_jvp_rule"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))\n",
|
|
"print(out_tan)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):\n",
|
|
" pred , *vals_in = vals_in\n",
|
|
" pred_dim, *dims_in = dims_in\n",
|
|
" if pred_dim is not not_mapped: raise NotImplementedError # TODO\n",
|
|
" true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))\n",
|
|
" false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))\n",
|
|
" true_jaxpr, false_jaxpr = _join_jaxpr_consts(\n",
|
|
" true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))\n",
|
|
" assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)\n",
|
|
" outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,\n",
|
|
" true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)\n",
|
|
" return outs, [0] * len(outs)\n",
|
|
"vmap_rules[cond_p] = cond_vmap_rule"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"xs = np.array([1., 2., 3])\n",
|
|
"out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)\n",
|
|
"print(out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice that we're not currently supporting the case where the predicate value\n",
|
|
"itself is batched. In mainline JAX, we handle this case by transforming the\n",
|
|
"conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).\n",
|
|
"That transformation is semantically correct so long as `true_fun` and\n",
|
|
"`false_fun` do not involve any side-effecting primitives.\n",
|
|
"\n",
|
|
"Another thing not represented here, but present in the mainline JAX, is that\n",
|
|
"applying transformations to two jaxprs of equal type might result in jaxprs of\n",
|
|
"different types. For example, applying the mainline JAX version of\n",
|
|
"`vmap_jaxpr` to the identity-function jaxpr\n",
|
|
"\n",
|
|
"```\n",
|
|
"{ lambda a:float32[] .\n",
|
|
" let\n",
|
|
" in ( a ) }\n",
|
|
"```\n",
|
|
"\n",
|
|
"would result in a jaxpr with a batched output, of type\n",
|
|
"`[float32[10]] -> [float32[10]]` if the batch size were 10, while applying it\n",
|
|
"to the zero-function jaxpr\n",
|
|
"\n",
|
|
"```\n",
|
|
"{ lambda a:float32[] .\n",
|
|
" let\n",
|
|
" in ( 0. ) }\n",
|
|
"```\n",
|
|
"\n",
|
|
"would result in a jaxpr with an unbatched output, of type\n",
|
|
"`[float32[10]] -> [float32[]]`. This is an optimization, aimed at not batching\n",
|
|
"values unnecessarily. But it means that in `cond` we'd need an extra step of\n",
|
|
"joining the two transformed jaxprs to have consistent output types. We don't\n",
|
|
"need this step here because we chose `vmap_jaxpr` always to batch all outputs\n",
|
|
"over the leading axis."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next we can turn to abstract evaluation and XLA lowering rules:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):\n",
|
|
" if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError\n",
|
|
" jaxpr_type = typecheck_jaxpr(true_jaxpr)\n",
|
|
" if jaxpr_type != typecheck_jaxpr(false_jaxpr):\n",
|
|
" raise TypeError\n",
|
|
" if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):\n",
|
|
" raise TypeError\n",
|
|
" return jaxpr_type.out_types\n",
|
|
"abstract_eval_rules[cond_p] = cond_abstract_eval\n",
|
|
"\n",
|
|
"def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):\n",
|
|
" del in_avals # Unused\n",
|
|
" pred, *in_vals = in_vals\n",
|
|
" flat_vals, in_tree = tree_flatten(in_vals)\n",
|
|
" operand = xops.Tuple(c, flat_vals)\n",
|
|
" operand_shape = c.get_shape(operand)\n",
|
|
"\n",
|
|
" def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:\n",
|
|
" c = xc.XlaBuilder(name)\n",
|
|
" operand = xops.Parameter(c, 0, operand_shape)\n",
|
|
" operands = tree_unflatten(in_tree, destructure_tuple(c, operand))\n",
|
|
" outs = jaxpr_subcomp(c, jaxpr, operands)\n",
|
|
" return c.build(xops.Tuple(c, outs))\n",
|
|
"\n",
|
|
" true_comp = make_comp('true_fn', true_jaxpr)\n",
|
|
" false_comp = make_comp('false_fn', false_jaxpr)\n",
|
|
"\n",
|
|
" int_etype = xc.dtype_to_etype(np.dtype('int32'))\n",
|
|
" out = xops.Conditional(xops.ConvertElementType(pred, int_etype),\n",
|
|
" [false_comp, true_comp], [operand] * 2)\n",
|
|
" return destructure_tuple(c, out)\n",
|
|
"xla_translations[cond_p] = cond_translation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"out = jit(lambda: cond(False, lambda: 1, lambda: 2))()\n",
|
|
"print(out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, to support reverse-mode automatic differentiation, we need partial\n",
|
|
"evaluation and transposition rules. For partial evaluation, we need to\n",
|
|
"introduce another jaxpr-munging utility, `_join_jaxpr_res`, to handle the fact\n",
|
|
"that applying partial evaluation to `true_fun` and `false_fun` will in general\n",
|
|
"result in distinct residuals. We use `_join_jaxpr_res` to make the output\n",
|
|
"types of the transformed jaxprs consistent (while `_join_jaxpr_consts` dealt\n",
|
|
"with input types)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):\n",
|
|
" pred_tracer, *tracers = tracers\n",
|
|
" assert pred_tracer.pval.is_known\n",
|
|
" pred = pred_tracer.pval.const\n",
|
|
" in_uks = [not t.pval.is_known for t in tracers]\n",
|
|
"\n",
|
|
" *jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)\n",
|
|
" t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs\n",
|
|
"\n",
|
|
" known_tracers, unknown_tracers = partition_list(in_uks, tracers)\n",
|
|
" known_vals = [t.pval.const for t in known_tracers]\n",
|
|
" outs1_res = bind_cond(pred, *known_vals,\n",
|
|
" true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)\n",
|
|
" outs1, res = split_list(outs1_res, len(outs1_res) - num_res)\n",
|
|
" pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))\n",
|
|
" res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]\n",
|
|
" outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)\n",
|
|
" for v in t_jaxpr2.outs]\n",
|
|
" eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],\n",
|
|
" dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),\n",
|
|
" [v.aval for v in t_jaxpr2.outs], map(ref, outs2))\n",
|
|
" for t in outs2: t.recipe = eqn\n",
|
|
" return merge_lists(out_uks, outs1, outs2)\n",
|
|
"partial_eval_rules[cond_p] = cond_partial_eval\n",
|
|
"\n",
|
|
"def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: List[bool]\n",
|
|
" ) -> Tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, List[bool], int]:\n",
|
|
" _, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)\n",
|
|
" _, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)\n",
|
|
" out_uks = map(op.or_, t_out_uks, f_out_uks)\n",
|
|
"\n",
|
|
" t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)\n",
|
|
" f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)\n",
|
|
"\n",
|
|
" t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)\n",
|
|
" t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)\n",
|
|
" assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)\n",
|
|
" assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)\n",
|
|
" num_res = t_nres + f_nres\n",
|
|
"\n",
|
|
" return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res\n",
|
|
"\n",
|
|
"def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int\n",
|
|
" ) -> Tuple[Jaxpr, Jaxpr]:\n",
|
|
" jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)\n",
|
|
" out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)\n",
|
|
" out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)\n",
|
|
" assert out_types1 == out_types2\n",
|
|
" outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)\n",
|
|
" outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)\n",
|
|
" zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]\n",
|
|
" zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]\n",
|
|
" new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)\n",
|
|
" new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)\n",
|
|
" return new_jaxpr1, new_jaxpr2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)\n",
|
|
"out = f_lin(3.14)\n",
|
|
"print(out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,\n",
|
|
" ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:\n",
|
|
" pred_unk, *unks_in = unks_in\n",
|
|
" assert not pred_unk\n",
|
|
" true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']\n",
|
|
" *jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)\n",
|
|
" t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs\n",
|
|
" ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])\n",
|
|
" outs1, outs2 = partition_list(unks_out, eqn.out_binders)\n",
|
|
" residuals, _ = split_list(t_jaxpr2.in_binders, num_res)\n",
|
|
" eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],\n",
|
|
" dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),\n",
|
|
" outs1 + residuals)\n",
|
|
" eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],\n",
|
|
" dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),\n",
|
|
" outs2)\n",
|
|
" res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals\n",
|
|
" return eqn1, eqn2, unks_out, res\n",
|
|
"partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)\n",
|
|
"out = f_lin(3.14)\n",
|
|
"print(out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Transposition is a fairly straightforward application of `transpose_jaxpr`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):\n",
|
|
" undef_primals = tuple(type(x) is UndefPrimal for x in invals)\n",
|
|
" true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)\n",
|
|
" false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)\n",
|
|
" true_jaxpr, false_jaxpr = _join_jaxpr_consts(\n",
|
|
" true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))\n",
|
|
" res = [x for x in invals if type(x) is not UndefPrimal]\n",
|
|
" outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,\n",
|
|
" true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)\n",
|
|
" outs = iter(outs)\n",
|
|
" return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]\n",
|
|
"transpose_rules[cond_p] = cond_transpose_rule"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)\n",
|
|
"print(out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": [
|
|
"hide-input"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def pprint_cond(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
|
|
" true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']\n",
|
|
" new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}\n",
|
|
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
|
|
" rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>\n",
|
|
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
|
|
" for x in eqn.inputs)))\n",
|
|
" return vcat([lhs >> pp(' = ') >> rhs,\n",
|
|
" pp_jaxpr(true_jaxpr).indent(2),\n",
|
|
" pp_jaxpr(false_jaxpr).indent(2)])\n",
|
|
"pp_rules[cond_p] = pprint_cond"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"formats": "ipynb,md:myst,py",
|
|
"main_language": "python"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|