mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
1083 lines
34 KiB
Plaintext
1083 lines
34 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "raw",
|
|
"id": "e515a630",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"Copyright 2025 The JAX Authors.\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
|
"you may not use this file except in compliance with the License.\n",
|
|
"You may obtain a copy of the License at\n",
|
|
"\n",
|
|
" https://www.apache.org/licenses/LICENSE-2.0\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software\n",
|
|
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
|
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
|
"See the License for the specific language governing permissions and\n",
|
|
"limitations under the License.\n",
|
|
"\n",
|
|
"---"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "25ac15cc",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Autodidax2, part 1: JAX from scratch, again"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6999020a",
|
|
"metadata": {},
|
|
"source": [
|
|
"If you want to understand how JAX works you could trying reading the code. But\n",
|
|
"the code is complicated, often for no good reason. This notebook presents a\n",
|
|
"stripped-back version without the cruft. It's a minimal version of JAX from\n",
|
|
"first principles. Enjoy!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "62dde49f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Main idea: context-sensitive interpretation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d13d5272",
|
|
"metadata": {},
|
|
"source": [
|
|
"JAX is two things:\n",
|
|
" 1. a set of primitive operations (roughly the NumPy API)\n",
|
|
" 2. a set of interpreters over those primitives (compilation, AD, etc.)\n",
|
|
"\n",
|
|
"In this minimal version of JAX we'll start with just two primitive operations,\n",
|
|
"addition and multiplication, and we'll add interpreters one by one. Suppose we\n",
|
|
"have a user-defined function like this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "9f179429",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.809256Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.808149Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.827374Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.826143Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def foo(x):\n",
|
|
" return mul(x, add(x, 3.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "809d67a6",
|
|
"metadata": {},
|
|
"source": [
|
|
"We want to be able to interpret `foo` in different ways without changing its\n",
|
|
"implementation: we want to evaluate it on concrete values, differentiate it,\n",
|
|
"stage it out to an IR, compile it and so on."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a4235f52",
|
|
"metadata": {},
|
|
"source": [
|
|
"Here's how we'll do it. For each of these interpretations we'll define an\n",
|
|
"`Interpreter` object with a rule for handling each primitive operation. We'll\n",
|
|
"keep track of the *current* interpreter using a global context variable. The\n",
|
|
"user-facing functions `add` and `mul` will dispatch to the current\n",
|
|
"interpreter. At the beginning of the program the current interpreter will be\n",
|
|
"the \"evaluating\" interpreter which just evaluates the operations on ordinary\n",
|
|
"concrete data. Here's what this all looks like so far."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "65b26bdc",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.830948Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.830603Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.842672Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.841832Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from enum import Enum, auto\n",
|
|
"from contextlib import contextmanager\n",
|
|
"from typing import Any\n",
|
|
"\n",
|
|
"# The full (closed) set of primitive operations\n",
|
|
"class Op(Enum):\n",
|
|
" add = auto() # addition on floats\n",
|
|
" mul = auto() # multiplication on floats\n",
|
|
"\n",
|
|
"# Interpreters have rules for handling each primitive operation.\n",
|
|
"class Interpreter:\n",
|
|
" def interpret_op(self, op: Op, args: tuple[Any, ...]):\n",
|
|
" assert False, \"subclass should implement this\"\n",
|
|
"\n",
|
|
"# Our first interpreter is the \"evaluating interpreter\" which performs ordinary\n",
|
|
"# concrete evaluation.\n",
|
|
"class EvalInterpreter:\n",
|
|
" def interpret_op(self, op, args):\n",
|
|
" assert all(isinstance(arg, float) for arg in args)\n",
|
|
" match op:\n",
|
|
" case Op.add:\n",
|
|
" x, y = args\n",
|
|
" return x + y\n",
|
|
" case Op.mul:\n",
|
|
" x, y = args\n",
|
|
" return x * y\n",
|
|
" case _:\n",
|
|
" raise ValueError(f\"Unrecognized primitive op: {op}\")\n",
|
|
"\n",
|
|
"# The current interpreter is initially the evaluating interpreter.\n",
|
|
"current_interpreter = EvalInterpreter()\n",
|
|
"\n",
|
|
"# A context manager for temporarily changing the current interpreter\n",
|
|
"@contextmanager\n",
|
|
"def set_interpreter(new_interpreter):\n",
|
|
" global current_interpreter\n",
|
|
" prev_interpreter = current_interpreter\n",
|
|
" try:\n",
|
|
" current_interpreter = new_interpreter\n",
|
|
" yield\n",
|
|
" finally:\n",
|
|
" current_interpreter = prev_interpreter\n",
|
|
"\n",
|
|
"# The user-facing functions `mul` and `add` dispatch to the current interpreter.\n",
|
|
"def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))\n",
|
|
"def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e7d4ff3a",
|
|
"metadata": {},
|
|
"source": [
|
|
"At this point we can call `foo` with ordinary concrete inputs and see the\n",
|
|
"results:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "5aa8511c",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.846387Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.846085Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.850202Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.849420Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(foo(2.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "175587ca",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Aside: forward-mode automatic differentiation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e003ba3f",
|
|
"metadata": {},
|
|
"source": [
|
|
"For our second interpreter we're going to try forward-mode automatic\n",
|
|
"differentiation (AD). Here's a quick introduction to forward-mode AD in case\n",
|
|
"this is the first time you've come across it. Otherwise skip ahead to the\n",
|
|
"\"JVPInterprer\" section."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3bb1dde9",
|
|
"metadata": {},
|
|
"source": [
|
|
"Suppose we're interested in the derivative of `foo(x)` evaluated at `x=2.0`.\n",
|
|
"We could approximate it with finite differences:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "9151fcd4",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.852192Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.852015Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.855275Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.854676Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"7.000009999913458\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print((foo(2.00001) - foo(2.0)) / 0.00001)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9c3ce8ae",
|
|
"metadata": {},
|
|
"source": [
|
|
"The answer is close to 7.0 as expected. But computing it this way required two\n",
|
|
"evaluations of the function (not to mention the roundoff error and truncation\n",
|
|
"error). Here's a funny thing though. We can almost get the answer with a\n",
|
|
"single evaluation:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "cba962a2",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.857141Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.856974Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.859864Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.859432Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10.0000700001\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(foo(2.00001))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a256c797",
|
|
"metadata": {},
|
|
"source": [
|
|
"The answer we're looking for, 7.0, is right there in the insignificant digits!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "171a1aab",
|
|
"metadata": {},
|
|
"source": [
|
|
"Here's one way to think about what's happening. The initial argument to `foo`,\n",
|
|
"`2.00001`, carries two pieces of data: a \"primal\" value, 2.0, and a \"tangent\"\n",
|
|
"value, `1.0`. The representation of this primal-tangent pair, `2.00001`, is\n",
|
|
"the sum of the two, with the tangent scaled by a small fixed epsilon, `1e-5`.\n",
|
|
"Ordinary evaluation of `foo(2.00001)` propagates this primal-tangent pair,\n",
|
|
"producing `10.0000700001` as the result. The primal and tangent components are\n",
|
|
"well separated in scale so we can visually interpret the result as the\n",
|
|
"primal-tangent pair (10.0, 7.0), ignoring the the ~1e-10 truncation error at\n",
|
|
"the end."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "47420177",
|
|
"metadata": {},
|
|
"source": [
|
|
"The idea with forward-mode differentiation is to do the same thing but exactly\n",
|
|
"and explicitly (eyeballing floats doesn't really scale). We'll represent the\n",
|
|
"primal-tangent pair as an actual pair instead of folding them both into a\n",
|
|
"single floating point number. For each primitive operation we'll have a rule\n",
|
|
"that describes how to propagate these primal tangent pairs. Let's work out the\n",
|
|
"rules for our two primitives."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "309dc70d",
|
|
"metadata": {},
|
|
"source": [
|
|
"Addition is easy. Consider `x + y` where `x = xp + xt * eps` and `y = yp + yt * eps`\n",
|
|
"(\"p\" for \"primal\", \"t\" for \"tangent\"):\n",
|
|
"\n",
|
|
" x + y = (xp + xt * eps) + (yp + yt * eps)\n",
|
|
" = (xp + yp) # primal component\n",
|
|
" + (xt + yt) * eps # tangent component\n",
|
|
"\n",
|
|
"The result is a first-order polynomial in `eps` and we can read off the\n",
|
|
"primal-tangent pair as (xp + yp, xt + yt)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "59302b21",
|
|
"metadata": {},
|
|
"source": [
|
|
"Multiplication is more interesting:\n",
|
|
"\n",
|
|
" x * y = (xp + xt * eps) * (yp + yt * eps)\n",
|
|
" = (xp * yp) # primal component\n",
|
|
" + (xp * yt + xt * yp) * eps # tangent component\n",
|
|
" + (xt * yt) * eps * eps # quadratic component, vanishes in the eps->0 limit\n",
|
|
"\n",
|
|
"Now we have a second order polynomial. But as epsilon goes to zero the\n",
|
|
"quadratic term vanishes and our primal-tangent pair\n",
|
|
"is just `(xp * yp, xp * yt + xt * yp)`\n",
|
|
"(In our earlier example with finite `eps` this term not vanishing is\n",
|
|
"why we had the 1e-10 \"truncation error\")."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "37fa5063",
|
|
"metadata": {},
|
|
"source": [
|
|
"Putting this into code, we can write down the forward-AD rules for addition\n",
|
|
"and multiplication and express `foo` in terms of these:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "57222038",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.862460Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.862160Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.868704Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.867858Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"DualNumber(primal=10.0, tangent=7.0)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from dataclasses import dataclass\n",
|
|
"\n",
|
|
"# A primal-tangent pair is conventionally called a \"dual number\"\n",
|
|
"@dataclass\n",
|
|
"class DualNumber:\n",
|
|
" primal : float\n",
|
|
" tangent : float\n",
|
|
"\n",
|
|
"def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:\n",
|
|
" return DualNumber(x.primal + y.primal, x.tangent + y.tangent)\n",
|
|
"\n",
|
|
"def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:\n",
|
|
" return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)\n",
|
|
"\n",
|
|
"def foo_dual(x : DualNumber) -> DualNumber:\n",
|
|
" return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))\n",
|
|
"\n",
|
|
"print (foo_dual(DualNumber(2.0, 1.0)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "54947cc7",
|
|
"metadata": {},
|
|
"source": [
|
|
"That works! But rewriting `foo` to use the `_dual` versions of addition and\n",
|
|
"multiplication was a bit tedious. Let's get back to the main program and use\n",
|
|
"our interpretation machinery to do the rewrite automatically."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "25edb5f4",
|
|
"metadata": {},
|
|
"source": [
|
|
"## JVP Interpreter"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "945933c6",
|
|
"metadata": {},
|
|
"source": [
|
|
"We'll set up a new interpreter called `JVPInterpreter` (\"JVP\" for\n",
|
|
"\"Jacobian-vector product\") which propagates these dual numbers instead of\n",
|
|
"ordinary values. The `JVPInterpreter` has methods 'add' and 'mul' that operate\n",
|
|
"on dual number. They cast constant arguments to dual numbers as needed by\n",
|
|
"calling `JVPInterpreter.lift`. In our manually rewritten version above we did\n",
|
|
"that by replacing the literal `3.0` with `DualNumber(3.0, 0.0)`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "d17a4fb0",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.871469Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.871220Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.880456Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.879794Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(10.0, 7.0)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# This is like DualNumber above except that is also has a pointer to the\n",
|
|
"# interpreter it belongs to, which is needed to avoid \"perturbation confusion\"\n",
|
|
"# in higher order differentiation.\n",
|
|
"@dataclass\n",
|
|
"class TaggedDualNumber:\n",
|
|
" interpreter : Interpreter\n",
|
|
" primal : float\n",
|
|
" tangent : float\n",
|
|
"\n",
|
|
"class JVPInterpreter(Interpreter):\n",
|
|
" def __init__(self, prev_interpreter: Interpreter):\n",
|
|
" # We keep a pointer to the interpreter that was current when this\n",
|
|
" # interpreter was first invoked. That's the context in which our\n",
|
|
" # rules should run.\n",
|
|
" self.prev_interpreter = prev_interpreter\n",
|
|
"\n",
|
|
" def interpret_op(self, op, args):\n",
|
|
" args = tuple(self.lift(arg) for arg in args)\n",
|
|
" with set_interpreter(self.prev_interpreter):\n",
|
|
" match op:\n",
|
|
" case Op.add:\n",
|
|
" # Notice that we use `add` and `mul` here, which are the\n",
|
|
" # interpreter-dispatching functions defined earlier.\n",
|
|
" x, y = args\n",
|
|
" return self.dual_number(\n",
|
|
" add(x.primal, y.primal),\n",
|
|
" add(x.tangent, y.tangent))\n",
|
|
"\n",
|
|
" case Op.mul:\n",
|
|
" x, y = args\n",
|
|
" x = self.lift(x)\n",
|
|
" y = self.lift(y)\n",
|
|
" return self.dual_number(\n",
|
|
" mul(x.primal, y.primal),\n",
|
|
" add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))\n",
|
|
"\n",
|
|
" def dual_number(self, primal, tangent):\n",
|
|
" return TaggedDualNumber(self, primal, tangent)\n",
|
|
"\n",
|
|
" # Lift a constant value (constant with respect to this interpreter) to\n",
|
|
" # a TaggedDualNumber.\n",
|
|
" def lift(self, x):\n",
|
|
" if isinstance(x, TaggedDualNumber) and x.interpreter is self:\n",
|
|
" return x\n",
|
|
" else:\n",
|
|
" return self.dual_number(x, 0.0)\n",
|
|
"\n",
|
|
"def jvp(f, primal, tangent):\n",
|
|
" jvp_interpreter = JVPInterpreter(current_interpreter)\n",
|
|
" dual_number_in = jvp_interpreter.dual_number(primal, tangent)\n",
|
|
" with set_interpreter(jvp_interpreter):\n",
|
|
" result = f(dual_number_in)\n",
|
|
" dual_number_out = jvp_interpreter.lift(result)\n",
|
|
" return dual_number_out.primal, dual_number_out.tangent\n",
|
|
"\n",
|
|
"# Let's try it out:\n",
|
|
"print(jvp(foo, 2.0, 1.0))\n",
|
|
"\n",
|
|
"# Because we were careful to consider nesting interpreters, higher-order AD\n",
|
|
"# works out of the box:\n",
|
|
"\n",
|
|
"def derivative(f, x):\n",
|
|
" _, tangent = jvp(f, x, 1.0)\n",
|
|
" return tangent\n",
|
|
"\n",
|
|
"def nth_order_derivative(n, f, x):\n",
|
|
" if n == 0:\n",
|
|
" return f(x)\n",
|
|
" else:\n",
|
|
" return derivative(lambda x: nth_order_derivative(n-1, f, x), x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "3acc3839",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.882187Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.882009Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.885190Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.884635Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(nth_order_derivative(0, foo, 2.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "187eb028",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.886848Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.886685Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.889507Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.889081Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"7.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(nth_order_derivative(1, foo, 2.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "9f0dde6d",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.891061Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.890896Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.894142Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.893701Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(nth_order_derivative(2, foo, 2.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "4d086fb3",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.895591Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.895398Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.898277Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.897870Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# The rest are zero because `foo` is only a second-order polymonial\n",
|
|
"print(nth_order_derivative(3, foo, 2.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "e3164405",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.899736Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.899545Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.902719Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.902303Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(nth_order_derivative(4, foo, 2.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2e51ca61",
|
|
"metadata": {},
|
|
"source": [
|
|
"There are some subtleties worth discussing. First, how do you tell if\n",
|
|
"something is constant with respect to differentiation? It's tempting to say\n",
|
|
"\"it's a constant if and only if it's not a dual number\". But actually dual\n",
|
|
"numbers created by a *different* JVPInterpreter also need to be considered\n",
|
|
"constants with resepect to the JVPInterpreter we're currently handling. That's\n",
|
|
"why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This\n",
|
|
"comes up in higher order differentiation when there are multiple JVPInterprers\n",
|
|
"in scope. The sort of bug where you accidentally interpret a dual number from\n",
|
|
"a different interpreter as non-constant is sometimes called \"perturbation\n",
|
|
"confusion\" in the literature. Here's an example program that would have given\n",
|
|
"the wrong answer if we hadn't had the `and x.interpreter is self` check in\n",
|
|
"`JVPInterpreter.lift`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "ae1449a0",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.904294Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.904105Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.907284Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.906874Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" # g is constant in its (ignored) argument `y`. Its derivative should be zero\n",
|
|
" # but our AD will mess it up if we don't distinguish perturbations from\n",
|
|
" # different interpreters.\n",
|
|
" def g(y):\n",
|
|
" return x\n",
|
|
" should_be_zero = derivative(g, 0.0)\n",
|
|
" return mul(x, should_be_zero)\n",
|
|
"\n",
|
|
"print(derivative(f, 0.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "884d6f62",
|
|
"metadata": {},
|
|
"source": [
|
|
"Another subtlety: `JVPInterpreter.add` and `JVPInterpreter.mul` describe\n",
|
|
"addition and multiplication on dual numbers in terms of addition and\n",
|
|
"multiplication on the primal and tangent components. But we don't use ordinary\n",
|
|
"`+` and `*` for this. Instead we use our own `add` and `mul` functions which\n",
|
|
"dispatch to the current interpreter. Before calling them we set the current\n",
|
|
"interpreter to be the *previous* interpreter, i.e. the interpreter that was\n",
|
|
"current when `JVPInterpreter` was first invoked. If we didn't do this we'd\n",
|
|
"have an infinite recursion, with `add` and `mul` dispatching to\n",
|
|
"`JVPInterpreter` endlessly. The advantage of using own `add` and `mul` instead\n",
|
|
"of ordinary `+` and `*` is that it means we can nest these interpreters and do\n",
|
|
"higher-order AD."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e03446b3",
|
|
"metadata": {},
|
|
"source": [
|
|
"At this point you might be wondering: have we just reinvented operator\n",
|
|
"overloading? Python overloads the infix ops `+` and `*` to dispatch to the\n",
|
|
"argument's `__add__` and `__mul__`. Could we have just used that mechanism\n",
|
|
"instead of this whole interpreter business? Yes, actually. Indeed, the earlier\n",
|
|
"automatic differentiation (AD) literature uses the term \"operator overloading\"\n",
|
|
"to describe this style of AD implementation. One detail is that we can't rely\n",
|
|
"exclusively on Python built-in overloading because that only lets us overload\n",
|
|
"a handful of built-in infix ops whereas we eventually want to overload\n",
|
|
"numpy-level operations like `sin` and `cos`. So we need our own mechanism."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2e2035ea",
|
|
"metadata": {},
|
|
"source": [
|
|
"But there's a more important difference: our dispatch is based on *context*\n",
|
|
"whereas traditional Python-style overloading is based on *data*. This is\n",
|
|
"actually a recent development for JAX. The earliest versions of JAX looked\n",
|
|
"more like traditional data-based overloading. An interpreter (a \"trace\" in JAX\n",
|
|
"jargon) for an operation would be chosen based on data attached to the\n",
|
|
"arguments to that operation. We've gradually made the interpreter-dispatch\n",
|
|
"decision rely more and more on context rather than data (omnistaging [link],\n",
|
|
"stackless [link]). The reason to prefer context-based interpretation over\n",
|
|
"data-based interpretation is that it makes the implementation much simpler."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2e5e55ee",
|
|
"metadata": {},
|
|
"source": [
|
|
"All that said, we do *also* want to take advantage of Python's built-in\n",
|
|
"overloading mechanism. That way we get the syntactic convenience of using\n",
|
|
"infix operators `+` and `*` instead of writing out `add(..)` and `mul(..)`.\n",
|
|
"But we'll put that aside for now."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c9a22b61",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 3. Staging to an untyped IR"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "480e2328",
|
|
"metadata": {},
|
|
"source": [
|
|
"The two program transformations we've seen so far -- evaluation and JVP --\n",
|
|
"both traverse the input program from top to bottom. They visit the operations\n",
|
|
"one by one in the same order as ordinary evaluation. A convenient thing about\n",
|
|
"top-to-bottom transformations is that they can be implemented eagerly, or\n",
|
|
"\"online\", meaning that we can evaluate the program from top to bottom and\n",
|
|
"perform the necessary transformations as we go. We never look at the entire\n",
|
|
"program at once."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e4adad7b",
|
|
"metadata": {},
|
|
"source": [
|
|
"But not all transformations work this way. For example, dead-code elimination\n",
|
|
"requires traversing from bottom to top, collecting usage statistics on the way\n",
|
|
"up and eliminating pure operations whose results have no uses. Another\n",
|
|
"bottom-to-top transformation is AD transposition, which we use to implement\n",
|
|
"reverse-mode AD. For these we need to first \"stage\" the program into an IR\n",
|
|
"(internal representation), a data structure representing the program, which we\n",
|
|
"can then traverse in any order we like. Building this IR from a Python program\n",
|
|
"will be the goal of our third and final interpreter."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9c8ba558",
|
|
"metadata": {},
|
|
"source": [
|
|
"First, let's define the IR. We'll do an untypes ANF IR to start. A function\n",
|
|
"(we call IR functions \"jaxprs\" in JAX) will have a list of formal parameters,\n",
|
|
"a list of operations, and a return value. Each argument to an operation must\n",
|
|
"be an \"atom\", which is either a variable or a literal. The return value of the\n",
|
|
"function is also an atom."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "2100d92e",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.909146Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.908956Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.914391Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.913886Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"Var = str # Variables are just strings in this untyped IR\n",
|
|
"Atom = Var | float # Atoms (arguments to operations) can be variables or (float) literals\n",
|
|
"\n",
|
|
"# Equation - a single line in our IR like `z = mul(x, y)`\n",
|
|
"@dataclass\n",
|
|
"class Equation:\n",
|
|
" var : Var # The variable name of the result\n",
|
|
" op : Op # The primitive operation we're applying\n",
|
|
" args : tuple[Atom] # The arguments we're applying the primitive operation to\n",
|
|
"\n",
|
|
"# We call an IR function a \"Jaxpr\", for \"JAX expression\"\n",
|
|
"@dataclass\n",
|
|
"class Jaxpr:\n",
|
|
" parameters : list[Var] # The function's formal parameters (arguments)\n",
|
|
" equations : list[Equation] # The body of the function, a list of instructions/equations\n",
|
|
" return_val : Atom # The function's return value\n",
|
|
"\n",
|
|
" def __str__(self):\n",
|
|
" lines = []\n",
|
|
" lines.append(', '.join(b for b in self.parameters) + ' ->')\n",
|
|
" for eqn in self.equations:\n",
|
|
" args_str = ', '.join(str(arg) for arg in eqn.args)\n",
|
|
" lines.append(f' {eqn.var} = {eqn.op}({args_str})')\n",
|
|
" lines.append(self.return_val)\n",
|
|
" return '\\n'.join(lines)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "36720d5c",
|
|
"metadata": {},
|
|
"source": [
|
|
"To build the IR from a Python function we define a `StagingInterpreter` that\n",
|
|
"takes each operation and adds it to a growing list of all the operations we've\n",
|
|
"seen so far:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "0ed04f2e",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.916077Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.915878Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.920161Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.919753Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class StagingInterpreter(Interpreter):\n",
|
|
" def __init__(self):\n",
|
|
" self.equations = [] # A mutable list of all the ops we've seen so far\n",
|
|
" self.name_counter = 0 # Counter for generating unique names\n",
|
|
"\n",
|
|
" def fresh_var(self):\n",
|
|
" self.name_counter += 1\n",
|
|
" return \"v_\" + str(self.name_counter)\n",
|
|
"\n",
|
|
" def interpret_op(self, op, args):\n",
|
|
" binder = self.fresh_var()\n",
|
|
" self.equations.append(Equation(binder, op, args))\n",
|
|
" return binder\n",
|
|
"\n",
|
|
"def build_jaxpr(f, num_args):\n",
|
|
" interpreter = StagingInterpreter()\n",
|
|
" parameters = tuple(interpreter.fresh_var() for _ in range(num_args))\n",
|
|
" with set_interpreter(interpreter):\n",
|
|
" result = f(*parameters)\n",
|
|
" return Jaxpr(parameters, interpreter.equations, result)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1bde02c9",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we can construct an IR for a Python program and print it out:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "606d2e23",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.921731Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.921538Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.924256Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.923850Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"v_1 ->\n",
|
|
" v_2 = Op.add(v_1, 3.0)\n",
|
|
" v_3 = Op.mul(v_1, v_2)\n",
|
|
"v_3\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(build_jaxpr(foo, 1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "67deabe8",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also evaluate our IR by writing an explicit interpreter that traverses\n",
|
|
"the operations one by one:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "6a20cc84",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.925838Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.925646Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.929596Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.929187Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def eval_jaxpr(jaxpr, args):\n",
|
|
" # An environment mapping variables to values\n",
|
|
" env = dict(zip(jaxpr.parameters, args))\n",
|
|
" def eval_atom(x): return env[x] if isinstance(x, Var) else x\n",
|
|
" for eqn in jaxpr.equations:\n",
|
|
" args = tuple(eval_atom(x) for x in eqn.args)\n",
|
|
" env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)\n",
|
|
" return eval_atom(jaxpr.return_val)\n",
|
|
"\n",
|
|
"print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c6250492",
|
|
"metadata": {},
|
|
"source": [
|
|
"We've written this interpreter in terms of `current_interpreter.interpret_op`\n",
|
|
"which means we've done a full round-trip: interpretable Python program to IR\n",
|
|
"to interpretable Python program. Since the result is \"interpretable\" we can\n",
|
|
"differentiate it again, or stage it out or anything we like:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "831924b8",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2025-02-12T20:21:55.931176Z",
|
|
"iopub.status.busy": "2025-02-12T20:21:55.930983Z",
|
|
"iopub.status.idle": "2025-02-12T20:21:55.933902Z",
|
|
"shell.execute_reply": "2025-02-12T20:21:55.933490Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(10.0, 7.0)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d3ac5873",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Up next..."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c3451298",
|
|
"metadata": {},
|
|
"source": [
|
|
"That's it for part one of this tutorial. We've done two primitives, three\n",
|
|
"interpreters and the tracing mechanism that weaves them together. In the next\n",
|
|
"part we'll add types other than floats, error handling, compilation,\n",
|
|
"reverse-mode AD and higher-order primtives. Note that the second part is\n",
|
|
"structured differently. Rather than trying to have a top-to-bottom order that\n",
|
|
"obeys both code dependencies (e.g. data structures need to be defined before\n",
|
|
"they're used) and pedagogical dependencies (concepts need to be introduced\n",
|
|
"before they're implemented) we're going with a single file that can be approached\n",
|
|
"in any order."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"formats": "ipynb,md:myst,py:light"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|