rocm_jax/docs/autodidax.ipynb
2021-02-26 09:13:19 -08:00

1443 lines
47 KiB
Plaintext

{
"cells": [
{
"cell_type": "raw",
"id": "surrounded-agent",
"metadata": {},
"source": [
"---\n",
"Copyright 2021 Google LLC\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License.\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Autodidax: JAX core from scratch\n",
"\n",
"Ever want to learn how JAX works, but the implementation seemed too\n",
"impenetrable? Well, you're in luck! By reading this tutorial, you'll learn\n",
"every big idea in JAX's core system. You'll even get clued into our weird\n",
"jargon!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 1: Transformations as interpreters: standard evaluation, `jvp`, and `vmap`\n",
"\n",
"We want to transform functions that look like this:\n",
"\n",
"```python\n",
"def f(x):\n",
" y = sin(x) * 2\n",
" z = - y + x\n",
" return z\n",
"```\n",
"\n",
"Think of functions like `sin` and the arithmetic operations underlying the\n",
"infix operators (`mul`, `add`, and `neg`) as primitive operations, meaning\n",
"atomic units of processing rather than compositions.\n",
"\n",
"\"Transform\" means \"interpret differently.\" Instead of standard interpretation\n",
"where we apply primitive functions to numerical inputs to produce numerical\n",
"outputs, we want to override primitive application and let different values\n",
"flow through our program. For example, we might want to replace the\n",
"application of every primitive with an application of [its JVP\n",
"rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n",
"and let primal-tangent pairs flow through our program. Moreover, we want to\n",
"apply a composition of multiple transformations, leading to stacks of\n",
"interpreters."
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 2
},
"source": [
"### JAX core machinery\n",
"\n",
"We can implement stacks of interpreters and even have them all discharge on\n",
"the fly as we execute the Python function to be transformed. To start, let's\n",
"define these primitives so that we can intercept their application:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import NamedTuple\n",
"\n",
"class Primitive(NamedTuple):\n",
" name: str\n",
"\n",
"add_p = Primitive('add')\n",
"mul_p = Primitive('mul')\n",
"neg_p = Primitive(\"neg\")\n",
"sin_p = Primitive(\"sin\")\n",
"cos_p = Primitive(\"cos\")\n",
"reduce_sum_p = Primitive(\"reduce_sum\")\n",
"greater_p = Primitive(\"greater\")\n",
"\n",
"def add(x, y): return bind(add_p, x, y)\n",
"def mul(x, y): return bind(mul_p, x, y)\n",
"def neg(x): return bind(neg_p, x)\n",
"def sin(x): return bind(sin_p, x)\n",
"def cos(x): return bind(cos_p, x)\n",
"def reduce_sum(x, axis=None): return bind(reduce_sum_p, x, axis=axis)\n",
"def greater(x, y): return bind(greater_p, x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll set up array data types and infix operator methods in a moment.\n",
"\n",
"A `Primitive` is just an object with a name, to which we attach our\n",
"interpretation rules (one for each transformation). The `bind` function is our\n",
"interception point: it'll figure out which transformation rule to apply, based\n",
"on how the arguments are boxed in tracers and what interpreters are active.\n",
"\n",
"The functions that user code calls, like `add` and `sin`, are just wrappers\n",
"around calls to `bind`. These wrappers let us control how arguments are passed\n",
"to `bind`, and in particular we follow a handy internal convention: when we\n",
"call `bind`, we pass values representing array data as positional arguments,\n",
"and we pass metadata like the `axis` argument to `sum_p` via keyword. This\n",
"calling convention simplifies some core logic (since e.g. instances of the\n",
"`Tracer` class to be defined below can only occurr in positional arguments to\n",
"`bind`). The wrappers can also provide docstrings!\n",
"\n",
"We represent active interpreters as a stack. The stack is just a simple\n",
"`list`, and each element is a container with an integer level (corresponding\n",
"to the element's height in the stack), an interpreter type (which we'll call a\n",
"`trace_type`), and an optional field for any global data the interpreter\n",
"needs. We call each element a `MainTrace`, though maybe \"Interpreter\" would be\n",
"more descriptive."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from contextlib import contextmanager\n",
"from typing import Type, List, Optional, Any\n",
"\n",
"class MainTrace(NamedTuple):\n",
" level: int\n",
" trace_type: Type['Trace']\n",
" global_data: Optional[Any]\n",
"\n",
"trace_stack: List[MainTrace] = []\n",
"\n",
"@contextmanager\n",
"def new_main(trace_type: Type['Trace'], global_data=None):\n",
" level = len(trace_stack)\n",
" main = MainTrace(level, trace_type, global_data)\n",
" trace_stack.append(main)\n",
"\n",
" try:\n",
" yield main\n",
" finally:\n",
" trace_stack.pop()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we're about to apply a transformed function, we'll push another\n",
"interpreter onto the stack using `new_main`. Then, as we apply primitives in\n",
"the function, we can think of the `bind` first being interpreted by the trace\n",
"at the top of the stack (i.e. with the highest level). If that first\n",
"interpreter itself binds other primitives in its interpretation rule for the\n",
"primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`,\n",
"then those `bind` calls will be handled by the interpreter at the next level\n",
"down.\n",
"\n",
"What goes at the bottom of the interpreter stack? At the bottom, we know all\n",
"the transformation interpreters are finished, and we just want to do standard\n",
"evaluation. So at the bottom we'll put an evaluation interpreter.\n",
"\n",
"Let's sketch out the interface for interpreters, which is based on the `Trace`\n",
"and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps\n",
"carrying some extra context data used by the interpreter. A `Trace` handles\n",
"boxing up vales into `Tracers` and also handles primitive application."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Trace:\n",
" main: MainTrace\n",
"\n",
" def __init__(self, main: MainTrace) -> None:\n",
" self.main = main\n",
"\n",
" def pure(self, val): assert False # must override\n",
" def lift(self, val): assert False # must override\n",
"\n",
" def process_primitive(self, primitive, tracers, params):\n",
" assert False # must override"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first two methods are about boxing up values in `Tracer`s, which are the\n",
"objects that flow through the Python programs we transform. The last method is\n",
"the callback we'll use to interpret primitive application.\n",
"\n",
"The `Trace` itself doesn't contain any data, other than a reference to its\n",
"corresponding `MainTrace` instance. In fact, multiple instances of a `Trace`\n",
"might be created and discarded during an application of a transformation,\n",
"whereas only a single `MainTrace` instance is created per application of a\n",
"transformation.\n",
"\n",
"As for `Tracer`s themselves, each one carries an abstract value (and forwards\n",
"infix operators to it), and the rest is up to the transformation. (The\n",
"relationship between `Tracer`s and `AbstractValue`s is that there's one\n",
"`Tracer` per transformation, and at least one `AbstractValue` per base type,\n",
"like arrays.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from typing import Tuple\n",
"\n",
"class Tracer:\n",
" _trace: Trace\n",
"\n",
" __array_priority__ = 1000\n",
"\n",
" @property\n",
" def aval(self):\n",
" assert False # must override\n",
"\n",
" def full_lower(self):\n",
" return self # default implementation\n",
"\n",
" def __neg__(self): return self.aval._neg(self)\n",
" def __add__(self, other): return self.aval._add(self, other)\n",
" def __radd__(self, other): return self.aval._radd(self, other)\n",
" def __mul__(self, other): return self.aval._mul(self, other)\n",
" def __rmul__(self, other): return self.aval._rmul(self, other)\n",
" def __gt__(self, other): return self.aval._gt(self, other)\n",
" def __bool__(self): return self.aval._bool(self)\n",
" def __nonzero__(self): return self.aval._nonzero(self)\n",
"\n",
" def __getattr__(self, name):\n",
" try:\n",
" return getattr(self.aval, name)\n",
" except AttributeError:\n",
" raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")\n",
"\n",
"class ShapedArray:\n",
" array_abstraction_level = 1\n",
" shape: Tuple[int]\n",
" dtype: np.dtype\n",
"\n",
" def __init__(self, shape, dtype):\n",
" self.shape = shape\n",
" self.dtype = dtype\n",
"\n",
" @property\n",
" def ndim(self):\n",
" return len(self.shape)\n",
"\n",
" _neg = staticmethod(neg)\n",
" _add = staticmethod(add)\n",
" _radd = staticmethod(add)\n",
" _mul = staticmethod(mul)\n",
" _rmul = staticmethod(mul)\n",
" _gt = staticmethod(greater)\n",
"\n",
" @staticmethod\n",
" def _bool(tracer):\n",
" raise Exception(\"ShapedArray can't be unambiguously converted to bool\")\n",
"\n",
" @staticmethod\n",
" def _nonzero(tracer):\n",
" raise Exception(\"ShapedArray can't be unambiguously converted to bool\")\n",
"\n",
" def str_short(self):\n",
" return f'{self.dtype.name}[{\",\".join(str(d) for d in self.shape)}]'\n",
"\n",
"class ConcreteArray(ShapedArray):\n",
" array_abstraction_level = 2\n",
" val: np.ndarray\n",
"\n",
" def __init__(self, val):\n",
" self.val = val\n",
" self.shape = val.shape\n",
" self.dtype = val.dtype\n",
"\n",
" @staticmethod\n",
" def _bool(tracer):\n",
" return bool(tracer.aval.val)\n",
"\n",
" @staticmethod\n",
" def _nonzero(tracer):\n",
" return bool(tracer.aval.val)\n",
"\n",
"def get_aval(x):\n",
" if isinstance(x, Tracer):\n",
" return x.aval\n",
" else:\n",
" return ConcreteArray(np.asarray(x))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that we actually have two `AbstractValue`s for arrays, representing\n",
"different levels of abstraction. A `ShapedArray` represents the set of all\n",
"possible arrays with a given shape and dtype. A `ConcreteArray` represents a\n",
"singleton set consisting of a single array value.\n",
"\n",
"Now that we've set up the trace stack, the Trace/Tracer API for interpreters,\n",
"and abstract values, we can come back to implement `bind`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def bind(prim, *args, **params):\n",
" top_trace = find_top_trace(args)\n",
" tracers = [full_raise(top_trace, arg) for arg in args]\n",
" out = top_trace.process_primitive(prim, tracers, params)\n",
" return full_lower(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The main action is that we call `find_top_trace` to figure out which\n",
"interpreter should handle this primitive application as a function of the\n",
"arguments and the active traces on the trace stack. We then call that top\n",
"trace's `process_primitive` so that the trace can apply its interpretation\n",
"rule. The calls to `full_raise` just ensure that the inputs are boxed in the\n",
"top trace's `Tracer` instances, and the call to `full_lower` is an optional\n",
"optimization so that we unbox values out of `Tracer`s as much as possible."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from operator import attrgetter\n",
"\n",
"def find_top_trace(xs) -> Trace:\n",
" top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),\n",
" default=trace_stack[0], key=attrgetter('level'))\n",
" return top_main.trace_type(top_main)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In words, `find_top_trace` returns the highest-level interpreter associated\n",
"with the `Tracer`s on its inputs, and otherwise returns the interpreter at the\n",
"bottom of the stack (which is always an evaluation trace, at least for now).\n",
"This corresponds to JAX transformations mostly working by data dependence\n",
"_except_ for the special bottom-of-the-stack interpreter, which interprets\n",
"everything."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def full_lower(val):\n",
" if isinstance(val, Tracer):\n",
" return val.full_lower()\n",
" else:\n",
" return val\n",
"\n",
"def full_raise(trace, val) -> Tracer:\n",
" if not isinstance(val, Tracer):\n",
" return trace.pure(val)\n",
" level = trace.main.level\n",
" if val._trace.main is trace.main:\n",
" return val\n",
" elif val._trace.main.level < level:\n",
" return trace.lift(val)\n",
" elif val._trace.main.level > level:\n",
" raise Exception(f\"Can't lift level {val._trace.main.level} to {level}.\")\n",
" else: # val._trace.level == level\n",
" raise Exception(f\"Different traces at same level: {val._trace}, {trace}.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The logic in `full_raise` serves to box values into `Tracer`s for a particular\n",
"`Trace`, calling different methods on the `Trace` based on context:\n",
"`Trace.pure` is called on non-`Tracer` constants, and `Trace.lift` is called\n",
"for values that are already `Tracer`s from a lower-level interpreter. These\n",
"two methods could share the same implementation, but by distinguishing them in\n",
"the core logic we can provide more information to the `Trace` subclass.\n",
"\n",
"That's it for the JAX core! Now we can start adding interpreters."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluation interpreter\n",
"\n",
"We'll start with the simplest interpreter: the evaluation interpreter that\n",
"will sit at the bottom of the interpreter stack."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class EvalTrace(Trace):\n",
" pure = lift = lambda self, x: x # no boxing in Tracers needed\n",
"\n",
" def process_primitive(self, primitive, tracers, params):\n",
" return impl_rules[primitive](*tracers, **params)\n",
"\n",
"trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack\n",
"\n",
"impl_rules = {}\n",
"impl_rules[add_p] = np.add\n",
"impl_rules[mul_p] = np.multiply\n",
"impl_rules[neg_p] = np.negative\n",
"impl_rules[sin_p] = np.sin\n",
"impl_rules[cos_p] = np.cos\n",
"impl_rules[reduce_sum_p] = np.sum\n",
"impl_rules[greater_p] = np.greater"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With this interpreter, we can evaluate user functions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7177599838802657\n"
]
}
],
"source": [
"def f(x):\n",
" y = sin(x) * 2\n",
" z = - y + x\n",
" return z\n",
"\n",
"print(f(3.0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Woo! Like going around in a big circle. But the point of this indirection is\n",
"that now we can add some real transformations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Forward-mode autodiff with `jvp`\n",
"\n",
"First, a couple of helper functions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def zeros_like(val):\n",
" return np.zeros_like(val)\n",
"\n",
"def unzip2(pairs):\n",
" lst1, lst2 = [], []\n",
" for x1, x2 in pairs:\n",
" lst1.append(x1)\n",
" lst2.append(x2)\n",
" return lst1, lst2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The\n",
"`Trace` applies JVP rules."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class JVPTracer(Tracer):\n",
" def __init__(self, trace, primal, tangent):\n",
" self._trace = trace\n",
" self.primal = primal\n",
" self.tangent = tangent\n",
"\n",
" @property\n",
" def aval(self):\n",
" return get_aval(self.primal)\n",
"\n",
"class JVPTrace(Trace):\n",
" pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))\n",
"\n",
" def process_primitive(self, primitive, tracers, params):\n",
" primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)\n",
" jvp_rule = jvp_rules[primitive]\n",
" primal_out, tangent_out = jvp_rule(primals_in, tangents_in, **params)\n",
" return JVPTracer(self, primal_out, tangent_out)\n",
"\n",
"jvp_rules = {}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice both `lift` and `sublift` package a value into a `JVPTracer` with the\n",
"minimal amount of context, which is a zero tangent value."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's add some JVP rules for primitives:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def add_jvp(primals, tangents):\n",
" (x, y), (x_dot, y_dot) = primals, tangents\n",
" return x + y, x_dot + y_dot\n",
"jvp_rules[add_p] = add_jvp\n",
"\n",
"def mul_jvp(primals, tangents):\n",
" (x, y), (x_dot, y_dot) = primals, tangents\n",
" return x * y, x_dot * y + x * y_dot\n",
"jvp_rules[mul_p] = mul_jvp\n",
"\n",
"def sin_jvp(primals, tangents):\n",
" (x,), (x_dot,) = primals, tangents\n",
" return sin(x), cos(x) * x_dot\n",
"jvp_rules[sin_p] = sin_jvp\n",
"\n",
"def cos_jvp(primals, tangents):\n",
" (x,), (x_dot,) = primals, tangents\n",
" return cos(x), -sin(x) * x_dot\n",
"jvp_rules[cos_p] = cos_jvp\n",
"\n",
"def neg_jvp(primals, tangents):\n",
" (x,), (x_dot,) = primals, tangents\n",
" return neg(x), neg(x_dot)\n",
"jvp_rules[neg_p] = neg_jvp\n",
"\n",
"def reduce_sum_jvp(primals, tangents, *, axis):\n",
" (x,), (x_dot,) = primals, tangents\n",
" return reduce_sum(x, axis), reduce_sum(x_dot, axis)\n",
"jvp_rules[reduce_sum_p] = reduce_sum_jvp\n",
"\n",
"def greater_jvp(primals, tangents):\n",
" (x, y), _ = primals, tangents\n",
" out_primal = greater(x, y)\n",
" return out_primal, zeros_like(out_primal)\n",
"jvp_rules[greater_p] = greater_jvp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we add a transformation API to kick off the trace:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def jvp(f, primals, tangents):\n",
" with new_main(JVPTrace) as main:\n",
" trace = JVPTrace(main)\n",
" tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]\n",
" out = f(*tracers_in)\n",
" tracer_out = full_raise(trace, out)\n",
" primal_out, tangent_out = tracer_out.primal, tracer_out.tangent\n",
" return primal_out, tangent_out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And with that, we can differentiate!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.9899924966004454\n",
"-0.9899924966004454\n"
]
}
],
"source": [
"x = 3.0\n",
"y, sin_deriv_at_3 = jvp(sin, (x,), (1.0,))\n",
"print(sin_deriv_at_3)\n",
"print(cos(3.0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7177599838802657\n",
"2.979984993200891\n"
]
}
],
"source": [
"def f(x):\n",
" y = sin(x) * 2\n",
" z = - y + x\n",
" return z\n",
"\n",
"x, xdot = 3., 1.\n",
"y, ydot = jvp(f, (x,), (xdot,))\n",
"print(y)\n",
"print(ydot)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.9899924966004454\n",
"-0.1411200080598672\n",
"0.9899924966004454\n",
"0.1411200080598672\n"
]
}
],
"source": [
"def deriv(f):\n",
" return lambda x: jvp(f, (x,), (1.,))[1]\n",
"\n",
"print(deriv(sin)(3.))\n",
"print(deriv(deriv(sin))(3.))\n",
"print(deriv(deriv(deriv(sin)))(3.))\n",
"print(deriv(deriv(deriv(deriv(sin))))(3.))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.0\n",
"1.0\n"
]
}
],
"source": [
"def f(x):\n",
" if x > 0.: # Python control flow\n",
" return 2. * x\n",
" else:\n",
" return x\n",
"\n",
"print(deriv(f)(3.))\n",
"print(deriv(f)(-3.))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Vectorized batching with `vmap`\n",
"\n",
"First, a couple helper functions, one for producing mapped abstract values\n",
"from unmapped ones (by removing an axis), and one for moving batch dimensions\n",
"around:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mapped_aval(batch_dim, aval):\n",
" shape = list(aval.shape)\n",
" del shape[batch_dim]\n",
" return ShapedArray(tuple(shape), aval.dtype)\n",
"\n",
"def move_batch_axis(axis_size, src, dst, x):\n",
" if src is not_mapped:\n",
" target_shape = list(np.shape(x))\n",
" target_shape.insert(dst, axis_size)\n",
" return np.broadcast_to(np.expand_dims(x, dst), target_shape)\n",
" else:\n",
" return np.moveaxis(x, src, dst)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `Tracer` for vectorized batching carries a batched value and an optional\n",
"integer indicating which axis (if any) is the batch axis."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Union\n",
"\n",
"class NotMapped: pass\n",
"not_mapped = NotMapped()\n",
"\n",
"class BatchTracer(Tracer):\n",
" def __init__(self, trace, val, batch_dim: Union[NotMapped, int]):\n",
" self._trace = trace\n",
" self.val = val\n",
" self.batch_dim = batch_dim\n",
"\n",
" @property\n",
" def aval(self):\n",
" if self.batch_dim is not_mapped:\n",
" return get_aval(self.val)\n",
" else:\n",
" return mapped_aval(self.batch_dim, get_aval(self.val))\n",
"\n",
" def full_lower(self):\n",
" if self.batch_dim is not_mapped:\n",
" return full_lower(self.val)\n",
" else:\n",
" return self\n",
"\n",
"class BatchTrace(Trace):\n",
" pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)\n",
"\n",
" def process_primitive(self, primitive, tracers, params):\n",
" vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)\n",
" vmap_rule = vmap_rules[primitive]\n",
" val_out, bdim_out = vmap_rule(self.axis_size, vals_in, bdims_in, **params)\n",
" return BatchTracer(self, val_out, bdim_out)\n",
"\n",
" @property\n",
" def axis_size(self):\n",
" return self.main.global_data\n",
"\n",
"vmap_rules = {}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we've implemented the optional `Tracer.full_lower` method, which lets us\n",
"peel off a batching tracer if it's not needed because it doesn't represent a\n",
"batched value.\n",
"\n",
"For `BatchTrace`, analogous to `JVPTrace`, the methods `pure` and `lift` just\n",
"box a value in a `BatchTracer` with the minimal amount of context, which in\n",
"this case is a `batch_dim` taking the sentinel value `not_mapped`. Notice we\n",
"use the `MainTrace`'s interpreter-global data field to store the batch axis\n",
"size.\n",
"\n",
"Next we can define batching interpreter rules for each primitive:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):\n",
" (x, y), (x_bdim, y_bdim) = vals_in, dims_in\n",
" if x_bdim != y_bdim:\n",
" y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n",
" return op(x, y), x_bdim\n",
"vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)\n",
"vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)\n",
"\n",
"def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):\n",
" (x,), (x_bdim,) = vals_in, dims_in\n",
" return op(x), x_bdim\n",
"vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)\n",
"vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)\n",
"vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)\n",
"\n",
"def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):\n",
" (x,), (x_bdim,) = vals_in, dims_in\n",
" new_axis = axis + (x_bdim <= axis)\n",
" out_bdim = x_bdim - (new_axis < x_bdim)\n",
" return reduce_sum(x, new_axis), out_bdim\n",
"vmap_rules[reduce_sum_p] = reduce_sum_batching_rule"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we add a transformation API to kick off the trace:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def vmap(f, in_axes, out_axis):\n",
" def batched_f(*args):\n",
" axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)\n",
" if ax is not None}\n",
" with new_main(BatchTrace, axis_size) as main:\n",
" trace = BatchTrace(main)\n",
" tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x\n",
" for x, ax in zip(args, in_axes)]\n",
" out = f(*tracers_in)\n",
" tracer_out = full_raise(trace, out)\n",
" val_out, batch_dim_out = tracer_out.val, tracer_out.batch_dim\n",
" return move_batch_axis(axis_size, batch_dim_out, out_axis, val_out)\n",
" return batched_f"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0. 1. 2.]\n",
"[1. 2. 3.]\n"
]
}
],
"source": [
"def add_one_to_a_scalar(scalar):\n",
" assert np.ndim(scalar) == 0\n",
" return 1 + scalar\n",
"\n",
"vector_in = np.arange(3.)\n",
"vector_out = vmap(add_one_to_a_scalar, (0,), 0)(vector_in)\n",
"\n",
"print(vector_in)\n",
"print(vector_out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1. , 0. , -0. ],\n",
" [ 0. , 0.54030231, -0. ],\n",
" [ 0. , 0. , -0.41614684]])"
]
},
"execution_count": 172,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def jacfwd(f, x):\n",
" pushfwd = lambda v: jvp(f, (x,), (v,))[1]\n",
" vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)\n",
" return vmap(pushfwd, (0,), 0)(vecs_in)\n",
"\n",
"def f(x):\n",
" return sin(x)\n",
"\n",
"jacfwd(f, np.arange(3.))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's it for `jvp` and `vmap`! Before moving on, let's highlight a few\n",
"simplifications in what we've seen so far compared to the full JAX\n",
"implementation:\n",
"1. **Fewer, simpler primitives.** More primitives means more interpretation\n",
"rules, and for more complex primitives (like for convolution or advanced\n",
"indexing) each rule is harder to write. But the overarching design is no\n",
"different.\n",
"1. **Transformations expect arrays in, single array out.**\n",
"2. **No symbolic zeros in autodiff.**\n",
"3. **No special call primitives yet.** The core machinery needs to be\n",
" generalized to handle the most flexible kind of higher-order primitive,\n",
" used by `jax.custom_jvp` and `jax.custom_vjp`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 2: Jaxprs, for `jit` and `vjp`\n",
"\n",
"The next transformations are the horizon are `jit` for just-in-time\n",
"compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small\n",
"wrapper around `vjp`.) For `jvp` and `vmap` we only needed each `Tracer` to\n",
"carry a little bit of extra context, but for both `jit` and `vjp` we need\n",
"much richer context: we need to represent _programs_. That is, we need jaxprs!\n",
"\n",
"Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are\n",
"an explicitly typed, functional, first-order language. We need a program\n",
"representation for `jit` because the purpose of `jit` is to stage computation\n",
"out of Python. For any computation we want to stage out, we need to be able to\n",
"represent it as data, and build it up as we trace a Python function.\n",
"Similarly, `vjp` needs a way to represent the computation for the backward\n",
"pass of reverse-mode autodiff. We use the same jaxpr program representation\n",
"for both needs.\n",
"\n",
"(Building a program representation is the most\n",
"[free](https://en.wikipedia.org/wiki/Free_object) kind of\n",
"trace- transformation, and so except for issues around handling native Python\n",
"control flow, any transformation could be implemented by first tracing to a\n",
"jaxpr and then interpreting the jaxpr.)\n",
"\n",
"The jaxpr term syntax is roughly:\n",
"\n",
"```\n",
"jaxpr ::=\n",
" { lambda <binder> , ... .\n",
" let <eqn>\n",
" ...\n",
" in <atom> }\n",
"\n",
"binder ::= <var>:<array_type>\n",
"var ::= a | b | c | ...\n",
"atom ::= <var> | <literal>\n",
"literal ::= <int32> | <float32>\n",
"\n",
"eqn ::= <binder> = <primitive> [ <params> ] <atom> , ...\n",
"```\n",
"\n",
"The syntax of types is:\n",
"\n",
"```\n",
"jaxpr_type ::= [<array_type>, ...] -> [<array_type>, ...]\n",
"array_type ::= <dtype>[<shape>]\n",
"dtype ::= f32 | f64 | i32 | i64\n",
"shape ::= <int> , ...\n",
"```\n",
"\n",
"How do we represent these as Python data structures? We reuse ShapedArrays to\n",
"represent types, and we can represent the term syntax with a few Python\n",
"structs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, Set\n",
"\n",
"class Var:\n",
" aval: ShapedArray\n",
" def __init__(self, aval): self.aval = aval\n",
"\n",
"class Lit:\n",
" val: Any\n",
" aval: ShapedArray\n",
"\n",
" def __init__(self, val):\n",
" self.val = val\n",
" self.aval = raise_to_shaped(get_aval(self.val))\n",
"\n",
"Atom = Union[Var, Lit]\n",
"\n",
"class JaxprEqn(NamedTuple):\n",
" primitive: Primitive\n",
" inputs: List[Atom]\n",
" params: Dict[str, Any]\n",
" out_binder: Var\n",
"\n",
"class Jaxpr(NamedTuple):\n",
" in_binders: List[Var]\n",
" eqns: List[JaxprEqn]\n",
" out: Atom\n",
"\n",
"\n",
"def raise_to_shaped(aval):\n",
" return ShapedArray(aval.shape, aval.dtype)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "composite-dinner",
"metadata": {},
"outputs": [],
"source": [
"class JaxprType:\n",
" in_types: List[ShapedArray]\n",
" out_type: ShapedArray\n",
"\n",
" def __init__(self, in_types, out_type):\n",
" self.in_types = in_types\n",
" self.out_type = out_type\n",
"\n",
" def __repr__(self):\n",
" in_types = ', '.join(aval.str_short() for aval in self.in_types)\n",
" out_type = self.out_type.str_short()\n",
" return f'({in_types}) -> {out_type}'\n",
"\n",
"\n",
"def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:\n",
" env: Set[Var] = set()\n",
"\n",
" for v in jaxpr.in_binders:\n",
" env.add(v)\n",
"\n",
" for eqn in jaxpr.eqns:\n",
" in_types = [typecheck_atom(env, x) for x in eqn.inputs]\n",
" out_type = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)\n",
" if not types_equal(out_type, eqn.out_binder.aval): raise TypeError\n",
" env.add(eqn.out_binder)\n",
"\n",
" out_type = typecheck_atom(env, jaxpr.out)\n",
" return JaxprType([v.aval for v in jaxpr.in_binders], out_type)\n",
"\n",
"def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray:\n",
" if isinstance(x, Var):\n",
" if x not in env: raise TypeError(\"unbound variable\")\n",
" return x.aval\n",
" elif isinstance(x, Lit):\n",
" return raise_to_shaped(get_aval(x.val))\n",
" else:\n",
" assert False\n",
"\n",
"def types_equal(a: ShapedArray, b: ShapedArray) -> bool:\n",
" return a.shape == b.shape and a.dtype == b.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have jaxprs as a data structure, we need ways to produce these\n",
"from tracing Python code. In general there are two variants of how we trace to\n",
"a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one\n",
"used by `jit`, which is also used by control flow primitives like\n",
"`lax.cond`, `lax.while_loop`, and `lax.scan`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NB: the analogous class in JAX is called 'DynamicJaxprTracer'\n",
"class JaxprTracer(Tracer):\n",
" __slots__ = ['aval']\n",
" aval: ShapedArray\n",
"\n",
" def __init__(self, trace, aval):\n",
" self._trace = trace\n",
" self.aval = aval\n",
"\n",
"# NB: the analogous class in JAX is called 'DynamicJaxprTrace'\n",
"class JaxprTrace(Trace):\n",
" def new_arg(self, aval: ShapedArray) -> JaxprTracer:\n",
" aval = raise_to_shaped(aval)\n",
" tracer = JaxprTracer(self, aval)\n",
" self.builder.tracer_to_var[id(tracer)] = Var(aval)\n",
" return tracer\n",
"\n",
" def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:\n",
" tracer = self.builder.const_tracers.get(id(val))\n",
" if tracer is None:\n",
" tracer = JaxprTracer(self, raise_to_shaped(get_aval(val)))\n",
" self.builder.add_const(tracer, val)\n",
" return tracer\n",
" pure = lift = get_or_make_const_tracer\n",
"\n",
" def process_primitive(self, primitive, tracers, params):\n",
" avals_in = [t.aval for t in tracers]\n",
" aval_out = abstract_eval_rules[primitive](*avals_in, **params)\n",
" out_tracer = JaxprTracer(self, aval_out)\n",
" inputs = [self.builder.getvar(t) for t in tracers]\n",
" outvar = self.builder.add_var(out_tracer)\n",
" self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvar))\n",
" return out_tracer\n",
"\n",
" @property\n",
" def builder(self):\n",
" return self.main.global_data\n",
"\n",
"# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance\n",
"abstract_eval_rules = {}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that we keep as interpreter-global data a builder object, which keeps\n",
"track of variables, constants, and eqns as we build up the jaxpr."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class JaxprBuilder:\n",
" eqns: List[JaxprEqn]\n",
" tracer_to_var: Dict[int, Var]\n",
" const_tracers: Dict[int, JaxprTracer]\n",
" constvals: Dict[Var, Any]\n",
"\n",
" def __init__(self):\n",
" self.eqns = []\n",
" self.tracer_to_var = {}\n",
" self.const_tracers = {}\n",
" self.constvals = {}\n",
"\n",
" def add_eqn(self, eqn: JaxprEqn) -> None:\n",
" self.eqns.append(eqn)\n",
"\n",
" def add_var(self, tracer: JaxprTracer) -> Var:\n",
" var = self.tracer_to_var.get(id(tracer))\n",
" assert var is None\n",
" var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)\n",
" return var\n",
"\n",
" def getvar(self, tracer: JaxprTracer) -> Var:\n",
" var = self.tracer_to_var.get(id(tracer))\n",
" assert var is not None\n",
" return var\n",
"\n",
" def add_const(self, tracer: JaxprTracer, val: Any) -> Var:\n",
" var = self.add_var(tracer)\n",
" self.const_tracers[id(val)] = tracer\n",
" self.constvals[var] = val\n",
" return var\n",
"\n",
" def build(self, in_tracers: List[JaxprTracer], out_tracer: JaxprTracer\n",
" ) -> Tuple[Jaxpr, List[Any]]:\n",
" constvars, constvals = unzip2(self.constvals.items())\n",
" t2v = lambda t: self.tracer_to_var[id(t)]\n",
" in_binders = constvars + [t2v(t) for t in in_tracers]\n",
" jaxpr = Jaxpr(in_binders, self.eqns, t2v(out_tracer))\n",
" typecheck_jaxpr(jaxpr)\n",
" return jaxpr, constvals"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The rules we need for `JaxprTrace.process_primitive` are essentially typing\n",
"rules for primitive applications: given the primitive, its parameters, and\n",
"types for the inputs, the rule must produce a type for the output, which is\n",
"then packaged with the output `JaxprTracer`. We can use abstract evaluation\n",
"rules for this same purpose, even though they can be more general (since\n",
"abstract evaluation rules need to work on ConcreteArray inputs as well). We'll\n",
"reuse these abstract evaluation rules for the other jaxpr-producing trace\n",
"machinery, where the potential extra generality is useful."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def broadcast_shapes(*shapes):\n",
" assert len(shapes) > 1\n",
" for sizes in zip(*shapes):\n",
" sizes = [d for d in sizes if d != 1]\n",
" if sizes[:-1] != sizes[1:]:\n",
" raise Exception\n",
" return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def broadcasting_binop_abstract_eval_rule(*avals_in):\n",
" out_dtype = np.result_type(*map(np.result_type, avals_in))\n",
" out_shape = broadcast_shapes(*map(np.shape, avals_in))\n",
" return ShapedArray(out_shape, out_dtype)\n",
"\n",
"abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule\n",
"abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule\n",
"\n",
"def vectorized_unop_abstract_eval_rule(aval_in):\n",
" return ShapedArray(np.shape(aval_in), np.result_type(aval_in))\n",
"\n",
"abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule\n",
"abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule\n",
"abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval_rule\n",
"\n",
"def reduce_sum_abstract_eval_rule(aval_in, *, axis):\n",
" new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis]\n",
" return ShapedArray(tuple(new_shape), aval_in.dtype)\n",
"abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To check our implementation, we can add a `make_jaxpr` transformation and\n",
"first pretty-printer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "defensive-ownership",
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def make_jaxpr(f, avals_in):\n",
" builder = JaxprBuilder()\n",
" with new_main(JaxprTrace, builder) as main:\n",
" trace = JaxprTrace(main)\n",
" tracers_in = [trace.new_arg(aval) for aval in avals_in]\n",
" out = f(*tracers_in)\n",
" tracer_out = full_raise(trace, out)\n",
" return builder.build(tracers_in, tracer_out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adopted-month",
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"import itertools as it\n",
"import string\n",
"\n",
"class PPrint:\n",
" lines: List[Tuple[int, str]]\n",
"\n",
" def __init__(self, lines):\n",
" self.lines = lines\n",
"\n",
" def indent(self, indent: int) -> 'PPrint':\n",
" return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])\n",
"\n",
" def __add__(self, rhs: 'PPrint') -> 'PPrint':\n",
" return PPrint(self.lines + rhs.lines)\n",
"\n",
" def __rshift__(self, rhs: 'PPrint') -> 'PPrint':\n",
" if not rhs.lines: return self\n",
" if not self.lines: return rhs\n",
" indent, s = self.lines[-1]\n",
" indented_block = rhs.indent(indent + len(s))\n",
" common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]\n",
" return PPrint(self.lines[:-1]\n",
" + [(indent, common_line)]\n",
" + indented_block.lines[1:])\n",
"\n",
" def __str__(self) -> str:\n",
" return '\\n'.join(' ' * indent + s for indent, s in self.lines)\n",
"\n",
"def pp(s: Any) -> PPrint:\n",
" return PPrint([(0, line) for line in str(s).splitlines()])\n",
"\n",
"def vcat(ps: List[PPrint]) -> PPrint:\n",
" return sum(ps, pp(''))\n",
"\n",
"def pp_jaxpr(jaxpr: Jaxpr):\n",
" namegen = (''.join(s) for r in it.count(1)\n",
" for s in it.permutations(string.ascii_lowercase, r))\n",
" names = defaultdict(lambda: next(namegen))\n",
" in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)\n",
" eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])\n",
" out = names[jaxpr.out] if isinstance(jaxpr.out, Var) else str(jaxpr.out.val)\n",
" return (pp(f'{{ lambda {in_binders} .') +\n",
" ((pp('let ') >> eqns) + pp(f'in {out} }}')).indent(2))\n",
"\n",
"def var_str(names: Dict[Var, str], v: Var) -> str:\n",
" return f'{names[v]}:{v.aval.str_short()}'\n",
"\n",
"def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
" lhs = pp(var_str(names, eqn.out_binder))\n",
" rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
" for x in eqn.inputs)))\n",
" return lhs >> pp(' = ') >> rhs\n",
"\n",
"def pp_params(params: Dict[str, Any]) -> PPrint:\n",
" items = sorted(params.items())\n",
" if items:\n",
" return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')\n",
" else:\n",
" return pp(' ')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"jaxpr, consts = make_jaxpr(lambda x: 2. * x, [raise_to_shaped(get_aval(3.))])\n",
"print(pp_jaxpr(jaxpr))\n",
"print(typecheck_jaxpr(jaxpr))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
"kind": "private"
},
"name": "Autodidax: JAX core from scratch",
"provenance": [],
"toc_visible": true
},
"jupytext": {
"formats": "ipynb,md:myst,py",
"main_language": "python"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}