mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
492 lines
19 KiB
Python
492 lines
19 KiB
Python
# ---
|
|
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# jupyter:
|
|
# jupytext:
|
|
# formats: ipynb,md:myst,py:light
|
|
# text_representation:
|
|
# extension: .py
|
|
# format_name: light
|
|
# format_version: '1.5'
|
|
# jupytext_version: 1.16.4
|
|
# kernelspec:
|
|
# display_name: Python 3 (ipykernel)
|
|
# language: python
|
|
# name: python3
|
|
# ---
|
|
|
|
# # Autodidax2, part 1: JAX from scratch, again
|
|
|
|
# If you want to understand how JAX works you could trying reading the code. But
|
|
# the code is complicated, often for no good reason. This notebook presents a
|
|
# stripped-back version without the cruft. It's a minimal version of JAX from
|
|
# first principles. Enjoy!
|
|
|
|
# ## Main idea: context-sensitive interpretation
|
|
|
|
# JAX is two things:
|
|
# 1. a set of primitive operations (roughly the NumPy API)
|
|
# 2. a set of interpreters over those primitives (compilation, AD, etc.)
|
|
#
|
|
# In this minimal version of JAX we'll start with just two primitive operations,
|
|
# addition and multiplication, and we'll add interpreters one by one. Suppose we
|
|
# have a user-defined function like this:
|
|
|
|
def foo(x):
|
|
return mul(x, add(x, 3.0))
|
|
|
|
|
|
# We want to be able to interpret `foo` in different ways without changing its
|
|
# implementation: we want to evaluate it on concrete values, differentiate it,
|
|
# stage it out to an IR, compile it and so on.
|
|
|
|
# Here's how we'll do it. For each of these interpretations we'll define an
|
|
# `Interpreter` object with a rule for handling each primitive operation. We'll
|
|
# keep track of the *current* interpreter using a global context variable. The
|
|
# user-facing functions `add` and `mul` will dispatch to the current
|
|
# interpreter. At the beginning of the program the current interpreter will be
|
|
# the "evaluating" interpreter which just evaluates the operations on ordinary
|
|
# concrete data. Here's what this all looks like so far.
|
|
|
|
# +
|
|
from enum import Enum, auto
|
|
from contextlib import contextmanager
|
|
from typing import Any
|
|
|
|
# The full (closed) set of primitive operations
|
|
class Op(Enum):
|
|
add = auto() # addition on floats
|
|
mul = auto() # multiplication on floats
|
|
|
|
# Interpreters have rules for handling each primitive operation.
|
|
class Interpreter:
|
|
def interpret_op(self, op: Op, args: tuple[Any, ...]):
|
|
assert False, "subclass should implement this"
|
|
|
|
# Our first interpreter is the "evaluating interpreter" which performs ordinary
|
|
# concrete evaluation.
|
|
class EvalInterpreter:
|
|
def interpret_op(self, op, args):
|
|
assert all(isinstance(arg, float) for arg in args)
|
|
match op:
|
|
case Op.add:
|
|
x, y = args
|
|
return x + y
|
|
case Op.mul:
|
|
x, y = args
|
|
return x * y
|
|
case _:
|
|
raise ValueError(f"Unrecognized primitive op: {op}")
|
|
|
|
# The current interpreter is initially the evaluating interpreter.
|
|
current_interpreter = EvalInterpreter()
|
|
|
|
# A context manager for temporarily changing the current interpreter
|
|
@contextmanager
|
|
def set_interpreter(new_interpreter):
|
|
global current_interpreter
|
|
prev_interpreter = current_interpreter
|
|
try:
|
|
current_interpreter = new_interpreter
|
|
yield
|
|
finally:
|
|
current_interpreter = prev_interpreter
|
|
|
|
# The user-facing functions `mul` and `add` dispatch to the current interpreter.
|
|
def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))
|
|
def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))
|
|
|
|
|
|
# -
|
|
|
|
# At this point we can call `foo` with ordinary concrete inputs and see the
|
|
# results:
|
|
|
|
print(foo(2.0))
|
|
|
|
# ## Aside: forward-mode automatic differentiation
|
|
|
|
# For our second interpreter we're going to try forward-mode automatic
|
|
# differentiation (AD). Here's a quick introduction to forward-mode AD in case
|
|
# this is the first time you've come across it. Otherwise skip ahead to the
|
|
# "JVPInterprer" section.
|
|
|
|
# Suppose we're interested in the derivative of `foo(x)` evaluated at `x=2.0`.
|
|
# We could approximate it with finite differences:
|
|
|
|
print((foo(2.00001) - foo(2.0)) / 0.00001)
|
|
|
|
# The answer is close to 7.0 as expected. But computing it this way required two
|
|
# evaluations of the function (not to mention the roundoff error and truncation
|
|
# error). Here's a funny thing though. We can almost get the answer with a
|
|
# single evaluation:
|
|
|
|
print(foo(2.00001))
|
|
|
|
# The answer we're looking for, 7.0, is right there in the insignificant digits!
|
|
|
|
# Here's one way to think about what's happening. The initial argument to `foo`,
|
|
# `2.00001`, carries two pieces of data: a "primal" value, 2.0, and a "tangent"
|
|
# value, `1.0`. The representation of this primal-tangent pair, `2.00001`, is
|
|
# the sum of the two, with the tangent scaled by a small fixed epsilon, `1e-5`.
|
|
# Ordinary evaluation of `foo(2.00001)` propagates this primal-tangent pair,
|
|
# producing `10.0000700001` as the result. The primal and tangent components are
|
|
# well separated in scale so we can visually interpret the result as the
|
|
# primal-tangent pair (10.0, 7.0), ignoring the the ~1e-10 truncation error at
|
|
# the end.
|
|
|
|
# The idea with forward-mode differentiation is to do the same thing but exactly
|
|
# and explicitly (eyeballing floats doesn't really scale). We'll represent the
|
|
# primal-tangent pair as an actual pair instead of folding them both into a
|
|
# single floating point number. For each primitive operation we'll have a rule
|
|
# that describes how to propagate these primal tangent pairs. Let's work out the
|
|
# rules for our two primitives.
|
|
|
|
# Addition is easy. Consider `x + y` where `x = xp + xt * eps` and `y = yp + yt * eps`
|
|
# ("p" for "primal", "t" for "tangent"):
|
|
#
|
|
# x + y = (xp + xt * eps) + (yp + yt * eps)
|
|
# = (xp + yp) # primal component
|
|
# + (xt + yt) * eps # tangent component
|
|
#
|
|
# The result is a first-order polynomial in `eps` and we can read off the
|
|
# primal-tangent pair as (xp + yp, xt + yt).
|
|
|
|
# Multiplication is more interesting:
|
|
#
|
|
# x * y = (xp + xt * eps) * (yp + yt * eps)
|
|
# = (xp * yp) # primal component
|
|
# + (xp * yt + xt * yp) * eps # tangent component
|
|
# + (xt * yt) * eps * eps # quadratic component, vanishes in the eps->0 limit
|
|
#
|
|
# Now we have a second order polynomial. But as epsilon goes to zero the
|
|
# quadratic term vanishes and our primal-tangent pair
|
|
# is just `(xp * yp, xp * yt + xt * yp)`
|
|
# (In our earlier example with finite `eps` this term not vanishing is
|
|
# why we had the 1e-10 "truncation error").
|
|
|
|
# Putting this into code, we can write down the forward-AD rules for addition
|
|
# and multiplication and express `foo` in terms of these:
|
|
|
|
# +
|
|
from dataclasses import dataclass
|
|
|
|
# A primal-tangent pair is conventionally called a "dual number"
|
|
@dataclass
|
|
class DualNumber:
|
|
primal : float
|
|
tangent : float
|
|
|
|
def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:
|
|
return DualNumber(x.primal + y.primal, x.tangent + y.tangent)
|
|
|
|
def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:
|
|
return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)
|
|
|
|
def foo_dual(x : DualNumber) -> DualNumber:
|
|
return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))
|
|
|
|
print (foo_dual(DualNumber(2.0, 1.0)))
|
|
|
|
|
|
# -
|
|
|
|
# That works! But rewriting `foo` to use the `_dual` versions of addition and
|
|
# multiplication was a bit tedious. Let's get back to the main program and use
|
|
# our interpretation machinery to do the rewrite automatically.
|
|
|
|
# ## JVP Interpreter
|
|
|
|
# We'll set up a new interpreter called `JVPInterpreter` ("JVP" for
|
|
# "Jacobian-vector product") which propagates these dual numbers instead of
|
|
# ordinary values. The `JVPInterpreter` has methods 'add' and 'mul' that operate
|
|
# on dual number. They cast constant arguments to dual numbers as needed by
|
|
# calling `JVPInterpreter.lift`. In our manually rewritten version above we did
|
|
# that by replacing the literal `3.0` with `DualNumber(3.0, 0.0)`.
|
|
|
|
# +
|
|
# This is like DualNumber above except that is also has a pointer to the
|
|
# interpreter it belongs to, which is needed to avoid "perturbation confusion"
|
|
# in higher order differentiation.
|
|
@dataclass
|
|
class TaggedDualNumber:
|
|
interpreter : Interpreter
|
|
primal : float
|
|
tangent : float
|
|
|
|
class JVPInterpreter(Interpreter):
|
|
def __init__(self, prev_interpreter: Interpreter):
|
|
# We keep a pointer to the interpreter that was current when this
|
|
# interpreter was first invoked. That's the context in which our
|
|
# rules should run.
|
|
self.prev_interpreter = prev_interpreter
|
|
|
|
def interpret_op(self, op, args):
|
|
args = tuple(self.lift(arg) for arg in args)
|
|
with set_interpreter(self.prev_interpreter):
|
|
match op:
|
|
case Op.add:
|
|
# Notice that we use `add` and `mul` here, which are the
|
|
# interpreter-dispatching functions defined earlier.
|
|
x, y = args
|
|
return self.dual_number(
|
|
add(x.primal, y.primal),
|
|
add(x.tangent, y.tangent))
|
|
|
|
case Op.mul:
|
|
x, y = args
|
|
x = self.lift(x)
|
|
y = self.lift(y)
|
|
return self.dual_number(
|
|
mul(x.primal, y.primal),
|
|
add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))
|
|
|
|
def dual_number(self, primal, tangent):
|
|
return TaggedDualNumber(self, primal, tangent)
|
|
|
|
# Lift a constant value (constant with respect to this interpreter) to
|
|
# a TaggedDualNumber.
|
|
def lift(self, x):
|
|
if isinstance(x, TaggedDualNumber) and x.interpreter is self:
|
|
return x
|
|
else:
|
|
return self.dual_number(x, 0.0)
|
|
|
|
def jvp(f, primal, tangent):
|
|
jvp_interpreter = JVPInterpreter(current_interpreter)
|
|
dual_number_in = jvp_interpreter.dual_number(primal, tangent)
|
|
with set_interpreter(jvp_interpreter):
|
|
result = f(dual_number_in)
|
|
dual_number_out = jvp_interpreter.lift(result)
|
|
return dual_number_out.primal, dual_number_out.tangent
|
|
|
|
# Let's try it out:
|
|
print(jvp(foo, 2.0, 1.0))
|
|
|
|
# Because we were careful to consider nesting interpreters, higher-order AD
|
|
# works out of the box:
|
|
|
|
def derivative(f, x):
|
|
_, tangent = jvp(f, x, 1.0)
|
|
return tangent
|
|
|
|
def nth_order_derivative(n, f, x):
|
|
if n == 0:
|
|
return f(x)
|
|
else:
|
|
return derivative(lambda x: nth_order_derivative(n-1, f, x), x)
|
|
|
|
|
|
# -
|
|
|
|
print(nth_order_derivative(0, foo, 2.0))
|
|
|
|
print(nth_order_derivative(1, foo, 2.0))
|
|
|
|
print(nth_order_derivative(2, foo, 2.0))
|
|
|
|
# The rest are zero because `foo` is only a second-order polymonial
|
|
print(nth_order_derivative(3, foo, 2.0))
|
|
|
|
print(nth_order_derivative(4, foo, 2.0))
|
|
|
|
|
|
# There are some subtleties worth discussing. First, how do you tell if
|
|
# something is constant with respect to differentiation? It's tempting to say
|
|
# "it's a constant if and only if it's not a dual number". But actually dual
|
|
# numbers created by a *different* JVPInterpreter also need to be considered
|
|
# constants with resepect to the JVPInterpreter we're currently handling. That's
|
|
# why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This
|
|
# comes up in higher order differentiation when there are multiple JVPInterprers
|
|
# in scope. The sort of bug where you accidentally interpret a dual number from
|
|
# a different interpreter as non-constant is sometimes called "perturbation
|
|
# confusion" in the literature. Here's an example program that would have given
|
|
# the wrong answer if we hadn't had the `and x.interpreter is self` check in
|
|
# `JVPInterpreter.lift`.
|
|
|
|
# +
|
|
def f(x):
|
|
# g is constant in its (ignored) argument `y`. Its derivative should be zero
|
|
# but our AD will mess it up if we don't distinguish perturbations from
|
|
# different interpreters.
|
|
def g(y):
|
|
return x
|
|
should_be_zero = derivative(g, 0.0)
|
|
return mul(x, should_be_zero)
|
|
|
|
print(derivative(f, 0.0))
|
|
# -
|
|
|
|
# Another subtlety: `JVPInterpreter.add` and `JVPInterpreter.mul` describe
|
|
# addition and multiplication on dual numbers in terms of addition and
|
|
# multiplication on the primal and tangent components. But we don't use ordinary
|
|
# `+` and `*` for this. Instead we use our own `add` and `mul` functions which
|
|
# dispatch to the current interpreter. Before calling them we set the current
|
|
# interpreter to be the *previous* interpreter, i.e. the interpreter that was
|
|
# current when `JVPInterpreter` was first invoked. If we didn't do this we'd
|
|
# have an infinite recursion, with `add` and `mul` dispatching to
|
|
# `JVPInterpreter` endlessly. The advantage of using own `add` and `mul` instead
|
|
# of ordinary `+` and `*` is that it means we can nest these interpreters and do
|
|
# higher-order AD.
|
|
|
|
# At this point you might be wondering: have we just reinvented operator
|
|
# overloading? Python overloads the infix ops `+` and `*` to dispatch to the
|
|
# argument's `__add__` and `__mul__`. Could we have just used that mechanism
|
|
# instead of this whole interpreter business? Yes, actually. Indeed, the earlier
|
|
# automatic differentiation (AD) literature uses the term "operator overloading"
|
|
# to describe this style of AD implementation. One detail is that we can't rely
|
|
# exclusively on Python built-in overloading because that only lets us overload
|
|
# a handful of built-in infix ops whereas we eventually want to overload
|
|
# numpy-level operations like `sin` and `cos`. So we need our own mechanism.
|
|
|
|
# But there's a more important difference: our dispatch is based on *context*
|
|
# whereas traditional Python-style overloading is based on *data*. This is
|
|
# actually a recent development for JAX. The earliest versions of JAX looked
|
|
# more like traditional data-based overloading. An interpreter (a "trace" in JAX
|
|
# jargon) for an operation would be chosen based on data attached to the
|
|
# arguments to that operation. We've gradually made the interpreter-dispatch
|
|
# decision rely more and more on context rather than data (omnistaging [link],
|
|
# stackless [link]). The reason to prefer context-based interpretation over
|
|
# data-based interpretation is that it makes the implementation much simpler.
|
|
|
|
# All that said, we do *also* want to take advantage of Python's built-in
|
|
# overloading mechanism. That way we get the syntactic convenience of using
|
|
# infix operators `+` and `*` instead of writing out `add(..)` and `mul(..)`.
|
|
# But we'll put that aside for now.
|
|
|
|
# # 3. Staging to an untyped IR
|
|
|
|
# The two program transformations we've seen so far -- evaluation and JVP --
|
|
# both traverse the input program from top to bottom. They visit the operations
|
|
# one by one in the same order as ordinary evaluation. A convenient thing about
|
|
# top-to-bottom transformations is that they can be implemented eagerly, or
|
|
# "online", meaning that we can evaluate the program from top to bottom and
|
|
# perform the necessary transformations as we go. We never look at the entire
|
|
# program at once.
|
|
|
|
# But not all transformations work this way. For example, dead-code elimination
|
|
# requires traversing from bottom to top, collecting usage statistics on the way
|
|
# up and eliminating pure operations whose results have no uses. Another
|
|
# bottom-to-top transformation is AD transposition, which we use to implement
|
|
# reverse-mode AD. For these we need to first "stage" the program into an IR
|
|
# (internal representation), a data structure representing the program, which we
|
|
# can then traverse in any order we like. Building this IR from a Python program
|
|
# will be the goal of our third and final interpreter.
|
|
|
|
# First, let's define the IR. We'll do an untypes ANF IR to start. A function
|
|
# (we call IR functions "jaxprs" in JAX) will have a list of formal parameters,
|
|
# a list of operations, and a return value. Each argument to an operation must
|
|
# be an "atom", which is either a variable or a literal. The return value of the
|
|
# function is also an atom.
|
|
|
|
# +
|
|
Var = str # Variables are just strings in this untyped IR
|
|
Atom = Var | float # Atoms (arguments to operations) can be variables or (float) literals
|
|
|
|
# Equation - a single line in our IR like `z = mul(x, y)`
|
|
@dataclass
|
|
class Equation:
|
|
var : Var # The variable name of the result
|
|
op : Op # The primitive operation we're applying
|
|
args : tuple[Atom] # The arguments we're applying the primitive operation to
|
|
|
|
# We call an IR function a "Jaxpr", for "JAX expression"
|
|
@dataclass
|
|
class Jaxpr:
|
|
parameters : list[Var] # The function's formal parameters (arguments)
|
|
equations : list[Equation] # The body of the function, a list of instructions/equations
|
|
return_val : Atom # The function's return value
|
|
|
|
def __str__(self):
|
|
lines = []
|
|
lines.append(', '.join(b for b in self.parameters) + ' ->')
|
|
for eqn in self.equations:
|
|
args_str = ', '.join(str(arg) for arg in eqn.args)
|
|
lines.append(f' {eqn.var} = {eqn.op}({args_str})')
|
|
lines.append(self.return_val)
|
|
return '\n'.join(lines)
|
|
|
|
|
|
# -
|
|
|
|
# To build the IR from a Python function we define a `StagingInterpreter` that
|
|
# takes each operation and adds it to a growing list of all the operations we've
|
|
# seen so far:
|
|
|
|
# +
|
|
class StagingInterpreter(Interpreter):
|
|
def __init__(self):
|
|
self.equations = [] # A mutable list of all the ops we've seen so far
|
|
self.name_counter = 0 # Counter for generating unique names
|
|
|
|
def fresh_var(self):
|
|
self.name_counter += 1
|
|
return "v_" + str(self.name_counter)
|
|
|
|
def interpret_op(self, op, args):
|
|
binder = self.fresh_var()
|
|
self.equations.append(Equation(binder, op, args))
|
|
return binder
|
|
|
|
def build_jaxpr(f, num_args):
|
|
interpreter = StagingInterpreter()
|
|
parameters = tuple(interpreter.fresh_var() for _ in range(num_args))
|
|
with set_interpreter(interpreter):
|
|
result = f(*parameters)
|
|
return Jaxpr(parameters, interpreter.equations, result)
|
|
|
|
|
|
# -
|
|
|
|
# Now we can construct an IR for a Python program and print it out:
|
|
|
|
print(build_jaxpr(foo, 1))
|
|
|
|
|
|
# We can also evaluate our IR by writing an explicit interpreter that traverses
|
|
# the operations one by one:
|
|
|
|
# +
|
|
def eval_jaxpr(jaxpr, args):
|
|
# An environment mapping variables to values
|
|
env = dict(zip(jaxpr.parameters, args))
|
|
def eval_atom(x): return env[x] if isinstance(x, Var) else x
|
|
for eqn in jaxpr.equations:
|
|
args = tuple(eval_atom(x) for x in eqn.args)
|
|
env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)
|
|
return eval_atom(jaxpr.return_val)
|
|
|
|
print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))
|
|
# -
|
|
|
|
# We've written this interpreter in terms of `current_interpreter.interpret_op`
|
|
# which means we've done a full round-trip: interpretable Python program to IR
|
|
# to interpretable Python program. Since the result is "interpretable" we can
|
|
# differentiate it again, or stage it out or anything we like:
|
|
|
|
print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))
|
|
|
|
# ## Up next...
|
|
|
|
# That's it for part one of this tutorial. We've done two primitives, three
|
|
# interpreters and the tracing mechanism that weaves them together. In the next
|
|
# part we'll add types other than floats, error handling, compilation,
|
|
# reverse-mode AD and higher-order primtives. Note that the second part is
|
|
# structured differently. Rather than trying to have a top-to-bottom order that
|
|
# obeys both code dependencies (e.g. data structures need to be defined before
|
|
# they're used) and pedagogical dependencies (concepts need to be introduced
|
|
# before they're implemented) we're going with a single file that can be approached
|
|
# in any order.
|