mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Part 1 of a new autodidax based on "stackless"
This commit is contained in:
parent
c7199fe8a5
commit
9145366f6f
1082
docs/autodidax2_part1.ipynb
Normal file
1082
docs/autodidax2_part1.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
547
docs/autodidax2_part1.md
Normal file
547
docs/autodidax2_part1.md
Normal file
@ -0,0 +1,547 @@
|
||||
---
|
||||
jupytext:
|
||||
formats: ipynb,md:myst,py:light
|
||||
text_representation:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.4
|
||||
kernelspec:
|
||||
display_name: Python 3 (ipykernel)
|
||||
language: python
|
||||
name: python3
|
||||
---
|
||||
|
||||
```{raw-cell}
|
||||
|
||||
---
|
||||
Copyright 2025 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
---
|
||||
```
|
||||
|
||||
# Autodidax2, part 1: JAX from scratch, again
|
||||
|
||||
+++
|
||||
|
||||
If you want to understand how JAX works you could trying reading the code. But
|
||||
the code is complicated, often for no good reason. This notebook presents a
|
||||
stripped-back version without the cruft. It's a minimal version of JAX from
|
||||
first principles. Enjoy!
|
||||
|
||||
+++
|
||||
|
||||
## Main idea: context-sensitive interpretation
|
||||
|
||||
+++
|
||||
|
||||
JAX is two things:
|
||||
1. a set of primitive operations (roughly the NumPy API)
|
||||
2. a set of interpreters over those primitives (compilation, AD, etc.)
|
||||
|
||||
In this minimal version of JAX we'll start with just two primitive operations,
|
||||
addition and multiplication, and we'll add interpreters one by one. Suppose we
|
||||
have a user-defined function like this:
|
||||
|
||||
```{code-cell} ipython3
|
||||
def foo(x):
|
||||
return mul(x, add(x, 3.0))
|
||||
```
|
||||
|
||||
We want to be able to interpret `foo` in different ways without changing its
|
||||
implementation: we want to evaluate it on concrete values, differentiate it,
|
||||
stage it out to an IR, compile it and so on.
|
||||
|
||||
+++
|
||||
|
||||
Here's how we'll do it. For each of these interpretations we'll define an
|
||||
`Interpreter` object with a rule for handling each primitive operation. We'll
|
||||
keep track of the *current* interpreter using a global context variable. The
|
||||
user-facing functions `add` and `mul` will dispatch to the current
|
||||
interpreter. At the beginning of the program the current interpreter will be
|
||||
the "evaluating" interpreter which just evaluates the operations on ordinary
|
||||
concrete data. Here's what this all looks like so far.
|
||||
|
||||
```{code-cell} ipython3
|
||||
from enum import Enum, auto
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
# The full (closed) set of primitive operations
|
||||
class Op(Enum):
|
||||
add = auto() # addition on floats
|
||||
mul = auto() # multiplication on floats
|
||||
|
||||
# Interpreters have rules for handling each primitive operation.
|
||||
class Interpreter:
|
||||
def interpret_op(self, op: Op, args: tuple[Any, ...]):
|
||||
assert False, "subclass should implement this"
|
||||
|
||||
# Our first interpreter is the "evaluating interpreter" which performs ordinary
|
||||
# concrete evaluation.
|
||||
class EvalInterpreter:
|
||||
def interpret_op(self, op, args):
|
||||
assert all(isinstance(arg, float) for arg in args)
|
||||
match op:
|
||||
case Op.add:
|
||||
x, y = args
|
||||
return x + y
|
||||
case Op.mul:
|
||||
x, y = args
|
||||
return x * y
|
||||
case _:
|
||||
raise ValueError(f"Unrecognized primitive op: {op}")
|
||||
|
||||
# The current interpreter is initially the evaluating interpreter.
|
||||
current_interpreter = EvalInterpreter()
|
||||
|
||||
# A context manager for temporarily changing the current interpreter
|
||||
@contextmanager
|
||||
def set_interpreter(new_interpreter):
|
||||
global current_interpreter
|
||||
prev_interpreter = current_interpreter
|
||||
try:
|
||||
current_interpreter = new_interpreter
|
||||
yield
|
||||
finally:
|
||||
current_interpreter = prev_interpreter
|
||||
|
||||
# The user-facing functions `mul` and `add` dispatch to the current interpreter.
|
||||
def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))
|
||||
def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))
|
||||
```
|
||||
|
||||
At this point we can call `foo` with ordinary concrete inputs and see the
|
||||
results:
|
||||
|
||||
```{code-cell} ipython3
|
||||
print(foo(2.0))
|
||||
```
|
||||
|
||||
## Aside: forward-mode automatic differentiation
|
||||
|
||||
+++
|
||||
|
||||
For our second interpreter we're going to try forward-mode automatic
|
||||
differentiation (AD). Here's a quick introduction to forward-mode AD in case
|
||||
this is the first time you've come across it. Otherwise skip ahead to the
|
||||
"JVPInterprer" section.
|
||||
|
||||
+++
|
||||
|
||||
Suppose we're interested in the derivative of `foo(x)` evaluated at `x=2.0`.
|
||||
We could approximate it with finite differences:
|
||||
|
||||
```{code-cell} ipython3
|
||||
print((foo(2.00001) - foo(2.0)) / 0.00001)
|
||||
```
|
||||
|
||||
The answer is close to 7.0 as expected. But computing it this way required two
|
||||
evaluations of the function (not to mention the roundoff error and truncation
|
||||
error). Here's a funny thing though. We can almost get the answer with a
|
||||
single evaluation:
|
||||
|
||||
```{code-cell} ipython3
|
||||
print(foo(2.00001))
|
||||
```
|
||||
|
||||
The answer we're looking for, 7.0, is right there in the insignificant digits!
|
||||
|
||||
+++
|
||||
|
||||
Here's one way to think about what's happening. The initial argument to `foo`,
|
||||
`2.00001`, carries two pieces of data: a "primal" value, 2.0, and a "tangent"
|
||||
value, `1.0`. The representation of this primal-tangent pair, `2.00001`, is
|
||||
the sum of the two, with the tangent scaled by a small fixed epsilon, `1e-5`.
|
||||
Ordinary evaluation of `foo(2.00001)` propagates this primal-tangent pair,
|
||||
producing `10.0000700001` as the result. The primal and tangent components are
|
||||
well separated in scale so we can visually interpret the result as the
|
||||
primal-tangent pair (10.0, 7.0), ignoring the the ~1e-10 truncation error at
|
||||
the end.
|
||||
|
||||
+++
|
||||
|
||||
The idea with forward-mode differentiation is to do the same thing but exactly
|
||||
and explicitly (eyeballing floats doesn't really scale). We'll represent the
|
||||
primal-tangent pair as an actual pair instead of folding them both into a
|
||||
single floating point number. For each primitive operation we'll have a rule
|
||||
that describes how to propagate these primal tangent pairs. Let's work out the
|
||||
rules for our two primitives.
|
||||
|
||||
+++
|
||||
|
||||
Addition is easy. Consider `x + y` where `x = xp + xt * eps` and `y = yp + yt * eps`
|
||||
("p" for "primal", "t" for "tangent"):
|
||||
|
||||
x + y = (xp + xt * eps) + (yp + yt * eps)
|
||||
= (xp + yp) # primal component
|
||||
+ (xt + yt) * eps # tangent component
|
||||
|
||||
The result is a first-order polynomial in `eps` and we can read off the
|
||||
primal-tangent pair as (xp + yp, xt + yt).
|
||||
|
||||
+++
|
||||
|
||||
Multiplication is more interesting:
|
||||
|
||||
x * y = (xp + xt * eps) * (yp + yt * eps)
|
||||
= (xp * yp) # primal component
|
||||
+ (xp * yt + xt * yp) * eps # tangent component
|
||||
+ (xt * yt) * eps * eps # quadratic component, vanishes in the eps->0 limit
|
||||
|
||||
Now we have a second order polynomial. But as epsilon goes to zero the
|
||||
quadratic term vanishes and our primal-tangent pair
|
||||
is just `(xp * yp, xp * yt + xt * yp)`
|
||||
(In our earlier example with finite `eps` this term not vanishing is
|
||||
why we had the 1e-10 "truncation error").
|
||||
|
||||
+++
|
||||
|
||||
Putting this into code, we can write down the forward-AD rules for addition
|
||||
and multiplication and express `foo` in terms of these:
|
||||
|
||||
```{code-cell} ipython3
|
||||
from dataclasses import dataclass
|
||||
|
||||
# A primal-tangent pair is conventionally called a "dual number"
|
||||
@dataclass
|
||||
class DualNumber:
|
||||
primal : float
|
||||
tangent : float
|
||||
|
||||
def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:
|
||||
return DualNumber(x.primal + y.primal, x.tangent + y.tangent)
|
||||
|
||||
def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:
|
||||
return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)
|
||||
|
||||
def foo_dual(x : DualNumber) -> DualNumber:
|
||||
return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))
|
||||
|
||||
print (foo_dual(DualNumber(2.0, 1.0)))
|
||||
```
|
||||
|
||||
That works! But rewriting `foo` to use the `_dual` versions of addition and
|
||||
multiplication was a bit tedious. Let's get back to the main program and use
|
||||
our interpretation machinery to do the rewrite automatically.
|
||||
|
||||
+++
|
||||
|
||||
## JVP Interpreter
|
||||
|
||||
+++
|
||||
|
||||
We'll set up a new interpreter called `JVPInterpreter` ("JVP" for
|
||||
"Jacobian-vector product") which propagates these dual numbers instead of
|
||||
ordinary values. The `JVPInterpreter` has methods 'add' and 'mul' that operate
|
||||
on dual number. They cast constant arguments to dual numbers as needed by
|
||||
calling `JVPInterpreter.lift`. In our manually rewritten version above we did
|
||||
that by replacing the literal `3.0` with `DualNumber(3.0, 0.0)`.
|
||||
|
||||
```{code-cell} ipython3
|
||||
# This is like DualNumber above except that is also has a pointer to the
|
||||
# interpreter it belongs to, which is needed to avoid "perturbation confusion"
|
||||
# in higher order differentiation.
|
||||
@dataclass
|
||||
class TaggedDualNumber:
|
||||
interpreter : Interpreter
|
||||
primal : float
|
||||
tangent : float
|
||||
|
||||
class JVPInterpreter(Interpreter):
|
||||
def __init__(self, prev_interpreter: Interpreter):
|
||||
# We keep a pointer to the interpreter that was current when this
|
||||
# interpreter was first invoked. That's the context in which our
|
||||
# rules should run.
|
||||
self.prev_interpreter = prev_interpreter
|
||||
|
||||
def interpret_op(self, op, args):
|
||||
args = tuple(self.lift(arg) for arg in args)
|
||||
with set_interpreter(self.prev_interpreter):
|
||||
match op:
|
||||
case Op.add:
|
||||
# Notice that we use `add` and `mul` here, which are the
|
||||
# interpreter-dispatching functions defined earlier.
|
||||
x, y = args
|
||||
return self.dual_number(
|
||||
add(x.primal, y.primal),
|
||||
add(x.tangent, y.tangent))
|
||||
|
||||
case Op.mul:
|
||||
x, y = args
|
||||
x = self.lift(x)
|
||||
y = self.lift(y)
|
||||
return self.dual_number(
|
||||
mul(x.primal, y.primal),
|
||||
add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))
|
||||
|
||||
def dual_number(self, primal, tangent):
|
||||
return TaggedDualNumber(self, primal, tangent)
|
||||
|
||||
# Lift a constant value (constant with respect to this interpreter) to
|
||||
# a TaggedDualNumber.
|
||||
def lift(self, x):
|
||||
if isinstance(x, TaggedDualNumber) and x.interpreter is self:
|
||||
return x
|
||||
else:
|
||||
return self.dual_number(x, 0.0)
|
||||
|
||||
def jvp(f, primal, tangent):
|
||||
jvp_interpreter = JVPInterpreter(current_interpreter)
|
||||
dual_number_in = jvp_interpreter.dual_number(primal, tangent)
|
||||
with set_interpreter(jvp_interpreter):
|
||||
result = f(dual_number_in)
|
||||
dual_number_out = jvp_interpreter.lift(result)
|
||||
return dual_number_out.primal, dual_number_out.tangent
|
||||
|
||||
# Let's try it out:
|
||||
print(jvp(foo, 2.0, 1.0))
|
||||
|
||||
# Because we were careful to consider nesting interpreters, higher-order AD
|
||||
# works out of the box:
|
||||
|
||||
def derivative(f, x):
|
||||
_, tangent = jvp(f, x, 1.0)
|
||||
return tangent
|
||||
|
||||
def nth_order_derivative(n, f, x):
|
||||
if n == 0:
|
||||
return f(x)
|
||||
else:
|
||||
return derivative(lambda x: nth_order_derivative(n-1, f, x), x)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
print(nth_order_derivative(0, foo, 2.0))
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
print(nth_order_derivative(1, foo, 2.0))
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
print(nth_order_derivative(2, foo, 2.0))
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
# The rest are zero because `foo` is only a second-order polymonial
|
||||
print(nth_order_derivative(3, foo, 2.0))
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
print(nth_order_derivative(4, foo, 2.0))
|
||||
```
|
||||
|
||||
There are some subtleties worth discussing. First, how do you tell if
|
||||
something is constant with respect to differentiation? It's tempting to say
|
||||
"it's a constant if and only if it's not a dual number". But actually dual
|
||||
numbers created by a *different* JVPInterpreter also need to be considered
|
||||
constants with resepect to the JVPInterpreter we're currently handling. That's
|
||||
why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This
|
||||
comes up in higher order differentiation when there are multiple JVPInterprers
|
||||
in scope. The sort of bug where you accidentally interpret a dual number from
|
||||
a different interpreter as non-constant is sometimes called "perturbation
|
||||
confusion" in the literature. Here's an example program that would have given
|
||||
the wrong answer if we hadn't had the `and x.interpreter is self` check in
|
||||
`JVPInterpreter.lift`.
|
||||
|
||||
```{code-cell} ipython3
|
||||
def f(x):
|
||||
# g is constant in its (ignored) argument `y`. Its derivative should be zero
|
||||
# but our AD will mess it up if we don't distinguish perturbations from
|
||||
# different interpreters.
|
||||
def g(y):
|
||||
return x
|
||||
should_be_zero = derivative(g, 0.0)
|
||||
return mul(x, should_be_zero)
|
||||
|
||||
print(derivative(f, 0.0))
|
||||
```
|
||||
|
||||
Another subtlety: `JVPInterpreter.add` and `JVPInterpreter.mul` describe
|
||||
addition and multiplication on dual numbers in terms of addition and
|
||||
multiplication on the primal and tangent components. But we don't use ordinary
|
||||
`+` and `*` for this. Instead we use our own `add` and `mul` functions which
|
||||
dispatch to the current interpreter. Before calling them we set the current
|
||||
interpreter to be the *previous* interpreter, i.e. the interpreter that was
|
||||
current when `JVPInterpreter` was first invoked. If we didn't do this we'd
|
||||
have an infinite recursion, with `add` and `mul` dispatching to
|
||||
`JVPInterpreter` endlessly. The advantage of using own `add` and `mul` instead
|
||||
of ordinary `+` and `*` is that it means we can nest these interpreters and do
|
||||
higher-order AD.
|
||||
|
||||
+++
|
||||
|
||||
At this point you might be wondering: have we just reinvented operator
|
||||
overloading? Python overloads the infix ops `+` and `*` to dispatch to the
|
||||
argument's `__add__` and `__mul__`. Could we have just used that mechanism
|
||||
instead of this whole interpreter business? Yes, actually. Indeed, the earlier
|
||||
automatic differentiation (AD) literature uses the term "operator overloading"
|
||||
to describe this style of AD implementation. One detail is that we can't rely
|
||||
exclusively on Python built-in overloading because that only lets us overload
|
||||
a handful of built-in infix ops whereas we eventually want to overload
|
||||
numpy-level operations like `sin` and `cos`. So we need our own mechanism.
|
||||
|
||||
+++
|
||||
|
||||
But there's a more important difference: our dispatch is based on *context*
|
||||
whereas traditional Python-style overloading is based on *data*. This is
|
||||
actually a recent development for JAX. The earliest versions of JAX looked
|
||||
more like traditional data-based overloading. An interpreter (a "trace" in JAX
|
||||
jargon) for an operation would be chosen based on data attached to the
|
||||
arguments to that operation. We've gradually made the interpreter-dispatch
|
||||
decision rely more and more on context rather than data (omnistaging [link],
|
||||
stackless [link]). The reason to prefer context-based interpretation over
|
||||
data-based interpretation is that it makes the implementation much simpler.
|
||||
|
||||
+++
|
||||
|
||||
All that said, we do *also* want to take advantage of Python's built-in
|
||||
overloading mechanism. That way we get the syntactic convenience of using
|
||||
infix operators `+` and `*` instead of writing out `add(..)` and `mul(..)`.
|
||||
But we'll put that aside for now.
|
||||
|
||||
+++
|
||||
|
||||
# 3. Staging to an untyped IR
|
||||
|
||||
+++
|
||||
|
||||
The two program transformations we've seen so far -- evaluation and JVP --
|
||||
both traverse the input program from top to bottom. They visit the operations
|
||||
one by one in the same order as ordinary evaluation. A convenient thing about
|
||||
top-to-bottom transformations is that they can be implemented eagerly, or
|
||||
"online", meaning that we can evaluate the program from top to bottom and
|
||||
perform the necessary transformations as we go. We never look at the entire
|
||||
program at once.
|
||||
|
||||
+++
|
||||
|
||||
But not all transformations work this way. For example, dead-code elimination
|
||||
requires traversing from bottom to top, collecting usage statistics on the way
|
||||
up and eliminating pure operations whose results have no uses. Another
|
||||
bottom-to-top transformation is AD transposition, which we use to implement
|
||||
reverse-mode AD. For these we need to first "stage" the program into an IR
|
||||
(internal representation), a data structure representing the program, which we
|
||||
can then traverse in any order we like. Building this IR from a Python program
|
||||
will be the goal of our third and final interpreter.
|
||||
|
||||
+++
|
||||
|
||||
First, let's define the IR. We'll do an untypes ANF IR to start. A function
|
||||
(we call IR functions "jaxprs" in JAX) will have a list of formal parameters,
|
||||
a list of operations, and a return value. Each argument to an operation must
|
||||
be an "atom", which is either a variable or a literal. The return value of the
|
||||
function is also an atom.
|
||||
|
||||
```{code-cell} ipython3
|
||||
Var = str # Variables are just strings in this untyped IR
|
||||
Atom = Var | float # Atoms (arguments to operations) can be variables or (float) literals
|
||||
|
||||
# Equation - a single line in our IR like `z = mul(x, y)`
|
||||
@dataclass
|
||||
class Equation:
|
||||
var : Var # The variable name of the result
|
||||
op : Op # The primitive operation we're applying
|
||||
args : tuple[Atom] # The arguments we're applying the primitive operation to
|
||||
|
||||
# We call an IR function a "Jaxpr", for "JAX expression"
|
||||
@dataclass
|
||||
class Jaxpr:
|
||||
parameters : list[Var] # The function's formal parameters (arguments)
|
||||
equations : list[Equation] # The body of the function, a list of instructions/equations
|
||||
return_val : Atom # The function's return value
|
||||
|
||||
def __str__(self):
|
||||
lines = []
|
||||
lines.append(', '.join(b for b in self.parameters) + ' ->')
|
||||
for eqn in self.equations:
|
||||
args_str = ', '.join(str(arg) for arg in eqn.args)
|
||||
lines.append(f' {eqn.var} = {eqn.op}({args_str})')
|
||||
lines.append(self.return_val)
|
||||
return '\n'.join(lines)
|
||||
```
|
||||
|
||||
To build the IR from a Python function we define a `StagingInterpreter` that
|
||||
takes each operation and adds it to a growing list of all the operations we've
|
||||
seen so far:
|
||||
|
||||
```{code-cell} ipython3
|
||||
class StagingInterpreter(Interpreter):
|
||||
def __init__(self):
|
||||
self.equations = [] # A mutable list of all the ops we've seen so far
|
||||
self.name_counter = 0 # Counter for generating unique names
|
||||
|
||||
def fresh_var(self):
|
||||
self.name_counter += 1
|
||||
return "v_" + str(self.name_counter)
|
||||
|
||||
def interpret_op(self, op, args):
|
||||
binder = self.fresh_var()
|
||||
self.equations.append(Equation(binder, op, args))
|
||||
return binder
|
||||
|
||||
def build_jaxpr(f, num_args):
|
||||
interpreter = StagingInterpreter()
|
||||
parameters = tuple(interpreter.fresh_var() for _ in range(num_args))
|
||||
with set_interpreter(interpreter):
|
||||
result = f(*parameters)
|
||||
return Jaxpr(parameters, interpreter.equations, result)
|
||||
```
|
||||
|
||||
Now we can construct an IR for a Python program and print it out:
|
||||
|
||||
```{code-cell} ipython3
|
||||
print(build_jaxpr(foo, 1))
|
||||
```
|
||||
|
||||
We can also evaluate our IR by writing an explicit interpreter that traverses
|
||||
the operations one by one:
|
||||
|
||||
```{code-cell} ipython3
|
||||
def eval_jaxpr(jaxpr, args):
|
||||
# An environment mapping variables to values
|
||||
env = dict(zip(jaxpr.parameters, args))
|
||||
def eval_atom(x): return env[x] if isinstance(x, Var) else x
|
||||
for eqn in jaxpr.equations:
|
||||
args = tuple(eval_atom(x) for x in eqn.args)
|
||||
env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)
|
||||
return eval_atom(jaxpr.return_val)
|
||||
|
||||
print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))
|
||||
```
|
||||
|
||||
We've written this interpreter in terms of `current_interpreter.interpret_op`
|
||||
which means we've done a full round-trip: interpretable Python program to IR
|
||||
to interpretable Python program. Since the result is "interpretable" we can
|
||||
differentiate it again, or stage it out or anything we like:
|
||||
|
||||
```{code-cell} ipython3
|
||||
print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))
|
||||
```
|
||||
|
||||
## Up next...
|
||||
|
||||
+++
|
||||
|
||||
That's it for part one of this tutorial. We've done two primitives, three
|
||||
interpreters and the tracing mechanism that weaves them together. In the next
|
||||
part we'll add types other than floats, error handling, compilation,
|
||||
reverse-mode AD and higher-order primtives. Note that the second part is
|
||||
structured differently. Rather than trying to have a top-to-bottom order that
|
||||
obeys both code dependencies (e.g. data structures need to be defined before
|
||||
they're used) and pedagogical dependencies (concepts need to be introduced
|
||||
before they're implemented) we're going with a single file that can be approached
|
||||
in any order.
|
491
docs/autodidax2_part1.py
Normal file
491
docs/autodidax2_part1.py
Normal file
@ -0,0 +1,491 @@
|
||||
# ---
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,md:myst,py:light
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: light
|
||||
# format_version: '1.5'
|
||||
# jupytext_version: 1.16.4
|
||||
# kernelspec:
|
||||
# display_name: Python 3 (ipykernel)
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# # Autodidax2, part 1: JAX from scratch, again
|
||||
|
||||
# If you want to understand how JAX works you could trying reading the code. But
|
||||
# the code is complicated, often for no good reason. This notebook presents a
|
||||
# stripped-back version without the cruft. It's a minimal version of JAX from
|
||||
# first principles. Enjoy!
|
||||
|
||||
# ## Main idea: context-sensitive interpretation
|
||||
|
||||
# JAX is two things:
|
||||
# 1. a set of primitive operations (roughly the NumPy API)
|
||||
# 2. a set of interpreters over those primitives (compilation, AD, etc.)
|
||||
#
|
||||
# In this minimal version of JAX we'll start with just two primitive operations,
|
||||
# addition and multiplication, and we'll add interpreters one by one. Suppose we
|
||||
# have a user-defined function like this:
|
||||
|
||||
def foo(x):
|
||||
return mul(x, add(x, 3.0))
|
||||
|
||||
|
||||
# We want to be able to interpret `foo` in different ways without changing its
|
||||
# implementation: we want to evaluate it on concrete values, differentiate it,
|
||||
# stage it out to an IR, compile it and so on.
|
||||
|
||||
# Here's how we'll do it. For each of these interpretations we'll define an
|
||||
# `Interpreter` object with a rule for handling each primitive operation. We'll
|
||||
# keep track of the *current* interpreter using a global context variable. The
|
||||
# user-facing functions `add` and `mul` will dispatch to the current
|
||||
# interpreter. At the beginning of the program the current interpreter will be
|
||||
# the "evaluating" interpreter which just evaluates the operations on ordinary
|
||||
# concrete data. Here's what this all looks like so far.
|
||||
|
||||
# +
|
||||
from enum import Enum, auto
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
# The full (closed) set of primitive operations
|
||||
class Op(Enum):
|
||||
add = auto() # addition on floats
|
||||
mul = auto() # multiplication on floats
|
||||
|
||||
# Interpreters have rules for handling each primitive operation.
|
||||
class Interpreter:
|
||||
def interpret_op(self, op: Op, args: tuple[Any, ...]):
|
||||
assert False, "subclass should implement this"
|
||||
|
||||
# Our first interpreter is the "evaluating interpreter" which performs ordinary
|
||||
# concrete evaluation.
|
||||
class EvalInterpreter:
|
||||
def interpret_op(self, op, args):
|
||||
assert all(isinstance(arg, float) for arg in args)
|
||||
match op:
|
||||
case Op.add:
|
||||
x, y = args
|
||||
return x + y
|
||||
case Op.mul:
|
||||
x, y = args
|
||||
return x * y
|
||||
case _:
|
||||
raise ValueError(f"Unrecognized primitive op: {op}")
|
||||
|
||||
# The current interpreter is initially the evaluating interpreter.
|
||||
current_interpreter = EvalInterpreter()
|
||||
|
||||
# A context manager for temporarily changing the current interpreter
|
||||
@contextmanager
|
||||
def set_interpreter(new_interpreter):
|
||||
global current_interpreter
|
||||
prev_interpreter = current_interpreter
|
||||
try:
|
||||
current_interpreter = new_interpreter
|
||||
yield
|
||||
finally:
|
||||
current_interpreter = prev_interpreter
|
||||
|
||||
# The user-facing functions `mul` and `add` dispatch to the current interpreter.
|
||||
def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))
|
||||
def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# At this point we can call `foo` with ordinary concrete inputs and see the
|
||||
# results:
|
||||
|
||||
print(foo(2.0))
|
||||
|
||||
# ## Aside: forward-mode automatic differentiation
|
||||
|
||||
# For our second interpreter we're going to try forward-mode automatic
|
||||
# differentiation (AD). Here's a quick introduction to forward-mode AD in case
|
||||
# this is the first time you've come across it. Otherwise skip ahead to the
|
||||
# "JVPInterprer" section.
|
||||
|
||||
# Suppose we're interested in the derivative of `foo(x)` evaluated at `x=2.0`.
|
||||
# We could approximate it with finite differences:
|
||||
|
||||
print((foo(2.00001) - foo(2.0)) / 0.00001)
|
||||
|
||||
# The answer is close to 7.0 as expected. But computing it this way required two
|
||||
# evaluations of the function (not to mention the roundoff error and truncation
|
||||
# error). Here's a funny thing though. We can almost get the answer with a
|
||||
# single evaluation:
|
||||
|
||||
print(foo(2.00001))
|
||||
|
||||
# The answer we're looking for, 7.0, is right there in the insignificant digits!
|
||||
|
||||
# Here's one way to think about what's happening. The initial argument to `foo`,
|
||||
# `2.00001`, carries two pieces of data: a "primal" value, 2.0, and a "tangent"
|
||||
# value, `1.0`. The representation of this primal-tangent pair, `2.00001`, is
|
||||
# the sum of the two, with the tangent scaled by a small fixed epsilon, `1e-5`.
|
||||
# Ordinary evaluation of `foo(2.00001)` propagates this primal-tangent pair,
|
||||
# producing `10.0000700001` as the result. The primal and tangent components are
|
||||
# well separated in scale so we can visually interpret the result as the
|
||||
# primal-tangent pair (10.0, 7.0), ignoring the the ~1e-10 truncation error at
|
||||
# the end.
|
||||
|
||||
# The idea with forward-mode differentiation is to do the same thing but exactly
|
||||
# and explicitly (eyeballing floats doesn't really scale). We'll represent the
|
||||
# primal-tangent pair as an actual pair instead of folding them both into a
|
||||
# single floating point number. For each primitive operation we'll have a rule
|
||||
# that describes how to propagate these primal tangent pairs. Let's work out the
|
||||
# rules for our two primitives.
|
||||
|
||||
# Addition is easy. Consider `x + y` where `x = xp + xt * eps` and `y = yp + yt * eps`
|
||||
# ("p" for "primal", "t" for "tangent"):
|
||||
#
|
||||
# x + y = (xp + xt * eps) + (yp + yt * eps)
|
||||
# = (xp + yp) # primal component
|
||||
# + (xt + yt) * eps # tangent component
|
||||
#
|
||||
# The result is a first-order polynomial in `eps` and we can read off the
|
||||
# primal-tangent pair as (xp + yp, xt + yt).
|
||||
|
||||
# Multiplication is more interesting:
|
||||
#
|
||||
# x * y = (xp + xt * eps) * (yp + yt * eps)
|
||||
# = (xp * yp) # primal component
|
||||
# + (xp * yt + xt * yp) * eps # tangent component
|
||||
# + (xt * yt) * eps * eps # quadratic component, vanishes in the eps->0 limit
|
||||
#
|
||||
# Now we have a second order polynomial. But as epsilon goes to zero the
|
||||
# quadratic term vanishes and our primal-tangent pair
|
||||
# is just `(xp * yp, xp * yt + xt * yp)`
|
||||
# (In our earlier example with finite `eps` this term not vanishing is
|
||||
# why we had the 1e-10 "truncation error").
|
||||
|
||||
# Putting this into code, we can write down the forward-AD rules for addition
|
||||
# and multiplication and express `foo` in terms of these:
|
||||
|
||||
# +
|
||||
from dataclasses import dataclass
|
||||
|
||||
# A primal-tangent pair is conventionally called a "dual number"
|
||||
@dataclass
|
||||
class DualNumber:
|
||||
primal : float
|
||||
tangent : float
|
||||
|
||||
def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:
|
||||
return DualNumber(x.primal + y.primal, x.tangent + y.tangent)
|
||||
|
||||
def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:
|
||||
return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)
|
||||
|
||||
def foo_dual(x : DualNumber) -> DualNumber:
|
||||
return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))
|
||||
|
||||
print (foo_dual(DualNumber(2.0, 1.0)))
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# That works! But rewriting `foo` to use the `_dual` versions of addition and
|
||||
# multiplication was a bit tedious. Let's get back to the main program and use
|
||||
# our interpretation machinery to do the rewrite automatically.
|
||||
|
||||
# ## JVP Interpreter
|
||||
|
||||
# We'll set up a new interpreter called `JVPInterpreter` ("JVP" for
|
||||
# "Jacobian-vector product") which propagates these dual numbers instead of
|
||||
# ordinary values. The `JVPInterpreter` has methods 'add' and 'mul' that operate
|
||||
# on dual number. They cast constant arguments to dual numbers as needed by
|
||||
# calling `JVPInterpreter.lift`. In our manually rewritten version above we did
|
||||
# that by replacing the literal `3.0` with `DualNumber(3.0, 0.0)`.
|
||||
|
||||
# +
|
||||
# This is like DualNumber above except that is also has a pointer to the
|
||||
# interpreter it belongs to, which is needed to avoid "perturbation confusion"
|
||||
# in higher order differentiation.
|
||||
@dataclass
|
||||
class TaggedDualNumber:
|
||||
interpreter : Interpreter
|
||||
primal : float
|
||||
tangent : float
|
||||
|
||||
class JVPInterpreter(Interpreter):
|
||||
def __init__(self, prev_interpreter: Interpreter):
|
||||
# We keep a pointer to the interpreter that was current when this
|
||||
# interpreter was first invoked. That's the context in which our
|
||||
# rules should run.
|
||||
self.prev_interpreter = prev_interpreter
|
||||
|
||||
def interpret_op(self, op, args):
|
||||
args = tuple(self.lift(arg) for arg in args)
|
||||
with set_interpreter(self.prev_interpreter):
|
||||
match op:
|
||||
case Op.add:
|
||||
# Notice that we use `add` and `mul` here, which are the
|
||||
# interpreter-dispatching functions defined earlier.
|
||||
x, y = args
|
||||
return self.dual_number(
|
||||
add(x.primal, y.primal),
|
||||
add(x.tangent, y.tangent))
|
||||
|
||||
case Op.mul:
|
||||
x, y = args
|
||||
x = self.lift(x)
|
||||
y = self.lift(y)
|
||||
return self.dual_number(
|
||||
mul(x.primal, y.primal),
|
||||
add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))
|
||||
|
||||
def dual_number(self, primal, tangent):
|
||||
return TaggedDualNumber(self, primal, tangent)
|
||||
|
||||
# Lift a constant value (constant with respect to this interpreter) to
|
||||
# a TaggedDualNumber.
|
||||
def lift(self, x):
|
||||
if isinstance(x, TaggedDualNumber) and x.interpreter is self:
|
||||
return x
|
||||
else:
|
||||
return self.dual_number(x, 0.0)
|
||||
|
||||
def jvp(f, primal, tangent):
|
||||
jvp_interpreter = JVPInterpreter(current_interpreter)
|
||||
dual_number_in = jvp_interpreter.dual_number(primal, tangent)
|
||||
with set_interpreter(jvp_interpreter):
|
||||
result = f(dual_number_in)
|
||||
dual_number_out = jvp_interpreter.lift(result)
|
||||
return dual_number_out.primal, dual_number_out.tangent
|
||||
|
||||
# Let's try it out:
|
||||
print(jvp(foo, 2.0, 1.0))
|
||||
|
||||
# Because we were careful to consider nesting interpreters, higher-order AD
|
||||
# works out of the box:
|
||||
|
||||
def derivative(f, x):
|
||||
_, tangent = jvp(f, x, 1.0)
|
||||
return tangent
|
||||
|
||||
def nth_order_derivative(n, f, x):
|
||||
if n == 0:
|
||||
return f(x)
|
||||
else:
|
||||
return derivative(lambda x: nth_order_derivative(n-1, f, x), x)
|
||||
|
||||
|
||||
# -
|
||||
|
||||
print(nth_order_derivative(0, foo, 2.0))
|
||||
|
||||
print(nth_order_derivative(1, foo, 2.0))
|
||||
|
||||
print(nth_order_derivative(2, foo, 2.0))
|
||||
|
||||
# The rest are zero because `foo` is only a second-order polymonial
|
||||
print(nth_order_derivative(3, foo, 2.0))
|
||||
|
||||
print(nth_order_derivative(4, foo, 2.0))
|
||||
|
||||
|
||||
# There are some subtleties worth discussing. First, how do you tell if
|
||||
# something is constant with respect to differentiation? It's tempting to say
|
||||
# "it's a constant if and only if it's not a dual number". But actually dual
|
||||
# numbers created by a *different* JVPInterpreter also need to be considered
|
||||
# constants with resepect to the JVPInterpreter we're currently handling. That's
|
||||
# why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This
|
||||
# comes up in higher order differentiation when there are multiple JVPInterprers
|
||||
# in scope. The sort of bug where you accidentally interpret a dual number from
|
||||
# a different interpreter as non-constant is sometimes called "perturbation
|
||||
# confusion" in the literature. Here's an example program that would have given
|
||||
# the wrong answer if we hadn't had the `and x.interpreter is self` check in
|
||||
# `JVPInterpreter.lift`.
|
||||
|
||||
# +
|
||||
def f(x):
|
||||
# g is constant in its (ignored) argument `y`. Its derivative should be zero
|
||||
# but our AD will mess it up if we don't distinguish perturbations from
|
||||
# different interpreters.
|
||||
def g(y):
|
||||
return x
|
||||
should_be_zero = derivative(g, 0.0)
|
||||
return mul(x, should_be_zero)
|
||||
|
||||
print(derivative(f, 0.0))
|
||||
# -
|
||||
|
||||
# Another subtlety: `JVPInterpreter.add` and `JVPInterpreter.mul` describe
|
||||
# addition and multiplication on dual numbers in terms of addition and
|
||||
# multiplication on the primal and tangent components. But we don't use ordinary
|
||||
# `+` and `*` for this. Instead we use our own `add` and `mul` functions which
|
||||
# dispatch to the current interpreter. Before calling them we set the current
|
||||
# interpreter to be the *previous* interpreter, i.e. the interpreter that was
|
||||
# current when `JVPInterpreter` was first invoked. If we didn't do this we'd
|
||||
# have an infinite recursion, with `add` and `mul` dispatching to
|
||||
# `JVPInterpreter` endlessly. The advantage of using own `add` and `mul` instead
|
||||
# of ordinary `+` and `*` is that it means we can nest these interpreters and do
|
||||
# higher-order AD.
|
||||
|
||||
# At this point you might be wondering: have we just reinvented operator
|
||||
# overloading? Python overloads the infix ops `+` and `*` to dispatch to the
|
||||
# argument's `__add__` and `__mul__`. Could we have just used that mechanism
|
||||
# instead of this whole interpreter business? Yes, actually. Indeed, the earlier
|
||||
# automatic differentiation (AD) literature uses the term "operator overloading"
|
||||
# to describe this style of AD implementation. One detail is that we can't rely
|
||||
# exclusively on Python built-in overloading because that only lets us overload
|
||||
# a handful of built-in infix ops whereas we eventually want to overload
|
||||
# numpy-level operations like `sin` and `cos`. So we need our own mechanism.
|
||||
|
||||
# But there's a more important difference: our dispatch is based on *context*
|
||||
# whereas traditional Python-style overloading is based on *data*. This is
|
||||
# actually a recent development for JAX. The earliest versions of JAX looked
|
||||
# more like traditional data-based overloading. An interpreter (a "trace" in JAX
|
||||
# jargon) for an operation would be chosen based on data attached to the
|
||||
# arguments to that operation. We've gradually made the interpreter-dispatch
|
||||
# decision rely more and more on context rather than data (omnistaging [link],
|
||||
# stackless [link]). The reason to prefer context-based interpretation over
|
||||
# data-based interpretation is that it makes the implementation much simpler.
|
||||
|
||||
# All that said, we do *also* want to take advantage of Python's built-in
|
||||
# overloading mechanism. That way we get the syntactic convenience of using
|
||||
# infix operators `+` and `*` instead of writing out `add(..)` and `mul(..)`.
|
||||
# But we'll put that aside for now.
|
||||
|
||||
# # 3. Staging to an untyped IR
|
||||
|
||||
# The two program transformations we've seen so far -- evaluation and JVP --
|
||||
# both traverse the input program from top to bottom. They visit the operations
|
||||
# one by one in the same order as ordinary evaluation. A convenient thing about
|
||||
# top-to-bottom transformations is that they can be implemented eagerly, or
|
||||
# "online", meaning that we can evaluate the program from top to bottom and
|
||||
# perform the necessary transformations as we go. We never look at the entire
|
||||
# program at once.
|
||||
|
||||
# But not all transformations work this way. For example, dead-code elimination
|
||||
# requires traversing from bottom to top, collecting usage statistics on the way
|
||||
# up and eliminating pure operations whose results have no uses. Another
|
||||
# bottom-to-top transformation is AD transposition, which we use to implement
|
||||
# reverse-mode AD. For these we need to first "stage" the program into an IR
|
||||
# (internal representation), a data structure representing the program, which we
|
||||
# can then traverse in any order we like. Building this IR from a Python program
|
||||
# will be the goal of our third and final interpreter.
|
||||
|
||||
# First, let's define the IR. We'll do an untypes ANF IR to start. A function
|
||||
# (we call IR functions "jaxprs" in JAX) will have a list of formal parameters,
|
||||
# a list of operations, and a return value. Each argument to an operation must
|
||||
# be an "atom", which is either a variable or a literal. The return value of the
|
||||
# function is also an atom.
|
||||
|
||||
# +
|
||||
Var = str # Variables are just strings in this untyped IR
|
||||
Atom = Var | float # Atoms (arguments to operations) can be variables or (float) literals
|
||||
|
||||
# Equation - a single line in our IR like `z = mul(x, y)`
|
||||
@dataclass
|
||||
class Equation:
|
||||
var : Var # The variable name of the result
|
||||
op : Op # The primitive operation we're applying
|
||||
args : tuple[Atom] # The arguments we're applying the primitive operation to
|
||||
|
||||
# We call an IR function a "Jaxpr", for "JAX expression"
|
||||
@dataclass
|
||||
class Jaxpr:
|
||||
parameters : list[Var] # The function's formal parameters (arguments)
|
||||
equations : list[Equation] # The body of the function, a list of instructions/equations
|
||||
return_val : Atom # The function's return value
|
||||
|
||||
def __str__(self):
|
||||
lines = []
|
||||
lines.append(', '.join(b for b in self.parameters) + ' ->')
|
||||
for eqn in self.equations:
|
||||
args_str = ', '.join(str(arg) for arg in eqn.args)
|
||||
lines.append(f' {eqn.var} = {eqn.op}({args_str})')
|
||||
lines.append(self.return_val)
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# To build the IR from a Python function we define a `StagingInterpreter` that
|
||||
# takes each operation and adds it to a growing list of all the operations we've
|
||||
# seen so far:
|
||||
|
||||
# +
|
||||
class StagingInterpreter(Interpreter):
|
||||
def __init__(self):
|
||||
self.equations = [] # A mutable list of all the ops we've seen so far
|
||||
self.name_counter = 0 # Counter for generating unique names
|
||||
|
||||
def fresh_var(self):
|
||||
self.name_counter += 1
|
||||
return "v_" + str(self.name_counter)
|
||||
|
||||
def interpret_op(self, op, args):
|
||||
binder = self.fresh_var()
|
||||
self.equations.append(Equation(binder, op, args))
|
||||
return binder
|
||||
|
||||
def build_jaxpr(f, num_args):
|
||||
interpreter = StagingInterpreter()
|
||||
parameters = tuple(interpreter.fresh_var() for _ in range(num_args))
|
||||
with set_interpreter(interpreter):
|
||||
result = f(*parameters)
|
||||
return Jaxpr(parameters, interpreter.equations, result)
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# Now we can construct an IR for a Python program and print it out:
|
||||
|
||||
print(build_jaxpr(foo, 1))
|
||||
|
||||
|
||||
# We can also evaluate our IR by writing an explicit interpreter that traverses
|
||||
# the operations one by one:
|
||||
|
||||
# +
|
||||
def eval_jaxpr(jaxpr, args):
|
||||
# An environment mapping variables to values
|
||||
env = dict(zip(jaxpr.parameters, args))
|
||||
def eval_atom(x): return env[x] if isinstance(x, Var) else x
|
||||
for eqn in jaxpr.equations:
|
||||
args = tuple(eval_atom(x) for x in eqn.args)
|
||||
env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)
|
||||
return eval_atom(jaxpr.return_val)
|
||||
|
||||
print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))
|
||||
# -
|
||||
|
||||
# We've written this interpreter in terms of `current_interpreter.interpret_op`
|
||||
# which means we've done a full round-trip: interpretable Python program to IR
|
||||
# to interpretable Python program. Since the result is "interpretable" we can
|
||||
# differentiate it again, or stage it out or anything we like:
|
||||
|
||||
print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))
|
||||
|
||||
# ## Up next...
|
||||
|
||||
# That's it for part one of this tutorial. We've done two primitives, three
|
||||
# interpreters and the tracing mechanism that weaves them together. In the next
|
||||
# part we'll add types other than floats, error handling, compilation,
|
||||
# reverse-mode AD and higher-order primtives. Note that the second part is
|
||||
# structured differently. Rather than trying to have a top-to-bottom order that
|
||||
# obeys both code dependencies (e.g. data structures need to be defined before
|
||||
# they're used) and pedagogical dependencies (concepts need to be introduced
|
||||
# before they're implemented) we're going with a single file that can be approached
|
||||
# in any order.
|
@ -138,6 +138,7 @@ exclude_patterns = [
|
||||
'pallas/tpu/matmul.md',
|
||||
'jep/9407-type-promotion.md',
|
||||
'autodidax.md',
|
||||
'autodidax2_part1.md',
|
||||
'sharded-computation.md',
|
||||
'ffi.ipynb',
|
||||
]
|
||||
|
@ -24,4 +24,5 @@ some of JAX's (extensible) internals.
|
||||
:caption: Design and internals
|
||||
|
||||
autodidax
|
||||
autodidax2_part1
|
||||
jep/index
|
||||
|
Loading…
x
Reference in New Issue
Block a user