mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
1443 lines
47 KiB
Plaintext
1443 lines
47 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "raw",
|
|
"id": "surrounded-agent",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"Copyright 2021 Google LLC\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
|
"you may not use this file except in compliance with the License.\n",
|
|
"You may obtain a copy of the License at\n",
|
|
"\n",
|
|
" https://www.apache.org/licenses/LICENSE-2.0\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software\n",
|
|
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
|
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
|
"See the License for the specific language governing permissions and\n",
|
|
"limitations under the License.\n",
|
|
"---"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Autodidax: JAX core from scratch\n",
|
|
"\n",
|
|
"Ever want to learn how JAX works, but the implementation seemed too\n",
|
|
"impenetrable? Well, you're in luck! By reading this tutorial, you'll learn\n",
|
|
"every big idea in JAX's core system. You'll even get clued into our weird\n",
|
|
"jargon!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Part 1: Transformations as interpreters: standard evaluation, `jvp`, and `vmap`\n",
|
|
"\n",
|
|
"We want to transform functions that look like this:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"```\n",
|
|
"\n",
|
|
"Think of functions like `sin` and the arithmetic operations underlying the\n",
|
|
"infix operators (`mul`, `add`, and `neg`) as primitive operations, meaning\n",
|
|
"atomic units of processing rather than compositions.\n",
|
|
"\n",
|
|
"\"Transform\" means \"interpret differently.\" Instead of standard interpretation\n",
|
|
"where we apply primitive functions to numerical inputs to produce numerical\n",
|
|
"outputs, we want to override primitive application and let different values\n",
|
|
"flow through our program. For example, we might want to replace the\n",
|
|
"application of every primitive with an application of [its JVP\n",
|
|
"rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n",
|
|
"and let primal-tangent pairs flow through our program. Moreover, we want to\n",
|
|
"apply a composition of multiple transformations, leading to stacks of\n",
|
|
"interpreters."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"source": [
|
|
"### JAX core machinery\n",
|
|
"\n",
|
|
"We can implement stacks of interpreters and even have them all discharge on\n",
|
|
"the fly as we execute the Python function to be transformed. To start, let's\n",
|
|
"define these primitives so that we can intercept their application:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import NamedTuple\n",
|
|
"\n",
|
|
"class Primitive(NamedTuple):\n",
|
|
" name: str\n",
|
|
"\n",
|
|
"add_p = Primitive('add')\n",
|
|
"mul_p = Primitive('mul')\n",
|
|
"neg_p = Primitive(\"neg\")\n",
|
|
"sin_p = Primitive(\"sin\")\n",
|
|
"cos_p = Primitive(\"cos\")\n",
|
|
"reduce_sum_p = Primitive(\"reduce_sum\")\n",
|
|
"greater_p = Primitive(\"greater\")\n",
|
|
"\n",
|
|
"def add(x, y): return bind(add_p, x, y)\n",
|
|
"def mul(x, y): return bind(mul_p, x, y)\n",
|
|
"def neg(x): return bind(neg_p, x)\n",
|
|
"def sin(x): return bind(sin_p, x)\n",
|
|
"def cos(x): return bind(cos_p, x)\n",
|
|
"def reduce_sum(x, axis=None): return bind(reduce_sum_p, x, axis=axis)\n",
|
|
"def greater(x, y): return bind(greater_p, x, y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We'll set up array data types and infix operator methods in a moment.\n",
|
|
"\n",
|
|
"A `Primitive` is just an object with a name, to which we attach our\n",
|
|
"interpretation rules (one for each transformation). The `bind` function is our\n",
|
|
"interception point: it'll figure out which transformation rule to apply, based\n",
|
|
"on how the arguments are boxed in tracers and what interpreters are active.\n",
|
|
"\n",
|
|
"The functions that user code calls, like `add` and `sin`, are just wrappers\n",
|
|
"around calls to `bind`. These wrappers let us control how arguments are passed\n",
|
|
"to `bind`, and in particular we follow a handy internal convention: when we\n",
|
|
"call `bind`, we pass values representing array data as positional arguments,\n",
|
|
"and we pass metadata like the `axis` argument to `sum_p` via keyword. This\n",
|
|
"calling convention simplifies some core logic (since e.g. instances of the\n",
|
|
"`Tracer` class to be defined below can only occurr in positional arguments to\n",
|
|
"`bind`). The wrappers can also provide docstrings!\n",
|
|
"\n",
|
|
"We represent active interpreters as a stack. The stack is just a simple\n",
|
|
"`list`, and each element is a container with an integer level (corresponding\n",
|
|
"to the element's height in the stack), an interpreter type (which we'll call a\n",
|
|
"`trace_type`), and an optional field for any global data the interpreter\n",
|
|
"needs. We call each element a `MainTrace`, though maybe \"Interpreter\" would be\n",
|
|
"more descriptive."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from contextlib import contextmanager\n",
|
|
"from typing import Type, List, Optional, Any\n",
|
|
"\n",
|
|
"class MainTrace(NamedTuple):\n",
|
|
" level: int\n",
|
|
" trace_type: Type['Trace']\n",
|
|
" global_data: Optional[Any]\n",
|
|
"\n",
|
|
"trace_stack: List[MainTrace] = []\n",
|
|
"\n",
|
|
"@contextmanager\n",
|
|
"def new_main(trace_type: Type['Trace'], global_data=None):\n",
|
|
" level = len(trace_stack)\n",
|
|
" main = MainTrace(level, trace_type, global_data)\n",
|
|
" trace_stack.append(main)\n",
|
|
"\n",
|
|
" try:\n",
|
|
" yield main\n",
|
|
" finally:\n",
|
|
" trace_stack.pop()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"When we're about to apply a transformed function, we'll push another\n",
|
|
"interpreter onto the stack using `new_main`. Then, as we apply primitives in\n",
|
|
"the function, we can think of the `bind` first being interpreted by the trace\n",
|
|
"at the top of the stack (i.e. with the highest level). If that first\n",
|
|
"interpreter itself binds other primitives in its interpretation rule for the\n",
|
|
"primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`,\n",
|
|
"then those `bind` calls will be handled by the interpreter at the next level\n",
|
|
"down.\n",
|
|
"\n",
|
|
"What goes at the bottom of the interpreter stack? At the bottom, we know all\n",
|
|
"the transformation interpreters are finished, and we just want to do standard\n",
|
|
"evaluation. So at the bottom we'll put an evaluation interpreter.\n",
|
|
"\n",
|
|
"Let's sketch out the interface for interpreters, which is based on the `Trace`\n",
|
|
"and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps\n",
|
|
"carrying some extra context data used by the interpreter. A `Trace` handles\n",
|
|
"boxing up vales into `Tracers` and also handles primitive application."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Trace:\n",
|
|
" main: MainTrace\n",
|
|
"\n",
|
|
" def __init__(self, main: MainTrace) -> None:\n",
|
|
" self.main = main\n",
|
|
"\n",
|
|
" def pure(self, val): assert False # must override\n",
|
|
" def lift(self, val): assert False # must override\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" assert False # must override"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The first two methods are about boxing up values in `Tracer`s, which are the\n",
|
|
"objects that flow through the Python programs we transform. The last method is\n",
|
|
"the callback we'll use to interpret primitive application.\n",
|
|
"\n",
|
|
"The `Trace` itself doesn't contain any data, other than a reference to its\n",
|
|
"corresponding `MainTrace` instance. In fact, multiple instances of a `Trace`\n",
|
|
"might be created and discarded during an application of a transformation,\n",
|
|
"whereas only a single `MainTrace` instance is created per application of a\n",
|
|
"transformation.\n",
|
|
"\n",
|
|
"As for `Tracer`s themselves, each one carries an abstract value (and forwards\n",
|
|
"infix operators to it), and the rest is up to the transformation. (The\n",
|
|
"relationship between `Tracer`s and `AbstractValue`s is that there's one\n",
|
|
"`Tracer` per transformation, and at least one `AbstractValue` per base type,\n",
|
|
"like arrays.)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"from typing import Tuple\n",
|
|
"\n",
|
|
"class Tracer:\n",
|
|
" _trace: Trace\n",
|
|
"\n",
|
|
" __array_priority__ = 1000\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def aval(self):\n",
|
|
" assert False # must override\n",
|
|
"\n",
|
|
" def full_lower(self):\n",
|
|
" return self # default implementation\n",
|
|
"\n",
|
|
" def __neg__(self): return self.aval._neg(self)\n",
|
|
" def __add__(self, other): return self.aval._add(self, other)\n",
|
|
" def __radd__(self, other): return self.aval._radd(self, other)\n",
|
|
" def __mul__(self, other): return self.aval._mul(self, other)\n",
|
|
" def __rmul__(self, other): return self.aval._rmul(self, other)\n",
|
|
" def __gt__(self, other): return self.aval._gt(self, other)\n",
|
|
" def __bool__(self): return self.aval._bool(self)\n",
|
|
" def __nonzero__(self): return self.aval._nonzero(self)\n",
|
|
"\n",
|
|
" def __getattr__(self, name):\n",
|
|
" try:\n",
|
|
" return getattr(self.aval, name)\n",
|
|
" except AttributeError:\n",
|
|
" raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")\n",
|
|
"\n",
|
|
"class ShapedArray:\n",
|
|
" array_abstraction_level = 1\n",
|
|
" shape: Tuple[int]\n",
|
|
" dtype: np.dtype\n",
|
|
"\n",
|
|
" def __init__(self, shape, dtype):\n",
|
|
" self.shape = shape\n",
|
|
" self.dtype = dtype\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def ndim(self):\n",
|
|
" return len(self.shape)\n",
|
|
"\n",
|
|
" _neg = staticmethod(neg)\n",
|
|
" _add = staticmethod(add)\n",
|
|
" _radd = staticmethod(add)\n",
|
|
" _mul = staticmethod(mul)\n",
|
|
" _rmul = staticmethod(mul)\n",
|
|
" _gt = staticmethod(greater)\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def _bool(tracer):\n",
|
|
" raise Exception(\"ShapedArray can't be unambiguously converted to bool\")\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def _nonzero(tracer):\n",
|
|
" raise Exception(\"ShapedArray can't be unambiguously converted to bool\")\n",
|
|
"\n",
|
|
" def str_short(self):\n",
|
|
" return f'{self.dtype.name}[{\",\".join(str(d) for d in self.shape)}]'\n",
|
|
"\n",
|
|
"class ConcreteArray(ShapedArray):\n",
|
|
" array_abstraction_level = 2\n",
|
|
" val: np.ndarray\n",
|
|
"\n",
|
|
" def __init__(self, val):\n",
|
|
" self.val = val\n",
|
|
" self.shape = val.shape\n",
|
|
" self.dtype = val.dtype\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def _bool(tracer):\n",
|
|
" return bool(tracer.aval.val)\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def _nonzero(tracer):\n",
|
|
" return bool(tracer.aval.val)\n",
|
|
"\n",
|
|
"def get_aval(x):\n",
|
|
" if isinstance(x, Tracer):\n",
|
|
" return x.aval\n",
|
|
" else:\n",
|
|
" return ConcreteArray(np.asarray(x))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice that we actually have two `AbstractValue`s for arrays, representing\n",
|
|
"different levels of abstraction. A `ShapedArray` represents the set of all\n",
|
|
"possible arrays with a given shape and dtype. A `ConcreteArray` represents a\n",
|
|
"singleton set consisting of a single array value.\n",
|
|
"\n",
|
|
"Now that we've set up the trace stack, the Trace/Tracer API for interpreters,\n",
|
|
"and abstract values, we can come back to implement `bind`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def bind(prim, *args, **params):\n",
|
|
" top_trace = find_top_trace(args)\n",
|
|
" tracers = [full_raise(top_trace, arg) for arg in args]\n",
|
|
" out = top_trace.process_primitive(prim, tracers, params)\n",
|
|
" return full_lower(out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The main action is that we call `find_top_trace` to figure out which\n",
|
|
"interpreter should handle this primitive application as a function of the\n",
|
|
"arguments and the active traces on the trace stack. We then call that top\n",
|
|
"trace's `process_primitive` so that the trace can apply its interpretation\n",
|
|
"rule. The calls to `full_raise` just ensure that the inputs are boxed in the\n",
|
|
"top trace's `Tracer` instances, and the call to `full_lower` is an optional\n",
|
|
"optimization so that we unbox values out of `Tracer`s as much as possible."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from operator import attrgetter\n",
|
|
"\n",
|
|
"def find_top_trace(xs) -> Trace:\n",
|
|
" top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),\n",
|
|
" default=trace_stack[0], key=attrgetter('level'))\n",
|
|
" return top_main.trace_type(top_main)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In words, `find_top_trace` returns the highest-level interpreter associated\n",
|
|
"with the `Tracer`s on its inputs, and otherwise returns the interpreter at the\n",
|
|
"bottom of the stack (which is always an evaluation trace, at least for now).\n",
|
|
"This corresponds to JAX transformations mostly working by data dependence\n",
|
|
"_except_ for the special bottom-of-the-stack interpreter, which interprets\n",
|
|
"everything."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def full_lower(val):\n",
|
|
" if isinstance(val, Tracer):\n",
|
|
" return val.full_lower()\n",
|
|
" else:\n",
|
|
" return val\n",
|
|
"\n",
|
|
"def full_raise(trace, val) -> Tracer:\n",
|
|
" if not isinstance(val, Tracer):\n",
|
|
" return trace.pure(val)\n",
|
|
" level = trace.main.level\n",
|
|
" if val._trace.main is trace.main:\n",
|
|
" return val\n",
|
|
" elif val._trace.main.level < level:\n",
|
|
" return trace.lift(val)\n",
|
|
" elif val._trace.main.level > level:\n",
|
|
" raise Exception(f\"Can't lift level {val._trace.main.level} to {level}.\")\n",
|
|
" else: # val._trace.level == level\n",
|
|
" raise Exception(f\"Different traces at same level: {val._trace}, {trace}.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The logic in `full_raise` serves to box values into `Tracer`s for a particular\n",
|
|
"`Trace`, calling different methods on the `Trace` based on context:\n",
|
|
"`Trace.pure` is called on non-`Tracer` constants, and `Trace.lift` is called\n",
|
|
"for values that are already `Tracer`s from a lower-level interpreter. These\n",
|
|
"two methods could share the same implementation, but by distinguishing them in\n",
|
|
"the core logic we can provide more information to the `Trace` subclass.\n",
|
|
"\n",
|
|
"That's it for the JAX core! Now we can start adding interpreters."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Evaluation interpreter\n",
|
|
"\n",
|
|
"We'll start with the simplest interpreter: the evaluation interpreter that\n",
|
|
"will sit at the bottom of the interpreter stack."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class EvalTrace(Trace):\n",
|
|
" pure = lift = lambda self, x: x # no boxing in Tracers needed\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" return impl_rules[primitive](*tracers, **params)\n",
|
|
"\n",
|
|
"trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack\n",
|
|
"\n",
|
|
"impl_rules = {}\n",
|
|
"impl_rules[add_p] = np.add\n",
|
|
"impl_rules[mul_p] = np.multiply\n",
|
|
"impl_rules[neg_p] = np.negative\n",
|
|
"impl_rules[sin_p] = np.sin\n",
|
|
"impl_rules[cos_p] = np.cos\n",
|
|
"impl_rules[reduce_sum_p] = np.sum\n",
|
|
"impl_rules[greater_p] = np.greater"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With this interpreter, we can evaluate user functions:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2.7177599838802657\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"\n",
|
|
"print(f(3.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Woo! Like going around in a big circle. But the point of this indirection is\n",
|
|
"that now we can add some real transformations."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Forward-mode autodiff with `jvp`\n",
|
|
"\n",
|
|
"First, a couple of helper functions:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def zeros_like(val):\n",
|
|
" return np.zeros_like(val)\n",
|
|
"\n",
|
|
"def unzip2(pairs):\n",
|
|
" lst1, lst2 = [], []\n",
|
|
" for x1, x2 in pairs:\n",
|
|
" lst1.append(x1)\n",
|
|
" lst2.append(x2)\n",
|
|
" return lst1, lst2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The\n",
|
|
"`Trace` applies JVP rules."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class JVPTracer(Tracer):\n",
|
|
" def __init__(self, trace, primal, tangent):\n",
|
|
" self._trace = trace\n",
|
|
" self.primal = primal\n",
|
|
" self.tangent = tangent\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def aval(self):\n",
|
|
" return get_aval(self.primal)\n",
|
|
"\n",
|
|
"class JVPTrace(Trace):\n",
|
|
" pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)\n",
|
|
" jvp_rule = jvp_rules[primitive]\n",
|
|
" primal_out, tangent_out = jvp_rule(primals_in, tangents_in, **params)\n",
|
|
" return JVPTracer(self, primal_out, tangent_out)\n",
|
|
"\n",
|
|
"jvp_rules = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice both `lift` and `sublift` package a value into a `JVPTracer` with the\n",
|
|
"minimal amount of context, which is a zero tangent value."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's add some JVP rules for primitives:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def add_jvp(primals, tangents):\n",
|
|
" (x, y), (x_dot, y_dot) = primals, tangents\n",
|
|
" return x + y, x_dot + y_dot\n",
|
|
"jvp_rules[add_p] = add_jvp\n",
|
|
"\n",
|
|
"def mul_jvp(primals, tangents):\n",
|
|
" (x, y), (x_dot, y_dot) = primals, tangents\n",
|
|
" return x * y, x_dot * y + x * y_dot\n",
|
|
"jvp_rules[mul_p] = mul_jvp\n",
|
|
"\n",
|
|
"def sin_jvp(primals, tangents):\n",
|
|
" (x,), (x_dot,) = primals, tangents\n",
|
|
" return sin(x), cos(x) * x_dot\n",
|
|
"jvp_rules[sin_p] = sin_jvp\n",
|
|
"\n",
|
|
"def cos_jvp(primals, tangents):\n",
|
|
" (x,), (x_dot,) = primals, tangents\n",
|
|
" return cos(x), -sin(x) * x_dot\n",
|
|
"jvp_rules[cos_p] = cos_jvp\n",
|
|
"\n",
|
|
"def neg_jvp(primals, tangents):\n",
|
|
" (x,), (x_dot,) = primals, tangents\n",
|
|
" return neg(x), neg(x_dot)\n",
|
|
"jvp_rules[neg_p] = neg_jvp\n",
|
|
"\n",
|
|
"def reduce_sum_jvp(primals, tangents, *, axis):\n",
|
|
" (x,), (x_dot,) = primals, tangents\n",
|
|
" return reduce_sum(x, axis), reduce_sum(x_dot, axis)\n",
|
|
"jvp_rules[reduce_sum_p] = reduce_sum_jvp\n",
|
|
"\n",
|
|
"def greater_jvp(primals, tangents):\n",
|
|
" (x, y), _ = primals, tangents\n",
|
|
" out_primal = greater(x, y)\n",
|
|
" return out_primal, zeros_like(out_primal)\n",
|
|
"jvp_rules[greater_p] = greater_jvp"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we add a transformation API to kick off the trace:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def jvp(f, primals, tangents):\n",
|
|
" with new_main(JVPTrace) as main:\n",
|
|
" trace = JVPTrace(main)\n",
|
|
" tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]\n",
|
|
" out = f(*tracers_in)\n",
|
|
" tracer_out = full_raise(trace, out)\n",
|
|
" primal_out, tangent_out = tracer_out.primal, tracer_out.tangent\n",
|
|
" return primal_out, tangent_out"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And with that, we can differentiate!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"-0.9899924966004454\n",
|
|
"-0.9899924966004454\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"x = 3.0\n",
|
|
"y, sin_deriv_at_3 = jvp(sin, (x,), (1.0,))\n",
|
|
"print(sin_deriv_at_3)\n",
|
|
"print(cos(3.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2.7177599838802657\n",
|
|
"2.979984993200891\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" y = sin(x) * 2\n",
|
|
" z = - y + x\n",
|
|
" return z\n",
|
|
"\n",
|
|
"x, xdot = 3., 1.\n",
|
|
"y, ydot = jvp(f, (x,), (xdot,))\n",
|
|
"print(y)\n",
|
|
"print(ydot)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"-0.9899924966004454\n",
|
|
"-0.1411200080598672\n",
|
|
"0.9899924966004454\n",
|
|
"0.1411200080598672\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def deriv(f):\n",
|
|
" return lambda x: jvp(f, (x,), (1.,))[1]\n",
|
|
"\n",
|
|
"print(deriv(sin)(3.))\n",
|
|
"print(deriv(deriv(sin))(3.))\n",
|
|
"print(deriv(deriv(deriv(sin)))(3.))\n",
|
|
"print(deriv(deriv(deriv(deriv(sin))))(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2.0\n",
|
|
"1.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" if x > 0.: # Python control flow\n",
|
|
" return 2. * x\n",
|
|
" else:\n",
|
|
" return x\n",
|
|
"\n",
|
|
"print(deriv(f)(3.))\n",
|
|
"print(deriv(f)(-3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Vectorized batching with `vmap`\n",
|
|
"\n",
|
|
"First, a couple helper functions, one for producing mapped abstract values\n",
|
|
"from unmapped ones (by removing an axis), and one for moving batch dimensions\n",
|
|
"around:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def mapped_aval(batch_dim, aval):\n",
|
|
" shape = list(aval.shape)\n",
|
|
" del shape[batch_dim]\n",
|
|
" return ShapedArray(tuple(shape), aval.dtype)\n",
|
|
"\n",
|
|
"def move_batch_axis(axis_size, src, dst, x):\n",
|
|
" if src is not_mapped:\n",
|
|
" target_shape = list(np.shape(x))\n",
|
|
" target_shape.insert(dst, axis_size)\n",
|
|
" return np.broadcast_to(np.expand_dims(x, dst), target_shape)\n",
|
|
" else:\n",
|
|
" return np.moveaxis(x, src, dst)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `Tracer` for vectorized batching carries a batched value and an optional\n",
|
|
"integer indicating which axis (if any) is the batch axis."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Union\n",
|
|
"\n",
|
|
"class NotMapped: pass\n",
|
|
"not_mapped = NotMapped()\n",
|
|
"\n",
|
|
"class BatchTracer(Tracer):\n",
|
|
" def __init__(self, trace, val, batch_dim: Union[NotMapped, int]):\n",
|
|
" self._trace = trace\n",
|
|
" self.val = val\n",
|
|
" self.batch_dim = batch_dim\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def aval(self):\n",
|
|
" if self.batch_dim is not_mapped:\n",
|
|
" return get_aval(self.val)\n",
|
|
" else:\n",
|
|
" return mapped_aval(self.batch_dim, get_aval(self.val))\n",
|
|
"\n",
|
|
" def full_lower(self):\n",
|
|
" if self.batch_dim is not_mapped:\n",
|
|
" return full_lower(self.val)\n",
|
|
" else:\n",
|
|
" return self\n",
|
|
"\n",
|
|
"class BatchTrace(Trace):\n",
|
|
" pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)\n",
|
|
" vmap_rule = vmap_rules[primitive]\n",
|
|
" val_out, bdim_out = vmap_rule(self.axis_size, vals_in, bdims_in, **params)\n",
|
|
" return BatchTracer(self, val_out, bdim_out)\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def axis_size(self):\n",
|
|
" return self.main.global_data\n",
|
|
"\n",
|
|
"vmap_rules = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Here we've implemented the optional `Tracer.full_lower` method, which lets us\n",
|
|
"peel off a batching tracer if it's not needed because it doesn't represent a\n",
|
|
"batched value.\n",
|
|
"\n",
|
|
"For `BatchTrace`, analogous to `JVPTrace`, the methods `pure` and `lift` just\n",
|
|
"box a value in a `BatchTracer` with the minimal amount of context, which in\n",
|
|
"this case is a `batch_dim` taking the sentinel value `not_mapped`. Notice we\n",
|
|
"use the `MainTrace`'s interpreter-global data field to store the batch axis\n",
|
|
"size.\n",
|
|
"\n",
|
|
"Next we can define batching interpreter rules for each primitive:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from functools import partial\n",
|
|
"\n",
|
|
"def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):\n",
|
|
" (x, y), (x_bdim, y_bdim) = vals_in, dims_in\n",
|
|
" if x_bdim != y_bdim:\n",
|
|
" y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n",
|
|
" return op(x, y), x_bdim\n",
|
|
"vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)\n",
|
|
"vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)\n",
|
|
"\n",
|
|
"def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):\n",
|
|
" (x,), (x_bdim,) = vals_in, dims_in\n",
|
|
" return op(x), x_bdim\n",
|
|
"vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)\n",
|
|
"vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)\n",
|
|
"vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)\n",
|
|
"\n",
|
|
"def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):\n",
|
|
" (x,), (x_bdim,) = vals_in, dims_in\n",
|
|
" new_axis = axis + (x_bdim <= axis)\n",
|
|
" out_bdim = x_bdim - (new_axis < x_bdim)\n",
|
|
" return reduce_sum(x, new_axis), out_bdim\n",
|
|
"vmap_rules[reduce_sum_p] = reduce_sum_batching_rule"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we add a transformation API to kick off the trace:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def vmap(f, in_axes, out_axis):\n",
|
|
" def batched_f(*args):\n",
|
|
" axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)\n",
|
|
" if ax is not None}\n",
|
|
" with new_main(BatchTrace, axis_size) as main:\n",
|
|
" trace = BatchTrace(main)\n",
|
|
" tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x\n",
|
|
" for x, ax in zip(args, in_axes)]\n",
|
|
" out = f(*tracers_in)\n",
|
|
" tracer_out = full_raise(trace, out)\n",
|
|
" val_out, batch_dim_out = tracer_out.val, tracer_out.batch_dim\n",
|
|
" return move_batch_axis(axis_size, batch_dim_out, out_axis, val_out)\n",
|
|
" return batched_f"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0. 1. 2.]\n",
|
|
"[1. 2. 3.]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def add_one_to_a_scalar(scalar):\n",
|
|
" assert np.ndim(scalar) == 0\n",
|
|
" return 1 + scalar\n",
|
|
"\n",
|
|
"vector_in = np.arange(3.)\n",
|
|
"vector_out = vmap(add_one_to_a_scalar, (0,), 0)(vector_in)\n",
|
|
"\n",
|
|
"print(vector_in)\n",
|
|
"print(vector_out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[ 1. , 0. , -0. ],\n",
|
|
" [ 0. , 0.54030231, -0. ],\n",
|
|
" [ 0. , 0. , -0.41614684]])"
|
|
]
|
|
},
|
|
"execution_count": 172,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def jacfwd(f, x):\n",
|
|
" pushfwd = lambda v: jvp(f, (x,), (v,))[1]\n",
|
|
" vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)\n",
|
|
" return vmap(pushfwd, (0,), 0)(vecs_in)\n",
|
|
"\n",
|
|
"def f(x):\n",
|
|
" return sin(x)\n",
|
|
"\n",
|
|
"jacfwd(f, np.arange(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"That's it for `jvp` and `vmap`! Before moving on, let's highlight a few\n",
|
|
"simplifications in what we've seen so far compared to the full JAX\n",
|
|
"implementation:\n",
|
|
"1. **Fewer, simpler primitives.** More primitives means more interpretation\n",
|
|
"rules, and for more complex primitives (like for convolution or advanced\n",
|
|
"indexing) each rule is harder to write. But the overarching design is no\n",
|
|
"different.\n",
|
|
"1. **Transformations expect arrays in, single array out.**\n",
|
|
"2. **No symbolic zeros in autodiff.**\n",
|
|
"3. **No special call primitives yet.** The core machinery needs to be\n",
|
|
" generalized to handle the most flexible kind of higher-order primitive,\n",
|
|
" used by `jax.custom_jvp` and `jax.custom_vjp`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Part 2: Jaxprs, for `jit` and `vjp`\n",
|
|
"\n",
|
|
"The next transformations are the horizon are `jit` for just-in-time\n",
|
|
"compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small\n",
|
|
"wrapper around `vjp`.) For `jvp` and `vmap` we only needed each `Tracer` to\n",
|
|
"carry a little bit of extra context, but for both `jit` and `vjp` we need\n",
|
|
"much richer context: we need to represent _programs_. That is, we need jaxprs!\n",
|
|
"\n",
|
|
"Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are\n",
|
|
"an explicitly typed, functional, first-order language. We need a program\n",
|
|
"representation for `jit` because the purpose of `jit` is to stage computation\n",
|
|
"out of Python. For any computation we want to stage out, we need to be able to\n",
|
|
"represent it as data, and build it up as we trace a Python function.\n",
|
|
"Similarly, `vjp` needs a way to represent the computation for the backward\n",
|
|
"pass of reverse-mode autodiff. We use the same jaxpr program representation\n",
|
|
"for both needs.\n",
|
|
"\n",
|
|
"(Building a program representation is the most\n",
|
|
"[free](https://en.wikipedia.org/wiki/Free_object) kind of\n",
|
|
"trace- transformation, and so except for issues around handling native Python\n",
|
|
"control flow, any transformation could be implemented by first tracing to a\n",
|
|
"jaxpr and then interpreting the jaxpr.)\n",
|
|
"\n",
|
|
"The jaxpr term syntax is roughly:\n",
|
|
"\n",
|
|
"```\n",
|
|
"jaxpr ::=\n",
|
|
" { lambda <binder> , ... .\n",
|
|
" let <eqn>\n",
|
|
" ...\n",
|
|
" in <atom> }\n",
|
|
"\n",
|
|
"binder ::= <var>:<array_type>\n",
|
|
"var ::= a | b | c | ...\n",
|
|
"atom ::= <var> | <literal>\n",
|
|
"literal ::= <int32> | <float32>\n",
|
|
"\n",
|
|
"eqn ::= <binder> = <primitive> [ <params> ] <atom> , ...\n",
|
|
"```\n",
|
|
"\n",
|
|
"The syntax of types is:\n",
|
|
"\n",
|
|
"```\n",
|
|
"jaxpr_type ::= [<array_type>, ...] -> [<array_type>, ...]\n",
|
|
"array_type ::= <dtype>[<shape>]\n",
|
|
"dtype ::= f32 | f64 | i32 | i64\n",
|
|
"shape ::= <int> , ...\n",
|
|
"```\n",
|
|
"\n",
|
|
"How do we represent these as Python data structures? We reuse ShapedArrays to\n",
|
|
"represent types, and we can represent the term syntax with a few Python\n",
|
|
"structs:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Dict, Set\n",
|
|
"\n",
|
|
"class Var:\n",
|
|
" aval: ShapedArray\n",
|
|
" def __init__(self, aval): self.aval = aval\n",
|
|
"\n",
|
|
"class Lit:\n",
|
|
" val: Any\n",
|
|
" aval: ShapedArray\n",
|
|
"\n",
|
|
" def __init__(self, val):\n",
|
|
" self.val = val\n",
|
|
" self.aval = raise_to_shaped(get_aval(self.val))\n",
|
|
"\n",
|
|
"Atom = Union[Var, Lit]\n",
|
|
"\n",
|
|
"class JaxprEqn(NamedTuple):\n",
|
|
" primitive: Primitive\n",
|
|
" inputs: List[Atom]\n",
|
|
" params: Dict[str, Any]\n",
|
|
" out_binder: Var\n",
|
|
"\n",
|
|
"class Jaxpr(NamedTuple):\n",
|
|
" in_binders: List[Var]\n",
|
|
" eqns: List[JaxprEqn]\n",
|
|
" out: Atom\n",
|
|
"\n",
|
|
"\n",
|
|
"def raise_to_shaped(aval):\n",
|
|
" return ShapedArray(aval.shape, aval.dtype)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "composite-dinner",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class JaxprType:\n",
|
|
" in_types: List[ShapedArray]\n",
|
|
" out_type: ShapedArray\n",
|
|
"\n",
|
|
" def __init__(self, in_types, out_type):\n",
|
|
" self.in_types = in_types\n",
|
|
" self.out_type = out_type\n",
|
|
"\n",
|
|
" def __repr__(self):\n",
|
|
" in_types = ', '.join(aval.str_short() for aval in self.in_types)\n",
|
|
" out_type = self.out_type.str_short()\n",
|
|
" return f'({in_types}) -> {out_type}'\n",
|
|
"\n",
|
|
"\n",
|
|
"def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:\n",
|
|
" env: Set[Var] = set()\n",
|
|
"\n",
|
|
" for v in jaxpr.in_binders:\n",
|
|
" env.add(v)\n",
|
|
"\n",
|
|
" for eqn in jaxpr.eqns:\n",
|
|
" in_types = [typecheck_atom(env, x) for x in eqn.inputs]\n",
|
|
" out_type = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)\n",
|
|
" if not types_equal(out_type, eqn.out_binder.aval): raise TypeError\n",
|
|
" env.add(eqn.out_binder)\n",
|
|
"\n",
|
|
" out_type = typecheck_atom(env, jaxpr.out)\n",
|
|
" return JaxprType([v.aval for v in jaxpr.in_binders], out_type)\n",
|
|
"\n",
|
|
"def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray:\n",
|
|
" if isinstance(x, Var):\n",
|
|
" if x not in env: raise TypeError(\"unbound variable\")\n",
|
|
" return x.aval\n",
|
|
" elif isinstance(x, Lit):\n",
|
|
" return raise_to_shaped(get_aval(x.val))\n",
|
|
" else:\n",
|
|
" assert False\n",
|
|
"\n",
|
|
"def types_equal(a: ShapedArray, b: ShapedArray) -> bool:\n",
|
|
" return a.shape == b.shape and a.dtype == b.dtype"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now that we have jaxprs as a data structure, we need ways to produce these\n",
|
|
"from tracing Python code. In general there are two variants of how we trace to\n",
|
|
"a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one\n",
|
|
"used by `jit`, which is also used by control flow primitives like\n",
|
|
"`lax.cond`, `lax.while_loop`, and `lax.scan`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NB: the analogous class in JAX is called 'DynamicJaxprTracer'\n",
|
|
"class JaxprTracer(Tracer):\n",
|
|
" __slots__ = ['aval']\n",
|
|
" aval: ShapedArray\n",
|
|
"\n",
|
|
" def __init__(self, trace, aval):\n",
|
|
" self._trace = trace\n",
|
|
" self.aval = aval\n",
|
|
"\n",
|
|
"# NB: the analogous class in JAX is called 'DynamicJaxprTrace'\n",
|
|
"class JaxprTrace(Trace):\n",
|
|
" def new_arg(self, aval: ShapedArray) -> JaxprTracer:\n",
|
|
" aval = raise_to_shaped(aval)\n",
|
|
" tracer = JaxprTracer(self, aval)\n",
|
|
" self.builder.tracer_to_var[id(tracer)] = Var(aval)\n",
|
|
" return tracer\n",
|
|
"\n",
|
|
" def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:\n",
|
|
" tracer = self.builder.const_tracers.get(id(val))\n",
|
|
" if tracer is None:\n",
|
|
" tracer = JaxprTracer(self, raise_to_shaped(get_aval(val)))\n",
|
|
" self.builder.add_const(tracer, val)\n",
|
|
" return tracer\n",
|
|
" pure = lift = get_or_make_const_tracer\n",
|
|
"\n",
|
|
" def process_primitive(self, primitive, tracers, params):\n",
|
|
" avals_in = [t.aval for t in tracers]\n",
|
|
" aval_out = abstract_eval_rules[primitive](*avals_in, **params)\n",
|
|
" out_tracer = JaxprTracer(self, aval_out)\n",
|
|
" inputs = [self.builder.getvar(t) for t in tracers]\n",
|
|
" outvar = self.builder.add_var(out_tracer)\n",
|
|
" self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvar))\n",
|
|
" return out_tracer\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def builder(self):\n",
|
|
" return self.main.global_data\n",
|
|
"\n",
|
|
"# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance\n",
|
|
"abstract_eval_rules = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice that we keep as interpreter-global data a builder object, which keeps\n",
|
|
"track of variables, constants, and eqns as we build up the jaxpr."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class JaxprBuilder:\n",
|
|
" eqns: List[JaxprEqn]\n",
|
|
" tracer_to_var: Dict[int, Var]\n",
|
|
" const_tracers: Dict[int, JaxprTracer]\n",
|
|
" constvals: Dict[Var, Any]\n",
|
|
"\n",
|
|
" def __init__(self):\n",
|
|
" self.eqns = []\n",
|
|
" self.tracer_to_var = {}\n",
|
|
" self.const_tracers = {}\n",
|
|
" self.constvals = {}\n",
|
|
"\n",
|
|
" def add_eqn(self, eqn: JaxprEqn) -> None:\n",
|
|
" self.eqns.append(eqn)\n",
|
|
"\n",
|
|
" def add_var(self, tracer: JaxprTracer) -> Var:\n",
|
|
" var = self.tracer_to_var.get(id(tracer))\n",
|
|
" assert var is None\n",
|
|
" var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)\n",
|
|
" return var\n",
|
|
"\n",
|
|
" def getvar(self, tracer: JaxprTracer) -> Var:\n",
|
|
" var = self.tracer_to_var.get(id(tracer))\n",
|
|
" assert var is not None\n",
|
|
" return var\n",
|
|
"\n",
|
|
" def add_const(self, tracer: JaxprTracer, val: Any) -> Var:\n",
|
|
" var = self.add_var(tracer)\n",
|
|
" self.const_tracers[id(val)] = tracer\n",
|
|
" self.constvals[var] = val\n",
|
|
" return var\n",
|
|
"\n",
|
|
" def build(self, in_tracers: List[JaxprTracer], out_tracer: JaxprTracer\n",
|
|
" ) -> Tuple[Jaxpr, List[Any]]:\n",
|
|
" constvars, constvals = unzip2(self.constvals.items())\n",
|
|
" t2v = lambda t: self.tracer_to_var[id(t)]\n",
|
|
" in_binders = constvars + [t2v(t) for t in in_tracers]\n",
|
|
" jaxpr = Jaxpr(in_binders, self.eqns, t2v(out_tracer))\n",
|
|
" typecheck_jaxpr(jaxpr)\n",
|
|
" return jaxpr, constvals"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The rules we need for `JaxprTrace.process_primitive` are essentially typing\n",
|
|
"rules for primitive applications: given the primitive, its parameters, and\n",
|
|
"types for the inputs, the rule must produce a type for the output, which is\n",
|
|
"then packaged with the output `JaxprTracer`. We can use abstract evaluation\n",
|
|
"rules for this same purpose, even though they can be more general (since\n",
|
|
"abstract evaluation rules need to work on ConcreteArray inputs as well). We'll\n",
|
|
"reuse these abstract evaluation rules for the other jaxpr-producing trace\n",
|
|
"machinery, where the potential extra generality is useful."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def broadcast_shapes(*shapes):\n",
|
|
" assert len(shapes) > 1\n",
|
|
" for sizes in zip(*shapes):\n",
|
|
" sizes = [d for d in sizes if d != 1]\n",
|
|
" if sizes[:-1] != sizes[1:]:\n",
|
|
" raise Exception\n",
|
|
" return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def broadcasting_binop_abstract_eval_rule(*avals_in):\n",
|
|
" out_dtype = np.result_type(*map(np.result_type, avals_in))\n",
|
|
" out_shape = broadcast_shapes(*map(np.shape, avals_in))\n",
|
|
" return ShapedArray(out_shape, out_dtype)\n",
|
|
"\n",
|
|
"abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule\n",
|
|
"abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule\n",
|
|
"\n",
|
|
"def vectorized_unop_abstract_eval_rule(aval_in):\n",
|
|
" return ShapedArray(np.shape(aval_in), np.result_type(aval_in))\n",
|
|
"\n",
|
|
"abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule\n",
|
|
"abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule\n",
|
|
"abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval_rule\n",
|
|
"\n",
|
|
"def reduce_sum_abstract_eval_rule(aval_in, *, axis):\n",
|
|
" new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis]\n",
|
|
" return ShapedArray(tuple(new_shape), aval_in.dtype)\n",
|
|
"abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To check our implementation, we can add a `make_jaxpr` transformation and\n",
|
|
"first pretty-printer:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "defensive-ownership",
|
|
"metadata": {
|
|
"lines_to_next_cell": 1
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def make_jaxpr(f, avals_in):\n",
|
|
" builder = JaxprBuilder()\n",
|
|
" with new_main(JaxprTrace, builder) as main:\n",
|
|
" trace = JaxprTrace(main)\n",
|
|
" tracers_in = [trace.new_arg(aval) for aval in avals_in]\n",
|
|
" out = f(*tracers_in)\n",
|
|
" tracer_out = full_raise(trace, out)\n",
|
|
" return builder.build(tracers_in, tracer_out)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "adopted-month",
|
|
"metadata": {
|
|
"lines_to_next_cell": 1
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from collections import defaultdict\n",
|
|
"import itertools as it\n",
|
|
"import string\n",
|
|
"\n",
|
|
"class PPrint:\n",
|
|
" lines: List[Tuple[int, str]]\n",
|
|
"\n",
|
|
" def __init__(self, lines):\n",
|
|
" self.lines = lines\n",
|
|
"\n",
|
|
" def indent(self, indent: int) -> 'PPrint':\n",
|
|
" return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])\n",
|
|
"\n",
|
|
" def __add__(self, rhs: 'PPrint') -> 'PPrint':\n",
|
|
" return PPrint(self.lines + rhs.lines)\n",
|
|
"\n",
|
|
" def __rshift__(self, rhs: 'PPrint') -> 'PPrint':\n",
|
|
" if not rhs.lines: return self\n",
|
|
" if not self.lines: return rhs\n",
|
|
" indent, s = self.lines[-1]\n",
|
|
" indented_block = rhs.indent(indent + len(s))\n",
|
|
" common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]\n",
|
|
" return PPrint(self.lines[:-1]\n",
|
|
" + [(indent, common_line)]\n",
|
|
" + indented_block.lines[1:])\n",
|
|
"\n",
|
|
" def __str__(self) -> str:\n",
|
|
" return '\\n'.join(' ' * indent + s for indent, s in self.lines)\n",
|
|
"\n",
|
|
"def pp(s: Any) -> PPrint:\n",
|
|
" return PPrint([(0, line) for line in str(s).splitlines()])\n",
|
|
"\n",
|
|
"def vcat(ps: List[PPrint]) -> PPrint:\n",
|
|
" return sum(ps, pp(''))\n",
|
|
"\n",
|
|
"def pp_jaxpr(jaxpr: Jaxpr):\n",
|
|
" namegen = (''.join(s) for r in it.count(1)\n",
|
|
" for s in it.permutations(string.ascii_lowercase, r))\n",
|
|
" names = defaultdict(lambda: next(namegen))\n",
|
|
" in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)\n",
|
|
" eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])\n",
|
|
" out = names[jaxpr.out] if isinstance(jaxpr.out, Var) else str(jaxpr.out.val)\n",
|
|
" return (pp(f'{{ lambda {in_binders} .') +\n",
|
|
" ((pp('let ') >> eqns) + pp(f'in {out} }}')).indent(2))\n",
|
|
"\n",
|
|
"def var_str(names: Dict[Var, str], v: Var) -> str:\n",
|
|
" return f'{names[v]}:{v.aval.str_short()}'\n",
|
|
"\n",
|
|
"def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
|
|
" lhs = pp(var_str(names, eqn.out_binder))\n",
|
|
" rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
|
|
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
|
|
" for x in eqn.inputs)))\n",
|
|
" return lhs >> pp(' = ') >> rhs\n",
|
|
"\n",
|
|
"def pp_params(params: Dict[str, Any]) -> PPrint:\n",
|
|
" items = sorted(params.items())\n",
|
|
" if items:\n",
|
|
" return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')\n",
|
|
" else:\n",
|
|
" return pp(' ')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"lines_to_next_cell": 1
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"jaxpr, consts = make_jaxpr(lambda x: 2. * x, [raise_to_shaped(get_aval(3.))])\n",
|
|
"print(pp_jaxpr(jaxpr))\n",
|
|
"print(typecheck_jaxpr(jaxpr))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"last_runtime": {
|
|
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
|
|
"kind": "private"
|
|
},
|
|
"name": "Autodidax: JAX core from scratch",
|
|
"provenance": [],
|
|
"toc_visible": true
|
|
},
|
|
"jupytext": {
|
|
"formats": "ipynb,md:myst,py",
|
|
"main_language": "python"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|