From 9145366f6f2bd9a25a6c9b435dc6043fd327e62f Mon Sep 17 00:00:00 2001 From: Dougal Date: Sun, 26 Jan 2025 22:10:45 -0500 Subject: [PATCH] Part 1 of a new autodidax based on "stackless" --- docs/autodidax2_part1.ipynb | 1082 +++++++++++++++++++++++++++++++++++ docs/autodidax2_part1.md | 547 ++++++++++++++++++ docs/autodidax2_part1.py | 491 ++++++++++++++++ docs/conf.py | 1 + docs/contributor_guide.rst | 1 + 5 files changed, 2122 insertions(+) create mode 100644 docs/autodidax2_part1.ipynb create mode 100644 docs/autodidax2_part1.md create mode 100644 docs/autodidax2_part1.py diff --git a/docs/autodidax2_part1.ipynb b/docs/autodidax2_part1.ipynb new file mode 100644 index 000000000..0a5a89c8e --- /dev/null +++ b/docs/autodidax2_part1.ipynb @@ -0,0 +1,1082 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "e515a630", + "metadata": {}, + "source": [ + "---\n", + "Copyright 2025 The JAX Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "25ac15cc", + "metadata": {}, + "source": [ + "# Autodidax2, part 1: JAX from scratch, again" + ] + }, + { + "cell_type": "markdown", + "id": "6999020a", + "metadata": {}, + "source": [ + "If you want to understand how JAX works you could trying reading the code. But\n", + "the code is complicated, often for no good reason. This notebook presents a\n", + "stripped-back version without the cruft. It's a minimal version of JAX from\n", + "first principles. Enjoy!" + ] + }, + { + "cell_type": "markdown", + "id": "62dde49f", + "metadata": {}, + "source": [ + "## Main idea: context-sensitive interpretation" + ] + }, + { + "cell_type": "markdown", + "id": "d13d5272", + "metadata": {}, + "source": [ + "JAX is two things:\n", + " 1. a set of primitive operations (roughly the NumPy API)\n", + " 2. a set of interpreters over those primitives (compilation, AD, etc.)\n", + "\n", + "In this minimal version of JAX we'll start with just two primitive operations,\n", + "addition and multiplication, and we'll add interpreters one by one. Suppose we\n", + "have a user-defined function like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9f179429", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.809256Z", + "iopub.status.busy": "2025-02-12T20:21:55.808149Z", + "iopub.status.idle": "2025-02-12T20:21:55.827374Z", + "shell.execute_reply": "2025-02-12T20:21:55.826143Z" + } + }, + "outputs": [], + "source": [ + "def foo(x):\n", + " return mul(x, add(x, 3.0))" + ] + }, + { + "cell_type": "markdown", + "id": "809d67a6", + "metadata": {}, + "source": [ + "We want to be able to interpret `foo` in different ways without changing its\n", + "implementation: we want to evaluate it on concrete values, differentiate it,\n", + "stage it out to an IR, compile it and so on." + ] + }, + { + "cell_type": "markdown", + "id": "a4235f52", + "metadata": {}, + "source": [ + "Here's how we'll do it. For each of these interpretations we'll define an\n", + "`Interpreter` object with a rule for handling each primitive operation. We'll\n", + "keep track of the *current* interpreter using a global context variable. The\n", + "user-facing functions `add` and `mul` will dispatch to the current\n", + "interpreter. At the beginning of the program the current interpreter will be\n", + "the \"evaluating\" interpreter which just evaluates the operations on ordinary\n", + "concrete data. Here's what this all looks like so far." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "65b26bdc", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.830948Z", + "iopub.status.busy": "2025-02-12T20:21:55.830603Z", + "iopub.status.idle": "2025-02-12T20:21:55.842672Z", + "shell.execute_reply": "2025-02-12T20:21:55.841832Z" + } + }, + "outputs": [], + "source": [ + "from enum import Enum, auto\n", + "from contextlib import contextmanager\n", + "from typing import Any\n", + "\n", + "# The full (closed) set of primitive operations\n", + "class Op(Enum):\n", + " add = auto() # addition on floats\n", + " mul = auto() # multiplication on floats\n", + "\n", + "# Interpreters have rules for handling each primitive operation.\n", + "class Interpreter:\n", + " def interpret_op(self, op: Op, args: tuple[Any, ...]):\n", + " assert False, \"subclass should implement this\"\n", + "\n", + "# Our first interpreter is the \"evaluating interpreter\" which performs ordinary\n", + "# concrete evaluation.\n", + "class EvalInterpreter:\n", + " def interpret_op(self, op, args):\n", + " assert all(isinstance(arg, float) for arg in args)\n", + " match op:\n", + " case Op.add:\n", + " x, y = args\n", + " return x + y\n", + " case Op.mul:\n", + " x, y = args\n", + " return x * y\n", + " case _:\n", + " raise ValueError(f\"Unrecognized primitive op: {op}\")\n", + "\n", + "# The current interpreter is initially the evaluating interpreter.\n", + "current_interpreter = EvalInterpreter()\n", + "\n", + "# A context manager for temporarily changing the current interpreter\n", + "@contextmanager\n", + "def set_interpreter(new_interpreter):\n", + " global current_interpreter\n", + " prev_interpreter = current_interpreter\n", + " try:\n", + " current_interpreter = new_interpreter\n", + " yield\n", + " finally:\n", + " current_interpreter = prev_interpreter\n", + "\n", + "# The user-facing functions `mul` and `add` dispatch to the current interpreter.\n", + "def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))\n", + "def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))" + ] + }, + { + "cell_type": "markdown", + "id": "e7d4ff3a", + "metadata": {}, + "source": [ + "At this point we can call `foo` with ordinary concrete inputs and see the\n", + "results:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5aa8511c", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.846387Z", + "iopub.status.busy": "2025-02-12T20:21:55.846085Z", + "iopub.status.idle": "2025-02-12T20:21:55.850202Z", + "shell.execute_reply": "2025-02-12T20:21:55.849420Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10.0\n" + ] + } + ], + "source": [ + "print(foo(2.0))" + ] + }, + { + "cell_type": "markdown", + "id": "175587ca", + "metadata": {}, + "source": [ + "## Aside: forward-mode automatic differentiation" + ] + }, + { + "cell_type": "markdown", + "id": "e003ba3f", + "metadata": {}, + "source": [ + "For our second interpreter we're going to try forward-mode automatic\n", + "differentiation (AD). Here's a quick introduction to forward-mode AD in case\n", + "this is the first time you've come across it. Otherwise skip ahead to the\n", + "\"JVPInterprer\" section." + ] + }, + { + "cell_type": "markdown", + "id": "3bb1dde9", + "metadata": {}, + "source": [ + "Suppose we're interested in the derivative of `foo(x)` evaluated at `x=2.0`.\n", + "We could approximate it with finite differences:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9151fcd4", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.852192Z", + "iopub.status.busy": "2025-02-12T20:21:55.852015Z", + "iopub.status.idle": "2025-02-12T20:21:55.855275Z", + "shell.execute_reply": "2025-02-12T20:21:55.854676Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7.000009999913458\n" + ] + } + ], + "source": [ + "print((foo(2.00001) - foo(2.0)) / 0.00001)" + ] + }, + { + "cell_type": "markdown", + "id": "9c3ce8ae", + "metadata": {}, + "source": [ + "The answer is close to 7.0 as expected. But computing it this way required two\n", + "evaluations of the function (not to mention the roundoff error and truncation\n", + "error). Here's a funny thing though. We can almost get the answer with a\n", + "single evaluation:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "cba962a2", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.857141Z", + "iopub.status.busy": "2025-02-12T20:21:55.856974Z", + "iopub.status.idle": "2025-02-12T20:21:55.859864Z", + "shell.execute_reply": "2025-02-12T20:21:55.859432Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10.0000700001\n" + ] + } + ], + "source": [ + "print(foo(2.00001))" + ] + }, + { + "cell_type": "markdown", + "id": "a256c797", + "metadata": {}, + "source": [ + "The answer we're looking for, 7.0, is right there in the insignificant digits!" + ] + }, + { + "cell_type": "markdown", + "id": "171a1aab", + "metadata": {}, + "source": [ + "Here's one way to think about what's happening. The initial argument to `foo`,\n", + "`2.00001`, carries two pieces of data: a \"primal\" value, 2.0, and a \"tangent\"\n", + "value, `1.0`. The representation of this primal-tangent pair, `2.00001`, is\n", + "the sum of the two, with the tangent scaled by a small fixed epsilon, `1e-5`.\n", + "Ordinary evaluation of `foo(2.00001)` propagates this primal-tangent pair,\n", + "producing `10.0000700001` as the result. The primal and tangent components are\n", + "well separated in scale so we can visually interpret the result as the\n", + "primal-tangent pair (10.0, 7.0), ignoring the the ~1e-10 truncation error at\n", + "the end." + ] + }, + { + "cell_type": "markdown", + "id": "47420177", + "metadata": {}, + "source": [ + "The idea with forward-mode differentiation is to do the same thing but exactly\n", + "and explicitly (eyeballing floats doesn't really scale). We'll represent the\n", + "primal-tangent pair as an actual pair instead of folding them both into a\n", + "single floating point number. For each primitive operation we'll have a rule\n", + "that describes how to propagate these primal tangent pairs. Let's work out the\n", + "rules for our two primitives." + ] + }, + { + "cell_type": "markdown", + "id": "309dc70d", + "metadata": {}, + "source": [ + "Addition is easy. Consider `x + y` where `x = xp + xt * eps` and `y = yp + yt * eps`\n", + "(\"p\" for \"primal\", \"t\" for \"tangent\"):\n", + "\n", + " x + y = (xp + xt * eps) + (yp + yt * eps)\n", + " = (xp + yp) # primal component\n", + " + (xt + yt) * eps # tangent component\n", + "\n", + "The result is a first-order polynomial in `eps` and we can read off the\n", + "primal-tangent pair as (xp + yp, xt + yt)." + ] + }, + { + "cell_type": "markdown", + "id": "59302b21", + "metadata": {}, + "source": [ + "Multiplication is more interesting:\n", + "\n", + " x * y = (xp + xt * eps) * (yp + yt * eps)\n", + " = (xp * yp) # primal component\n", + " + (xp * yt + xt * yp) * eps # tangent component\n", + " + (xt * yt) * eps * eps # quadratic component, vanishes in the eps->0 limit\n", + "\n", + "Now we have a second order polynomial. But as epsilon goes to zero the\n", + "quadratic term vanishes and our primal-tangent pair\n", + "is just `(xp * yp, xp * yt + xt * yp)`\n", + "(In our earlier example with finite `eps` this term not vanishing is\n", + "why we had the 1e-10 \"truncation error\")." + ] + }, + { + "cell_type": "markdown", + "id": "37fa5063", + "metadata": {}, + "source": [ + "Putting this into code, we can write down the forward-AD rules for addition\n", + "and multiplication and express `foo` in terms of these:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "57222038", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.862460Z", + "iopub.status.busy": "2025-02-12T20:21:55.862160Z", + "iopub.status.idle": "2025-02-12T20:21:55.868704Z", + "shell.execute_reply": "2025-02-12T20:21:55.867858Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DualNumber(primal=10.0, tangent=7.0)\n" + ] + } + ], + "source": [ + "from dataclasses import dataclass\n", + "\n", + "# A primal-tangent pair is conventionally called a \"dual number\"\n", + "@dataclass\n", + "class DualNumber:\n", + " primal : float\n", + " tangent : float\n", + "\n", + "def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:\n", + " return DualNumber(x.primal + y.primal, x.tangent + y.tangent)\n", + "\n", + "def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:\n", + " return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)\n", + "\n", + "def foo_dual(x : DualNumber) -> DualNumber:\n", + " return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))\n", + "\n", + "print (foo_dual(DualNumber(2.0, 1.0)))" + ] + }, + { + "cell_type": "markdown", + "id": "54947cc7", + "metadata": {}, + "source": [ + "That works! But rewriting `foo` to use the `_dual` versions of addition and\n", + "multiplication was a bit tedious. Let's get back to the main program and use\n", + "our interpretation machinery to do the rewrite automatically." + ] + }, + { + "cell_type": "markdown", + "id": "25edb5f4", + "metadata": {}, + "source": [ + "## JVP Interpreter" + ] + }, + { + "cell_type": "markdown", + "id": "945933c6", + "metadata": {}, + "source": [ + "We'll set up a new interpreter called `JVPInterpreter` (\"JVP\" for\n", + "\"Jacobian-vector product\") which propagates these dual numbers instead of\n", + "ordinary values. The `JVPInterpreter` has methods 'add' and 'mul' that operate\n", + "on dual number. They cast constant arguments to dual numbers as needed by\n", + "calling `JVPInterpreter.lift`. In our manually rewritten version above we did\n", + "that by replacing the literal `3.0` with `DualNumber(3.0, 0.0)`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d17a4fb0", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.871469Z", + "iopub.status.busy": "2025-02-12T20:21:55.871220Z", + "iopub.status.idle": "2025-02-12T20:21:55.880456Z", + "shell.execute_reply": "2025-02-12T20:21:55.879794Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(10.0, 7.0)\n" + ] + } + ], + "source": [ + "# This is like DualNumber above except that is also has a pointer to the\n", + "# interpreter it belongs to, which is needed to avoid \"perturbation confusion\"\n", + "# in higher order differentiation.\n", + "@dataclass\n", + "class TaggedDualNumber:\n", + " interpreter : Interpreter\n", + " primal : float\n", + " tangent : float\n", + "\n", + "class JVPInterpreter(Interpreter):\n", + " def __init__(self, prev_interpreter: Interpreter):\n", + " # We keep a pointer to the interpreter that was current when this\n", + " # interpreter was first invoked. That's the context in which our\n", + " # rules should run.\n", + " self.prev_interpreter = prev_interpreter\n", + "\n", + " def interpret_op(self, op, args):\n", + " args = tuple(self.lift(arg) for arg in args)\n", + " with set_interpreter(self.prev_interpreter):\n", + " match op:\n", + " case Op.add:\n", + " # Notice that we use `add` and `mul` here, which are the\n", + " # interpreter-dispatching functions defined earlier.\n", + " x, y = args\n", + " return self.dual_number(\n", + " add(x.primal, y.primal),\n", + " add(x.tangent, y.tangent))\n", + "\n", + " case Op.mul:\n", + " x, y = args\n", + " x = self.lift(x)\n", + " y = self.lift(y)\n", + " return self.dual_number(\n", + " mul(x.primal, y.primal),\n", + " add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))\n", + "\n", + " def dual_number(self, primal, tangent):\n", + " return TaggedDualNumber(self, primal, tangent)\n", + "\n", + " # Lift a constant value (constant with respect to this interpreter) to\n", + " # a TaggedDualNumber.\n", + " def lift(self, x):\n", + " if isinstance(x, TaggedDualNumber) and x.interpreter is self:\n", + " return x\n", + " else:\n", + " return self.dual_number(x, 0.0)\n", + "\n", + "def jvp(f, primal, tangent):\n", + " jvp_interpreter = JVPInterpreter(current_interpreter)\n", + " dual_number_in = jvp_interpreter.dual_number(primal, tangent)\n", + " with set_interpreter(jvp_interpreter):\n", + " result = f(dual_number_in)\n", + " dual_number_out = jvp_interpreter.lift(result)\n", + " return dual_number_out.primal, dual_number_out.tangent\n", + "\n", + "# Let's try it out:\n", + "print(jvp(foo, 2.0, 1.0))\n", + "\n", + "# Because we were careful to consider nesting interpreters, higher-order AD\n", + "# works out of the box:\n", + "\n", + "def derivative(f, x):\n", + " _, tangent = jvp(f, x, 1.0)\n", + " return tangent\n", + "\n", + "def nth_order_derivative(n, f, x):\n", + " if n == 0:\n", + " return f(x)\n", + " else:\n", + " return derivative(lambda x: nth_order_derivative(n-1, f, x), x)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3acc3839", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.882187Z", + "iopub.status.busy": "2025-02-12T20:21:55.882009Z", + "iopub.status.idle": "2025-02-12T20:21:55.885190Z", + "shell.execute_reply": "2025-02-12T20:21:55.884635Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10.0\n" + ] + } + ], + "source": [ + "print(nth_order_derivative(0, foo, 2.0))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "187eb028", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.886848Z", + "iopub.status.busy": "2025-02-12T20:21:55.886685Z", + "iopub.status.idle": "2025-02-12T20:21:55.889507Z", + "shell.execute_reply": "2025-02-12T20:21:55.889081Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7.0\n" + ] + } + ], + "source": [ + "print(nth_order_derivative(1, foo, 2.0))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9f0dde6d", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.891061Z", + "iopub.status.busy": "2025-02-12T20:21:55.890896Z", + "iopub.status.idle": "2025-02-12T20:21:55.894142Z", + "shell.execute_reply": "2025-02-12T20:21:55.893701Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.0\n" + ] + } + ], + "source": [ + "print(nth_order_derivative(2, foo, 2.0))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4d086fb3", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.895591Z", + "iopub.status.busy": "2025-02-12T20:21:55.895398Z", + "iopub.status.idle": "2025-02-12T20:21:55.898277Z", + "shell.execute_reply": "2025-02-12T20:21:55.897870Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0\n" + ] + } + ], + "source": [ + "# The rest are zero because `foo` is only a second-order polymonial\n", + "print(nth_order_derivative(3, foo, 2.0))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e3164405", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.899736Z", + "iopub.status.busy": "2025-02-12T20:21:55.899545Z", + "iopub.status.idle": "2025-02-12T20:21:55.902719Z", + "shell.execute_reply": "2025-02-12T20:21:55.902303Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0\n" + ] + } + ], + "source": [ + "print(nth_order_derivative(4, foo, 2.0))" + ] + }, + { + "cell_type": "markdown", + "id": "2e51ca61", + "metadata": {}, + "source": [ + "There are some subtleties worth discussing. First, how do you tell if\n", + "something is constant with respect to differentiation? It's tempting to say\n", + "\"it's a constant if and only if it's not a dual number\". But actually dual\n", + "numbers created by a *different* JVPInterpreter also need to be considered\n", + "constants with resepect to the JVPInterpreter we're currently handling. That's\n", + "why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This\n", + "comes up in higher order differentiation when there are multiple JVPInterprers\n", + "in scope. The sort of bug where you accidentally interpret a dual number from\n", + "a different interpreter as non-constant is sometimes called \"perturbation\n", + "confusion\" in the literature. Here's an example program that would have given\n", + "the wrong answer if we hadn't had the `and x.interpreter is self` check in\n", + "`JVPInterpreter.lift`." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ae1449a0", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.904294Z", + "iopub.status.busy": "2025-02-12T20:21:55.904105Z", + "iopub.status.idle": "2025-02-12T20:21:55.907284Z", + "shell.execute_reply": "2025-02-12T20:21:55.906874Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0\n" + ] + } + ], + "source": [ + "def f(x):\n", + " # g is constant in its (ignored) argument `y`. Its derivative should be zero\n", + " # but our AD will mess it up if we don't distinguish perturbations from\n", + " # different interpreters.\n", + " def g(y):\n", + " return x\n", + " should_be_zero = derivative(g, 0.0)\n", + " return mul(x, should_be_zero)\n", + "\n", + "print(derivative(f, 0.0))" + ] + }, + { + "cell_type": "markdown", + "id": "884d6f62", + "metadata": {}, + "source": [ + "Another subtlety: `JVPInterpreter.add` and `JVPInterpreter.mul` describe\n", + "addition and multiplication on dual numbers in terms of addition and\n", + "multiplication on the primal and tangent components. But we don't use ordinary\n", + "`+` and `*` for this. Instead we use our own `add` and `mul` functions which\n", + "dispatch to the current interpreter. Before calling them we set the current\n", + "interpreter to be the *previous* interpreter, i.e. the interpreter that was\n", + "current when `JVPInterpreter` was first invoked. If we didn't do this we'd\n", + "have an infinite recursion, with `add` and `mul` dispatching to\n", + "`JVPInterpreter` endlessly. The advantage of using own `add` and `mul` instead\n", + "of ordinary `+` and `*` is that it means we can nest these interpreters and do\n", + "higher-order AD." + ] + }, + { + "cell_type": "markdown", + "id": "e03446b3", + "metadata": {}, + "source": [ + "At this point you might be wondering: have we just reinvented operator\n", + "overloading? Python overloads the infix ops `+` and `*` to dispatch to the\n", + "argument's `__add__` and `__mul__`. Could we have just used that mechanism\n", + "instead of this whole interpreter business? Yes, actually. Indeed, the earlier\n", + "automatic differentiation (AD) literature uses the term \"operator overloading\"\n", + "to describe this style of AD implementation. One detail is that we can't rely\n", + "exclusively on Python built-in overloading because that only lets us overload\n", + "a handful of built-in infix ops whereas we eventually want to overload\n", + "numpy-level operations like `sin` and `cos`. So we need our own mechanism." + ] + }, + { + "cell_type": "markdown", + "id": "2e2035ea", + "metadata": {}, + "source": [ + "But there's a more important difference: our dispatch is based on *context*\n", + "whereas traditional Python-style overloading is based on *data*. This is\n", + "actually a recent development for JAX. The earliest versions of JAX looked\n", + "more like traditional data-based overloading. An interpreter (a \"trace\" in JAX\n", + "jargon) for an operation would be chosen based on data attached to the\n", + "arguments to that operation. We've gradually made the interpreter-dispatch\n", + "decision rely more and more on context rather than data (omnistaging [link],\n", + "stackless [link]). The reason to prefer context-based interpretation over\n", + "data-based interpretation is that it makes the implementation much simpler." + ] + }, + { + "cell_type": "markdown", + "id": "2e5e55ee", + "metadata": {}, + "source": [ + "All that said, we do *also* want to take advantage of Python's built-in\n", + "overloading mechanism. That way we get the syntactic convenience of using\n", + "infix operators `+` and `*` instead of writing out `add(..)` and `mul(..)`.\n", + "But we'll put that aside for now." + ] + }, + { + "cell_type": "markdown", + "id": "c9a22b61", + "metadata": {}, + "source": [ + "# 3. Staging to an untyped IR" + ] + }, + { + "cell_type": "markdown", + "id": "480e2328", + "metadata": {}, + "source": [ + "The two program transformations we've seen so far -- evaluation and JVP --\n", + "both traverse the input program from top to bottom. They visit the operations\n", + "one by one in the same order as ordinary evaluation. A convenient thing about\n", + "top-to-bottom transformations is that they can be implemented eagerly, or\n", + "\"online\", meaning that we can evaluate the program from top to bottom and\n", + "perform the necessary transformations as we go. We never look at the entire\n", + "program at once." + ] + }, + { + "cell_type": "markdown", + "id": "e4adad7b", + "metadata": {}, + "source": [ + "But not all transformations work this way. For example, dead-code elimination\n", + "requires traversing from bottom to top, collecting usage statistics on the way\n", + "up and eliminating pure operations whose results have no uses. Another\n", + "bottom-to-top transformation is AD transposition, which we use to implement\n", + "reverse-mode AD. For these we need to first \"stage\" the program into an IR\n", + "(internal representation), a data structure representing the program, which we\n", + "can then traverse in any order we like. Building this IR from a Python program\n", + "will be the goal of our third and final interpreter." + ] + }, + { + "cell_type": "markdown", + "id": "9c8ba558", + "metadata": {}, + "source": [ + "First, let's define the IR. We'll do an untypes ANF IR to start. A function\n", + "(we call IR functions \"jaxprs\" in JAX) will have a list of formal parameters,\n", + "a list of operations, and a return value. Each argument to an operation must\n", + "be an \"atom\", which is either a variable or a literal. The return value of the\n", + "function is also an atom." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "2100d92e", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.909146Z", + "iopub.status.busy": "2025-02-12T20:21:55.908956Z", + "iopub.status.idle": "2025-02-12T20:21:55.914391Z", + "shell.execute_reply": "2025-02-12T20:21:55.913886Z" + } + }, + "outputs": [], + "source": [ + "Var = str # Variables are just strings in this untyped IR\n", + "Atom = Var | float # Atoms (arguments to operations) can be variables or (float) literals\n", + "\n", + "# Equation - a single line in our IR like `z = mul(x, y)`\n", + "@dataclass\n", + "class Equation:\n", + " var : Var # The variable name of the result\n", + " op : Op # The primitive operation we're applying\n", + " args : tuple[Atom] # The arguments we're applying the primitive operation to\n", + "\n", + "# We call an IR function a \"Jaxpr\", for \"JAX expression\"\n", + "@dataclass\n", + "class Jaxpr:\n", + " parameters : list[Var] # The function's formal parameters (arguments)\n", + " equations : list[Equation] # The body of the function, a list of instructions/equations\n", + " return_val : Atom # The function's return value\n", + "\n", + " def __str__(self):\n", + " lines = []\n", + " lines.append(', '.join(b for b in self.parameters) + ' ->')\n", + " for eqn in self.equations:\n", + " args_str = ', '.join(str(arg) for arg in eqn.args)\n", + " lines.append(f' {eqn.var} = {eqn.op}({args_str})')\n", + " lines.append(self.return_val)\n", + " return '\\n'.join(lines)" + ] + }, + { + "cell_type": "markdown", + "id": "36720d5c", + "metadata": {}, + "source": [ + "To build the IR from a Python function we define a `StagingInterpreter` that\n", + "takes each operation and adds it to a growing list of all the operations we've\n", + "seen so far:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0ed04f2e", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.916077Z", + "iopub.status.busy": "2025-02-12T20:21:55.915878Z", + "iopub.status.idle": "2025-02-12T20:21:55.920161Z", + "shell.execute_reply": "2025-02-12T20:21:55.919753Z" + } + }, + "outputs": [], + "source": [ + "class StagingInterpreter(Interpreter):\n", + " def __init__(self):\n", + " self.equations = [] # A mutable list of all the ops we've seen so far\n", + " self.name_counter = 0 # Counter for generating unique names\n", + "\n", + " def fresh_var(self):\n", + " self.name_counter += 1\n", + " return \"v_\" + str(self.name_counter)\n", + "\n", + " def interpret_op(self, op, args):\n", + " binder = self.fresh_var()\n", + " self.equations.append(Equation(binder, op, args))\n", + " return binder\n", + "\n", + "def build_jaxpr(f, num_args):\n", + " interpreter = StagingInterpreter()\n", + " parameters = tuple(interpreter.fresh_var() for _ in range(num_args))\n", + " with set_interpreter(interpreter):\n", + " result = f(*parameters)\n", + " return Jaxpr(parameters, interpreter.equations, result)" + ] + }, + { + "cell_type": "markdown", + "id": "1bde02c9", + "metadata": {}, + "source": [ + "Now we can construct an IR for a Python program and print it out:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "606d2e23", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.921731Z", + "iopub.status.busy": "2025-02-12T20:21:55.921538Z", + "iopub.status.idle": "2025-02-12T20:21:55.924256Z", + "shell.execute_reply": "2025-02-12T20:21:55.923850Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "v_1 ->\n", + " v_2 = Op.add(v_1, 3.0)\n", + " v_3 = Op.mul(v_1, v_2)\n", + "v_3\n" + ] + } + ], + "source": [ + "print(build_jaxpr(foo, 1))" + ] + }, + { + "cell_type": "markdown", + "id": "67deabe8", + "metadata": {}, + "source": [ + "We can also evaluate our IR by writing an explicit interpreter that traverses\n", + "the operations one by one:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "6a20cc84", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.925838Z", + "iopub.status.busy": "2025-02-12T20:21:55.925646Z", + "iopub.status.idle": "2025-02-12T20:21:55.929596Z", + "shell.execute_reply": "2025-02-12T20:21:55.929187Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10.0\n" + ] + } + ], + "source": [ + "def eval_jaxpr(jaxpr, args):\n", + " # An environment mapping variables to values\n", + " env = dict(zip(jaxpr.parameters, args))\n", + " def eval_atom(x): return env[x] if isinstance(x, Var) else x\n", + " for eqn in jaxpr.equations:\n", + " args = tuple(eval_atom(x) for x in eqn.args)\n", + " env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)\n", + " return eval_atom(jaxpr.return_val)\n", + "\n", + "print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))" + ] + }, + { + "cell_type": "markdown", + "id": "c6250492", + "metadata": {}, + "source": [ + "We've written this interpreter in terms of `current_interpreter.interpret_op`\n", + "which means we've done a full round-trip: interpretable Python program to IR\n", + "to interpretable Python program. Since the result is \"interpretable\" we can\n", + "differentiate it again, or stage it out or anything we like:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "831924b8", + "metadata": { + "execution": { + "iopub.execute_input": "2025-02-12T20:21:55.931176Z", + "iopub.status.busy": "2025-02-12T20:21:55.930983Z", + "iopub.status.idle": "2025-02-12T20:21:55.933902Z", + "shell.execute_reply": "2025-02-12T20:21:55.933490Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(10.0, 7.0)\n" + ] + } + ], + "source": [ + "print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))" + ] + }, + { + "cell_type": "markdown", + "id": "d3ac5873", + "metadata": {}, + "source": [ + "## Up next..." + ] + }, + { + "cell_type": "markdown", + "id": "c3451298", + "metadata": {}, + "source": [ + "That's it for part one of this tutorial. We've done two primitives, three\n", + "interpreters and the tracing mechanism that weaves them together. In the next\n", + "part we'll add types other than floats, error handling, compilation,\n", + "reverse-mode AD and higher-order primtives. Note that the second part is\n", + "structured differently. Rather than trying to have a top-to-bottom order that\n", + "obeys both code dependencies (e.g. data structures need to be defined before\n", + "they're used) and pedagogical dependencies (concepts need to be introduced\n", + "before they're implemented) we're going with a single file that can be approached\n", + "in any order." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst,py:light" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/autodidax2_part1.md b/docs/autodidax2_part1.md new file mode 100644 index 000000000..70dd0e4b6 --- /dev/null +++ b/docs/autodidax2_part1.md @@ -0,0 +1,547 @@ +--- +jupytext: + formats: ipynb,md:myst,py:light + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +```{raw-cell} + +--- +Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +--- +``` + +# Autodidax2, part 1: JAX from scratch, again + ++++ + +If you want to understand how JAX works you could trying reading the code. But +the code is complicated, often for no good reason. This notebook presents a +stripped-back version without the cruft. It's a minimal version of JAX from +first principles. Enjoy! + ++++ + +## Main idea: context-sensitive interpretation + ++++ + +JAX is two things: + 1. a set of primitive operations (roughly the NumPy API) + 2. a set of interpreters over those primitives (compilation, AD, etc.) + +In this minimal version of JAX we'll start with just two primitive operations, +addition and multiplication, and we'll add interpreters one by one. Suppose we +have a user-defined function like this: + +```{code-cell} ipython3 +def foo(x): + return mul(x, add(x, 3.0)) +``` + +We want to be able to interpret `foo` in different ways without changing its +implementation: we want to evaluate it on concrete values, differentiate it, +stage it out to an IR, compile it and so on. + ++++ + +Here's how we'll do it. For each of these interpretations we'll define an +`Interpreter` object with a rule for handling each primitive operation. We'll +keep track of the *current* interpreter using a global context variable. The +user-facing functions `add` and `mul` will dispatch to the current +interpreter. At the beginning of the program the current interpreter will be +the "evaluating" interpreter which just evaluates the operations on ordinary +concrete data. Here's what this all looks like so far. + +```{code-cell} ipython3 +from enum import Enum, auto +from contextlib import contextmanager +from typing import Any + +# The full (closed) set of primitive operations +class Op(Enum): + add = auto() # addition on floats + mul = auto() # multiplication on floats + +# Interpreters have rules for handling each primitive operation. +class Interpreter: + def interpret_op(self, op: Op, args: tuple[Any, ...]): + assert False, "subclass should implement this" + +# Our first interpreter is the "evaluating interpreter" which performs ordinary +# concrete evaluation. +class EvalInterpreter: + def interpret_op(self, op, args): + assert all(isinstance(arg, float) for arg in args) + match op: + case Op.add: + x, y = args + return x + y + case Op.mul: + x, y = args + return x * y + case _: + raise ValueError(f"Unrecognized primitive op: {op}") + +# The current interpreter is initially the evaluating interpreter. +current_interpreter = EvalInterpreter() + +# A context manager for temporarily changing the current interpreter +@contextmanager +def set_interpreter(new_interpreter): + global current_interpreter + prev_interpreter = current_interpreter + try: + current_interpreter = new_interpreter + yield + finally: + current_interpreter = prev_interpreter + +# The user-facing functions `mul` and `add` dispatch to the current interpreter. +def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y)) +def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y)) +``` + +At this point we can call `foo` with ordinary concrete inputs and see the +results: + +```{code-cell} ipython3 +print(foo(2.0)) +``` + +## Aside: forward-mode automatic differentiation + ++++ + +For our second interpreter we're going to try forward-mode automatic +differentiation (AD). Here's a quick introduction to forward-mode AD in case +this is the first time you've come across it. Otherwise skip ahead to the +"JVPInterprer" section. + ++++ + +Suppose we're interested in the derivative of `foo(x)` evaluated at `x=2.0`. +We could approximate it with finite differences: + +```{code-cell} ipython3 +print((foo(2.00001) - foo(2.0)) / 0.00001) +``` + +The answer is close to 7.0 as expected. But computing it this way required two +evaluations of the function (not to mention the roundoff error and truncation +error). Here's a funny thing though. We can almost get the answer with a +single evaluation: + +```{code-cell} ipython3 +print(foo(2.00001)) +``` + +The answer we're looking for, 7.0, is right there in the insignificant digits! + ++++ + +Here's one way to think about what's happening. The initial argument to `foo`, +`2.00001`, carries two pieces of data: a "primal" value, 2.0, and a "tangent" +value, `1.0`. The representation of this primal-tangent pair, `2.00001`, is +the sum of the two, with the tangent scaled by a small fixed epsilon, `1e-5`. +Ordinary evaluation of `foo(2.00001)` propagates this primal-tangent pair, +producing `10.0000700001` as the result. The primal and tangent components are +well separated in scale so we can visually interpret the result as the +primal-tangent pair (10.0, 7.0), ignoring the the ~1e-10 truncation error at +the end. + ++++ + +The idea with forward-mode differentiation is to do the same thing but exactly +and explicitly (eyeballing floats doesn't really scale). We'll represent the +primal-tangent pair as an actual pair instead of folding them both into a +single floating point number. For each primitive operation we'll have a rule +that describes how to propagate these primal tangent pairs. Let's work out the +rules for our two primitives. + ++++ + +Addition is easy. Consider `x + y` where `x = xp + xt * eps` and `y = yp + yt * eps` +("p" for "primal", "t" for "tangent"): + + x + y = (xp + xt * eps) + (yp + yt * eps) + = (xp + yp) # primal component + + (xt + yt) * eps # tangent component + +The result is a first-order polynomial in `eps` and we can read off the +primal-tangent pair as (xp + yp, xt + yt). + ++++ + +Multiplication is more interesting: + + x * y = (xp + xt * eps) * (yp + yt * eps) + = (xp * yp) # primal component + + (xp * yt + xt * yp) * eps # tangent component + + (xt * yt) * eps * eps # quadratic component, vanishes in the eps->0 limit + +Now we have a second order polynomial. But as epsilon goes to zero the +quadratic term vanishes and our primal-tangent pair +is just `(xp * yp, xp * yt + xt * yp)` +(In our earlier example with finite `eps` this term not vanishing is +why we had the 1e-10 "truncation error"). + ++++ + +Putting this into code, we can write down the forward-AD rules for addition +and multiplication and express `foo` in terms of these: + +```{code-cell} ipython3 +from dataclasses import dataclass + +# A primal-tangent pair is conventionally called a "dual number" +@dataclass +class DualNumber: + primal : float + tangent : float + +def add_dual(x : DualNumber, y: DualNumber) -> DualNumber: + return DualNumber(x.primal + y.primal, x.tangent + y.tangent) + +def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber: + return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal) + +def foo_dual(x : DualNumber) -> DualNumber: + return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0))) + +print (foo_dual(DualNumber(2.0, 1.0))) +``` + +That works! But rewriting `foo` to use the `_dual` versions of addition and +multiplication was a bit tedious. Let's get back to the main program and use +our interpretation machinery to do the rewrite automatically. + ++++ + +## JVP Interpreter + ++++ + +We'll set up a new interpreter called `JVPInterpreter` ("JVP" for +"Jacobian-vector product") which propagates these dual numbers instead of +ordinary values. The `JVPInterpreter` has methods 'add' and 'mul' that operate +on dual number. They cast constant arguments to dual numbers as needed by +calling `JVPInterpreter.lift`. In our manually rewritten version above we did +that by replacing the literal `3.0` with `DualNumber(3.0, 0.0)`. + +```{code-cell} ipython3 +# This is like DualNumber above except that is also has a pointer to the +# interpreter it belongs to, which is needed to avoid "perturbation confusion" +# in higher order differentiation. +@dataclass +class TaggedDualNumber: + interpreter : Interpreter + primal : float + tangent : float + +class JVPInterpreter(Interpreter): + def __init__(self, prev_interpreter: Interpreter): + # We keep a pointer to the interpreter that was current when this + # interpreter was first invoked. That's the context in which our + # rules should run. + self.prev_interpreter = prev_interpreter + + def interpret_op(self, op, args): + args = tuple(self.lift(arg) for arg in args) + with set_interpreter(self.prev_interpreter): + match op: + case Op.add: + # Notice that we use `add` and `mul` here, which are the + # interpreter-dispatching functions defined earlier. + x, y = args + return self.dual_number( + add(x.primal, y.primal), + add(x.tangent, y.tangent)) + + case Op.mul: + x, y = args + x = self.lift(x) + y = self.lift(y) + return self.dual_number( + mul(x.primal, y.primal), + add(mul(x.primal, y.tangent), mul(x.tangent, y.primal))) + + def dual_number(self, primal, tangent): + return TaggedDualNumber(self, primal, tangent) + + # Lift a constant value (constant with respect to this interpreter) to + # a TaggedDualNumber. + def lift(self, x): + if isinstance(x, TaggedDualNumber) and x.interpreter is self: + return x + else: + return self.dual_number(x, 0.0) + +def jvp(f, primal, tangent): + jvp_interpreter = JVPInterpreter(current_interpreter) + dual_number_in = jvp_interpreter.dual_number(primal, tangent) + with set_interpreter(jvp_interpreter): + result = f(dual_number_in) + dual_number_out = jvp_interpreter.lift(result) + return dual_number_out.primal, dual_number_out.tangent + +# Let's try it out: +print(jvp(foo, 2.0, 1.0)) + +# Because we were careful to consider nesting interpreters, higher-order AD +# works out of the box: + +def derivative(f, x): + _, tangent = jvp(f, x, 1.0) + return tangent + +def nth_order_derivative(n, f, x): + if n == 0: + return f(x) + else: + return derivative(lambda x: nth_order_derivative(n-1, f, x), x) +``` + +```{code-cell} ipython3 +print(nth_order_derivative(0, foo, 2.0)) +``` + +```{code-cell} ipython3 +print(nth_order_derivative(1, foo, 2.0)) +``` + +```{code-cell} ipython3 +print(nth_order_derivative(2, foo, 2.0)) +``` + +```{code-cell} ipython3 +# The rest are zero because `foo` is only a second-order polymonial +print(nth_order_derivative(3, foo, 2.0)) +``` + +```{code-cell} ipython3 +print(nth_order_derivative(4, foo, 2.0)) +``` + +There are some subtleties worth discussing. First, how do you tell if +something is constant with respect to differentiation? It's tempting to say +"it's a constant if and only if it's not a dual number". But actually dual +numbers created by a *different* JVPInterpreter also need to be considered +constants with resepect to the JVPInterpreter we're currently handling. That's +why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This +comes up in higher order differentiation when there are multiple JVPInterprers +in scope. The sort of bug where you accidentally interpret a dual number from +a different interpreter as non-constant is sometimes called "perturbation +confusion" in the literature. Here's an example program that would have given +the wrong answer if we hadn't had the `and x.interpreter is self` check in +`JVPInterpreter.lift`. + +```{code-cell} ipython3 +def f(x): + # g is constant in its (ignored) argument `y`. Its derivative should be zero + # but our AD will mess it up if we don't distinguish perturbations from + # different interpreters. + def g(y): + return x + should_be_zero = derivative(g, 0.0) + return mul(x, should_be_zero) + +print(derivative(f, 0.0)) +``` + +Another subtlety: `JVPInterpreter.add` and `JVPInterpreter.mul` describe +addition and multiplication on dual numbers in terms of addition and +multiplication on the primal and tangent components. But we don't use ordinary +`+` and `*` for this. Instead we use our own `add` and `mul` functions which +dispatch to the current interpreter. Before calling them we set the current +interpreter to be the *previous* interpreter, i.e. the interpreter that was +current when `JVPInterpreter` was first invoked. If we didn't do this we'd +have an infinite recursion, with `add` and `mul` dispatching to +`JVPInterpreter` endlessly. The advantage of using own `add` and `mul` instead +of ordinary `+` and `*` is that it means we can nest these interpreters and do +higher-order AD. + ++++ + +At this point you might be wondering: have we just reinvented operator +overloading? Python overloads the infix ops `+` and `*` to dispatch to the +argument's `__add__` and `__mul__`. Could we have just used that mechanism +instead of this whole interpreter business? Yes, actually. Indeed, the earlier +automatic differentiation (AD) literature uses the term "operator overloading" +to describe this style of AD implementation. One detail is that we can't rely +exclusively on Python built-in overloading because that only lets us overload +a handful of built-in infix ops whereas we eventually want to overload +numpy-level operations like `sin` and `cos`. So we need our own mechanism. + ++++ + +But there's a more important difference: our dispatch is based on *context* +whereas traditional Python-style overloading is based on *data*. This is +actually a recent development for JAX. The earliest versions of JAX looked +more like traditional data-based overloading. An interpreter (a "trace" in JAX +jargon) for an operation would be chosen based on data attached to the +arguments to that operation. We've gradually made the interpreter-dispatch +decision rely more and more on context rather than data (omnistaging [link], +stackless [link]). The reason to prefer context-based interpretation over +data-based interpretation is that it makes the implementation much simpler. + ++++ + +All that said, we do *also* want to take advantage of Python's built-in +overloading mechanism. That way we get the syntactic convenience of using +infix operators `+` and `*` instead of writing out `add(..)` and `mul(..)`. +But we'll put that aside for now. + ++++ + +# 3. Staging to an untyped IR + ++++ + +The two program transformations we've seen so far -- evaluation and JVP -- +both traverse the input program from top to bottom. They visit the operations +one by one in the same order as ordinary evaluation. A convenient thing about +top-to-bottom transformations is that they can be implemented eagerly, or +"online", meaning that we can evaluate the program from top to bottom and +perform the necessary transformations as we go. We never look at the entire +program at once. + ++++ + +But not all transformations work this way. For example, dead-code elimination +requires traversing from bottom to top, collecting usage statistics on the way +up and eliminating pure operations whose results have no uses. Another +bottom-to-top transformation is AD transposition, which we use to implement +reverse-mode AD. For these we need to first "stage" the program into an IR +(internal representation), a data structure representing the program, which we +can then traverse in any order we like. Building this IR from a Python program +will be the goal of our third and final interpreter. + ++++ + +First, let's define the IR. We'll do an untypes ANF IR to start. A function +(we call IR functions "jaxprs" in JAX) will have a list of formal parameters, +a list of operations, and a return value. Each argument to an operation must +be an "atom", which is either a variable or a literal. The return value of the +function is also an atom. + +```{code-cell} ipython3 +Var = str # Variables are just strings in this untyped IR +Atom = Var | float # Atoms (arguments to operations) can be variables or (float) literals + +# Equation - a single line in our IR like `z = mul(x, y)` +@dataclass +class Equation: + var : Var # The variable name of the result + op : Op # The primitive operation we're applying + args : tuple[Atom] # The arguments we're applying the primitive operation to + +# We call an IR function a "Jaxpr", for "JAX expression" +@dataclass +class Jaxpr: + parameters : list[Var] # The function's formal parameters (arguments) + equations : list[Equation] # The body of the function, a list of instructions/equations + return_val : Atom # The function's return value + + def __str__(self): + lines = [] + lines.append(', '.join(b for b in self.parameters) + ' ->') + for eqn in self.equations: + args_str = ', '.join(str(arg) for arg in eqn.args) + lines.append(f' {eqn.var} = {eqn.op}({args_str})') + lines.append(self.return_val) + return '\n'.join(lines) +``` + +To build the IR from a Python function we define a `StagingInterpreter` that +takes each operation and adds it to a growing list of all the operations we've +seen so far: + +```{code-cell} ipython3 +class StagingInterpreter(Interpreter): + def __init__(self): + self.equations = [] # A mutable list of all the ops we've seen so far + self.name_counter = 0 # Counter for generating unique names + + def fresh_var(self): + self.name_counter += 1 + return "v_" + str(self.name_counter) + + def interpret_op(self, op, args): + binder = self.fresh_var() + self.equations.append(Equation(binder, op, args)) + return binder + +def build_jaxpr(f, num_args): + interpreter = StagingInterpreter() + parameters = tuple(interpreter.fresh_var() for _ in range(num_args)) + with set_interpreter(interpreter): + result = f(*parameters) + return Jaxpr(parameters, interpreter.equations, result) +``` + +Now we can construct an IR for a Python program and print it out: + +```{code-cell} ipython3 +print(build_jaxpr(foo, 1)) +``` + +We can also evaluate our IR by writing an explicit interpreter that traverses +the operations one by one: + +```{code-cell} ipython3 +def eval_jaxpr(jaxpr, args): + # An environment mapping variables to values + env = dict(zip(jaxpr.parameters, args)) + def eval_atom(x): return env[x] if isinstance(x, Var) else x + for eqn in jaxpr.equations: + args = tuple(eval_atom(x) for x in eqn.args) + env[eqn.var] = current_interpreter.interpret_op(eqn.op, args) + return eval_atom(jaxpr.return_val) + +print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,))) +``` + +We've written this interpreter in terms of `current_interpreter.interpret_op` +which means we've done a full round-trip: interpretable Python program to IR +to interpretable Python program. Since the result is "interpretable" we can +differentiate it again, or stage it out or anything we like: + +```{code-cell} ipython3 +print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0)) +``` + +## Up next... + ++++ + +That's it for part one of this tutorial. We've done two primitives, three +interpreters and the tracing mechanism that weaves them together. In the next +part we'll add types other than floats, error handling, compilation, +reverse-mode AD and higher-order primtives. Note that the second part is +structured differently. Rather than trying to have a top-to-bottom order that +obeys both code dependencies (e.g. data structures need to be defined before +they're used) and pedagogical dependencies (concepts need to be introduced +before they're implemented) we're going with a single file that can be approached +in any order. diff --git a/docs/autodidax2_part1.py b/docs/autodidax2_part1.py new file mode 100644 index 000000000..bfe59df35 --- /dev/null +++ b/docs/autodidax2_part1.py @@ -0,0 +1,491 @@ +# --- +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# jupyter: +# jupytext: +# formats: ipynb,md:myst,py:light +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.4 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# # Autodidax2, part 1: JAX from scratch, again + +# If you want to understand how JAX works you could trying reading the code. But +# the code is complicated, often for no good reason. This notebook presents a +# stripped-back version without the cruft. It's a minimal version of JAX from +# first principles. Enjoy! + +# ## Main idea: context-sensitive interpretation + +# JAX is two things: +# 1. a set of primitive operations (roughly the NumPy API) +# 2. a set of interpreters over those primitives (compilation, AD, etc.) +# +# In this minimal version of JAX we'll start with just two primitive operations, +# addition and multiplication, and we'll add interpreters one by one. Suppose we +# have a user-defined function like this: + +def foo(x): + return mul(x, add(x, 3.0)) + + +# We want to be able to interpret `foo` in different ways without changing its +# implementation: we want to evaluate it on concrete values, differentiate it, +# stage it out to an IR, compile it and so on. + +# Here's how we'll do it. For each of these interpretations we'll define an +# `Interpreter` object with a rule for handling each primitive operation. We'll +# keep track of the *current* interpreter using a global context variable. The +# user-facing functions `add` and `mul` will dispatch to the current +# interpreter. At the beginning of the program the current interpreter will be +# the "evaluating" interpreter which just evaluates the operations on ordinary +# concrete data. Here's what this all looks like so far. + +# + +from enum import Enum, auto +from contextlib import contextmanager +from typing import Any + +# The full (closed) set of primitive operations +class Op(Enum): + add = auto() # addition on floats + mul = auto() # multiplication on floats + +# Interpreters have rules for handling each primitive operation. +class Interpreter: + def interpret_op(self, op: Op, args: tuple[Any, ...]): + assert False, "subclass should implement this" + +# Our first interpreter is the "evaluating interpreter" which performs ordinary +# concrete evaluation. +class EvalInterpreter: + def interpret_op(self, op, args): + assert all(isinstance(arg, float) for arg in args) + match op: + case Op.add: + x, y = args + return x + y + case Op.mul: + x, y = args + return x * y + case _: + raise ValueError(f"Unrecognized primitive op: {op}") + +# The current interpreter is initially the evaluating interpreter. +current_interpreter = EvalInterpreter() + +# A context manager for temporarily changing the current interpreter +@contextmanager +def set_interpreter(new_interpreter): + global current_interpreter + prev_interpreter = current_interpreter + try: + current_interpreter = new_interpreter + yield + finally: + current_interpreter = prev_interpreter + +# The user-facing functions `mul` and `add` dispatch to the current interpreter. +def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y)) +def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y)) + + +# - + +# At this point we can call `foo` with ordinary concrete inputs and see the +# results: + +print(foo(2.0)) + +# ## Aside: forward-mode automatic differentiation + +# For our second interpreter we're going to try forward-mode automatic +# differentiation (AD). Here's a quick introduction to forward-mode AD in case +# this is the first time you've come across it. Otherwise skip ahead to the +# "JVPInterprer" section. + +# Suppose we're interested in the derivative of `foo(x)` evaluated at `x=2.0`. +# We could approximate it with finite differences: + +print((foo(2.00001) - foo(2.0)) / 0.00001) + +# The answer is close to 7.0 as expected. But computing it this way required two +# evaluations of the function (not to mention the roundoff error and truncation +# error). Here's a funny thing though. We can almost get the answer with a +# single evaluation: + +print(foo(2.00001)) + +# The answer we're looking for, 7.0, is right there in the insignificant digits! + +# Here's one way to think about what's happening. The initial argument to `foo`, +# `2.00001`, carries two pieces of data: a "primal" value, 2.0, and a "tangent" +# value, `1.0`. The representation of this primal-tangent pair, `2.00001`, is +# the sum of the two, with the tangent scaled by a small fixed epsilon, `1e-5`. +# Ordinary evaluation of `foo(2.00001)` propagates this primal-tangent pair, +# producing `10.0000700001` as the result. The primal and tangent components are +# well separated in scale so we can visually interpret the result as the +# primal-tangent pair (10.0, 7.0), ignoring the the ~1e-10 truncation error at +# the end. + +# The idea with forward-mode differentiation is to do the same thing but exactly +# and explicitly (eyeballing floats doesn't really scale). We'll represent the +# primal-tangent pair as an actual pair instead of folding them both into a +# single floating point number. For each primitive operation we'll have a rule +# that describes how to propagate these primal tangent pairs. Let's work out the +# rules for our two primitives. + +# Addition is easy. Consider `x + y` where `x = xp + xt * eps` and `y = yp + yt * eps` +# ("p" for "primal", "t" for "tangent"): +# +# x + y = (xp + xt * eps) + (yp + yt * eps) +# = (xp + yp) # primal component +# + (xt + yt) * eps # tangent component +# +# The result is a first-order polynomial in `eps` and we can read off the +# primal-tangent pair as (xp + yp, xt + yt). + +# Multiplication is more interesting: +# +# x * y = (xp + xt * eps) * (yp + yt * eps) +# = (xp * yp) # primal component +# + (xp * yt + xt * yp) * eps # tangent component +# + (xt * yt) * eps * eps # quadratic component, vanishes in the eps->0 limit +# +# Now we have a second order polynomial. But as epsilon goes to zero the +# quadratic term vanishes and our primal-tangent pair +# is just `(xp * yp, xp * yt + xt * yp)` +# (In our earlier example with finite `eps` this term not vanishing is +# why we had the 1e-10 "truncation error"). + +# Putting this into code, we can write down the forward-AD rules for addition +# and multiplication and express `foo` in terms of these: + +# + +from dataclasses import dataclass + +# A primal-tangent pair is conventionally called a "dual number" +@dataclass +class DualNumber: + primal : float + tangent : float + +def add_dual(x : DualNumber, y: DualNumber) -> DualNumber: + return DualNumber(x.primal + y.primal, x.tangent + y.tangent) + +def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber: + return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal) + +def foo_dual(x : DualNumber) -> DualNumber: + return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0))) + +print (foo_dual(DualNumber(2.0, 1.0))) + + +# - + +# That works! But rewriting `foo` to use the `_dual` versions of addition and +# multiplication was a bit tedious. Let's get back to the main program and use +# our interpretation machinery to do the rewrite automatically. + +# ## JVP Interpreter + +# We'll set up a new interpreter called `JVPInterpreter` ("JVP" for +# "Jacobian-vector product") which propagates these dual numbers instead of +# ordinary values. The `JVPInterpreter` has methods 'add' and 'mul' that operate +# on dual number. They cast constant arguments to dual numbers as needed by +# calling `JVPInterpreter.lift`. In our manually rewritten version above we did +# that by replacing the literal `3.0` with `DualNumber(3.0, 0.0)`. + +# + +# This is like DualNumber above except that is also has a pointer to the +# interpreter it belongs to, which is needed to avoid "perturbation confusion" +# in higher order differentiation. +@dataclass +class TaggedDualNumber: + interpreter : Interpreter + primal : float + tangent : float + +class JVPInterpreter(Interpreter): + def __init__(self, prev_interpreter: Interpreter): + # We keep a pointer to the interpreter that was current when this + # interpreter was first invoked. That's the context in which our + # rules should run. + self.prev_interpreter = prev_interpreter + + def interpret_op(self, op, args): + args = tuple(self.lift(arg) for arg in args) + with set_interpreter(self.prev_interpreter): + match op: + case Op.add: + # Notice that we use `add` and `mul` here, which are the + # interpreter-dispatching functions defined earlier. + x, y = args + return self.dual_number( + add(x.primal, y.primal), + add(x.tangent, y.tangent)) + + case Op.mul: + x, y = args + x = self.lift(x) + y = self.lift(y) + return self.dual_number( + mul(x.primal, y.primal), + add(mul(x.primal, y.tangent), mul(x.tangent, y.primal))) + + def dual_number(self, primal, tangent): + return TaggedDualNumber(self, primal, tangent) + + # Lift a constant value (constant with respect to this interpreter) to + # a TaggedDualNumber. + def lift(self, x): + if isinstance(x, TaggedDualNumber) and x.interpreter is self: + return x + else: + return self.dual_number(x, 0.0) + +def jvp(f, primal, tangent): + jvp_interpreter = JVPInterpreter(current_interpreter) + dual_number_in = jvp_interpreter.dual_number(primal, tangent) + with set_interpreter(jvp_interpreter): + result = f(dual_number_in) + dual_number_out = jvp_interpreter.lift(result) + return dual_number_out.primal, dual_number_out.tangent + +# Let's try it out: +print(jvp(foo, 2.0, 1.0)) + +# Because we were careful to consider nesting interpreters, higher-order AD +# works out of the box: + +def derivative(f, x): + _, tangent = jvp(f, x, 1.0) + return tangent + +def nth_order_derivative(n, f, x): + if n == 0: + return f(x) + else: + return derivative(lambda x: nth_order_derivative(n-1, f, x), x) + + +# - + +print(nth_order_derivative(0, foo, 2.0)) + +print(nth_order_derivative(1, foo, 2.0)) + +print(nth_order_derivative(2, foo, 2.0)) + +# The rest are zero because `foo` is only a second-order polymonial +print(nth_order_derivative(3, foo, 2.0)) + +print(nth_order_derivative(4, foo, 2.0)) + + +# There are some subtleties worth discussing. First, how do you tell if +# something is constant with respect to differentiation? It's tempting to say +# "it's a constant if and only if it's not a dual number". But actually dual +# numbers created by a *different* JVPInterpreter also need to be considered +# constants with resepect to the JVPInterpreter we're currently handling. That's +# why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This +# comes up in higher order differentiation when there are multiple JVPInterprers +# in scope. The sort of bug where you accidentally interpret a dual number from +# a different interpreter as non-constant is sometimes called "perturbation +# confusion" in the literature. Here's an example program that would have given +# the wrong answer if we hadn't had the `and x.interpreter is self` check in +# `JVPInterpreter.lift`. + +# + +def f(x): + # g is constant in its (ignored) argument `y`. Its derivative should be zero + # but our AD will mess it up if we don't distinguish perturbations from + # different interpreters. + def g(y): + return x + should_be_zero = derivative(g, 0.0) + return mul(x, should_be_zero) + +print(derivative(f, 0.0)) +# - + +# Another subtlety: `JVPInterpreter.add` and `JVPInterpreter.mul` describe +# addition and multiplication on dual numbers in terms of addition and +# multiplication on the primal and tangent components. But we don't use ordinary +# `+` and `*` for this. Instead we use our own `add` and `mul` functions which +# dispatch to the current interpreter. Before calling them we set the current +# interpreter to be the *previous* interpreter, i.e. the interpreter that was +# current when `JVPInterpreter` was first invoked. If we didn't do this we'd +# have an infinite recursion, with `add` and `mul` dispatching to +# `JVPInterpreter` endlessly. The advantage of using own `add` and `mul` instead +# of ordinary `+` and `*` is that it means we can nest these interpreters and do +# higher-order AD. + +# At this point you might be wondering: have we just reinvented operator +# overloading? Python overloads the infix ops `+` and `*` to dispatch to the +# argument's `__add__` and `__mul__`. Could we have just used that mechanism +# instead of this whole interpreter business? Yes, actually. Indeed, the earlier +# automatic differentiation (AD) literature uses the term "operator overloading" +# to describe this style of AD implementation. One detail is that we can't rely +# exclusively on Python built-in overloading because that only lets us overload +# a handful of built-in infix ops whereas we eventually want to overload +# numpy-level operations like `sin` and `cos`. So we need our own mechanism. + +# But there's a more important difference: our dispatch is based on *context* +# whereas traditional Python-style overloading is based on *data*. This is +# actually a recent development for JAX. The earliest versions of JAX looked +# more like traditional data-based overloading. An interpreter (a "trace" in JAX +# jargon) for an operation would be chosen based on data attached to the +# arguments to that operation. We've gradually made the interpreter-dispatch +# decision rely more and more on context rather than data (omnistaging [link], +# stackless [link]). The reason to prefer context-based interpretation over +# data-based interpretation is that it makes the implementation much simpler. + +# All that said, we do *also* want to take advantage of Python's built-in +# overloading mechanism. That way we get the syntactic convenience of using +# infix operators `+` and `*` instead of writing out `add(..)` and `mul(..)`. +# But we'll put that aside for now. + +# # 3. Staging to an untyped IR + +# The two program transformations we've seen so far -- evaluation and JVP -- +# both traverse the input program from top to bottom. They visit the operations +# one by one in the same order as ordinary evaluation. A convenient thing about +# top-to-bottom transformations is that they can be implemented eagerly, or +# "online", meaning that we can evaluate the program from top to bottom and +# perform the necessary transformations as we go. We never look at the entire +# program at once. + +# But not all transformations work this way. For example, dead-code elimination +# requires traversing from bottom to top, collecting usage statistics on the way +# up and eliminating pure operations whose results have no uses. Another +# bottom-to-top transformation is AD transposition, which we use to implement +# reverse-mode AD. For these we need to first "stage" the program into an IR +# (internal representation), a data structure representing the program, which we +# can then traverse in any order we like. Building this IR from a Python program +# will be the goal of our third and final interpreter. + +# First, let's define the IR. We'll do an untypes ANF IR to start. A function +# (we call IR functions "jaxprs" in JAX) will have a list of formal parameters, +# a list of operations, and a return value. Each argument to an operation must +# be an "atom", which is either a variable or a literal. The return value of the +# function is also an atom. + +# + +Var = str # Variables are just strings in this untyped IR +Atom = Var | float # Atoms (arguments to operations) can be variables or (float) literals + +# Equation - a single line in our IR like `z = mul(x, y)` +@dataclass +class Equation: + var : Var # The variable name of the result + op : Op # The primitive operation we're applying + args : tuple[Atom] # The arguments we're applying the primitive operation to + +# We call an IR function a "Jaxpr", for "JAX expression" +@dataclass +class Jaxpr: + parameters : list[Var] # The function's formal parameters (arguments) + equations : list[Equation] # The body of the function, a list of instructions/equations + return_val : Atom # The function's return value + + def __str__(self): + lines = [] + lines.append(', '.join(b for b in self.parameters) + ' ->') + for eqn in self.equations: + args_str = ', '.join(str(arg) for arg in eqn.args) + lines.append(f' {eqn.var} = {eqn.op}({args_str})') + lines.append(self.return_val) + return '\n'.join(lines) + + +# - + +# To build the IR from a Python function we define a `StagingInterpreter` that +# takes each operation and adds it to a growing list of all the operations we've +# seen so far: + +# + +class StagingInterpreter(Interpreter): + def __init__(self): + self.equations = [] # A mutable list of all the ops we've seen so far + self.name_counter = 0 # Counter for generating unique names + + def fresh_var(self): + self.name_counter += 1 + return "v_" + str(self.name_counter) + + def interpret_op(self, op, args): + binder = self.fresh_var() + self.equations.append(Equation(binder, op, args)) + return binder + +def build_jaxpr(f, num_args): + interpreter = StagingInterpreter() + parameters = tuple(interpreter.fresh_var() for _ in range(num_args)) + with set_interpreter(interpreter): + result = f(*parameters) + return Jaxpr(parameters, interpreter.equations, result) + + +# - + +# Now we can construct an IR for a Python program and print it out: + +print(build_jaxpr(foo, 1)) + + +# We can also evaluate our IR by writing an explicit interpreter that traverses +# the operations one by one: + +# + +def eval_jaxpr(jaxpr, args): + # An environment mapping variables to values + env = dict(zip(jaxpr.parameters, args)) + def eval_atom(x): return env[x] if isinstance(x, Var) else x + for eqn in jaxpr.equations: + args = tuple(eval_atom(x) for x in eqn.args) + env[eqn.var] = current_interpreter.interpret_op(eqn.op, args) + return eval_atom(jaxpr.return_val) + +print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,))) +# - + +# We've written this interpreter in terms of `current_interpreter.interpret_op` +# which means we've done a full round-trip: interpretable Python program to IR +# to interpretable Python program. Since the result is "interpretable" we can +# differentiate it again, or stage it out or anything we like: + +print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0)) + +# ## Up next... + +# That's it for part one of this tutorial. We've done two primitives, three +# interpreters and the tracing mechanism that weaves them together. In the next +# part we'll add types other than floats, error handling, compilation, +# reverse-mode AD and higher-order primtives. Note that the second part is +# structured differently. Rather than trying to have a top-to-bottom order that +# obeys both code dependencies (e.g. data structures need to be defined before +# they're used) and pedagogical dependencies (concepts need to be introduced +# before they're implemented) we're going with a single file that can be approached +# in any order. diff --git a/docs/conf.py b/docs/conf.py index 295347c75..13281dff8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -138,6 +138,7 @@ exclude_patterns = [ 'pallas/tpu/matmul.md', 'jep/9407-type-promotion.md', 'autodidax.md', + 'autodidax2_part1.md', 'sharded-computation.md', 'ffi.ipynb', ] diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst index f89122f94..81e1f5c99 100644 --- a/docs/contributor_guide.rst +++ b/docs/contributor_guide.rst @@ -24,4 +24,5 @@ some of JAX's (extensible) internals. :caption: Design and internals autodidax + autodidax2_part1 jep/index