mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 17:56:06 +00:00
778 lines
23 KiB
Plaintext
778 lines
23 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "M-hPMKlwXjMr"
|
|
},
|
|
"source": [
|
|
"# Writing custom Jaxpr interpreters in JAX"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "r-3vMiKRYXPJ"
|
|
},
|
|
"source": [
|
|
"JAX offers several composable function transformations (`jit`, `grad`, `vmap`,\n",
|
|
"etc.) that enable writing concise, accelerated code. \n",
|
|
"\n",
|
|
"Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "s27RDKvKXFL8"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as onp\n",
|
|
"import jax\n",
|
|
"import jax.numpy as np\n",
|
|
"from jax import jit, grad, vmap\n",
|
|
"from jax import random"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "jb_8mEsJboVM"
|
|
},
|
|
"source": [
|
|
"## What is JAX doing?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "KxR2WK0Ubs0R"
|
|
},
|
|
"source": [
|
|
"JAX provides a NumPy-like API for numerical computing which can be used as is, but JAX's true power comes from composable function transformations. Take the `jit` function transformation, which takes in a function and returns a semantically identical function but is lazily compiled by XLA for accelerators.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 54
|
|
},
|
|
"colab_type": "code",
|
|
"id": "HmlMcICOcSXR",
|
|
"outputId": "546bf21b-7d03-4364-a087-5802792abbb0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/usr/local/google/home/sharadmv/workspace/jax/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.\n",
|
|
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"x = random.normal(random.PRNGKey(0), (5000, 5000))\n",
|
|
"def f(w, b, x):\n",
|
|
" return np.tanh(np.dot(x, w) + b)\n",
|
|
"fast_f = jit(f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "gA8V51wZdsjh"
|
|
},
|
|
"source": [
|
|
"When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/google/jax#how-it-works) section in the README."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "2Th1vYLVaFBz"
|
|
},
|
|
"source": [
|
|
"## Jaxpr tracer\n",
|
|
"\n",
|
|
"A tracer of special importance in Jax is the Jaxpr tracer, which records ops into a Jaxpr (Jax expression). A Jaxpr is a data structure that can be evaluated like a mini functional programming language and \n",
|
|
"thus Jaxprs are a useful intermediate representation\n",
|
|
"for function transformation. \n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "pH7s63lpaHJO"
|
|
},
|
|
"source": [
|
|
"To get a first look at Jaxprs, consider the `make_jaxpr` transformation. `make_jaxpr` is essentially a \"pretty-printing\" transformation:\n",
|
|
"it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.\n",
|
|
"Although we can't generally use the Jaxprs that it returns, it is useful for debugging and introspection.\n",
|
|
"Let's use it to look at how some example Jaxprs\n",
|
|
"are structured."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 547
|
|
},
|
|
"colab_type": "code",
|
|
"id": "RSxEiWi-EeYW",
|
|
"outputId": "ee3d4f5c-97d3-4db3-a012-20898920cadb"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"foo\n",
|
|
"=====\n",
|
|
"invars: [a]\n",
|
|
"outvars: [b]\n",
|
|
"constvars: []\n",
|
|
"freevars: []\n",
|
|
"equation: [a, 1] add [b] {}\n",
|
|
"\n",
|
|
"jaxpr: { lambda ; ; a.\n",
|
|
" let b = add a 1\n",
|
|
" in [b] }\n",
|
|
"\n",
|
|
"\n",
|
|
"bar\n",
|
|
"=====\n",
|
|
"invars: [a, b, c]\n",
|
|
"outvars: [g, c]\n",
|
|
"constvars: [f]\n",
|
|
"freevars: []\n",
|
|
"equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None}\n",
|
|
"equation: [d, b] add [e] {}\n",
|
|
"equation: [e, f] add [g] {}\n",
|
|
"\n",
|
|
"jaxpr: { lambda f ; ; a b c.\n",
|
|
" let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
|
|
" precision=None ] a c\n",
|
|
" e = add d b\n",
|
|
" g = add e f\n",
|
|
" in [g, c] }\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def examine_jaxpr(jaxpr):\n",
|
|
" print(\"invars:\", jaxpr.invars)\n",
|
|
" print(\"outvars:\", jaxpr.outvars)\n",
|
|
" print(\"constvars:\", jaxpr.constvars)\n",
|
|
" print(\"freevars:\", jaxpr.freevars)\n",
|
|
" for eqn in jaxpr.eqns:\n",
|
|
" print(\"equation:\", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)\n",
|
|
" print()\n",
|
|
" print(\"jaxpr:\", jaxpr)\n",
|
|
"\n",
|
|
"def foo(x):\n",
|
|
" return x + 1\n",
|
|
"print(\"foo\")\n",
|
|
"print(\"=====\")\n",
|
|
"examine_jaxpr(jax.make_jaxpr(foo)(5))\n",
|
|
"\n",
|
|
"print()\n",
|
|
"\n",
|
|
"def bar(w, b, x):\n",
|
|
" return np.dot(w, x) + b + np.ones(5), x\n",
|
|
"print(\"bar\")\n",
|
|
"print(\"=====\")\n",
|
|
"examine_jaxpr(jax.make_jaxpr(bar)(np.ones((5, 10)), np.ones(5), np.ones(10)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "k-HxK9iagnH6"
|
|
},
|
|
"source": [
|
|
"* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions\n",
|
|
"* `jaxpr.outvars` - the `outvars` of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.\n",
|
|
"* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later)\n",
|
|
"* `jaxpr.freevars` - these can arise when nesting `jit` and `pmap` transformations; we won't worry about them in this colab.\n",
|
|
"* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.\n",
|
|
"\n",
|
|
"All together, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "NwY7TurYn6sr"
|
|
},
|
|
"source": [
|
|
"### Why are Jaxprs useful?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "UEy6RorCgdYt"
|
|
},
|
|
"source": [
|
|
"Jaxprs are simple program representations that are easy to transform. And because Jax lets us stage out Jaxprs from Python functions, it gives us a way to transform numerical programs written in Python."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "qizTKpbno_ua"
|
|
},
|
|
"source": [
|
|
"## Your first interpreter: `invert`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "OIto-KX4pD7j"
|
|
},
|
|
"source": [
|
|
"Let's try to implement a simple function \"inverter\", which takes in the output of the original function and returns the inputs that produced those outputs. For now, let's focus on simple, unary functions which are composed of other invertible unary functions.\n",
|
|
"\n",
|
|
"Goal:\n",
|
|
"```python\n",
|
|
"def f(x):\n",
|
|
" return np.exp(np.tanh(x))\n",
|
|
"f_inv = inverse(f)\n",
|
|
"assert np.allclose(f_inv(f(1.0)), 1.0)\n",
|
|
"```\n",
|
|
"\n",
|
|
"The way we'll implement this is by (1) tracing `f` into a Jaxpr, then (2) interpreting the Jaxpr *backwards*. While interpreting the Jaxpr backwards, for each equation we'll look up the primitive's inverse in a table and apply it.\n",
|
|
"\n",
|
|
"### 1. Tracing a function\n",
|
|
"\n",
|
|
"We can't use `make_jaxpr` for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to `make_jaxpr`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "BHkg_3P1pXJj"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Importing Jax functions useful for tracing/interpreting.\n",
|
|
"import numpy as onp\n",
|
|
"from functools import wraps\n",
|
|
"\n",
|
|
"from jax import api_util\n",
|
|
"from jax import core\n",
|
|
"from jax import lax\n",
|
|
"from jax import linear_util as lu\n",
|
|
"from jax import tree_util\n",
|
|
"from jax.abstract_arrays import ShapedArray\n",
|
|
"from jax.interpreters import partial_eval as pe\n",
|
|
"from jax.util import safe_map"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "aqHqPjuBqrl9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def make_jaxpr2(fun):\n",
|
|
" \n",
|
|
" def pv_like(x):\n",
|
|
" # ShapedArrays are abstract values that carry around\n",
|
|
" # shape and dtype information\n",
|
|
" aval = ShapedArray(onp.shape(x), onp.result_type(x))\n",
|
|
" return pe.PartialVal((aval, core.unit))\n",
|
|
"\n",
|
|
" @wraps(fun)\n",
|
|
" def jaxpr_const_maker(*args, **kwargs):\n",
|
|
" # Set up fun for transformation\n",
|
|
" wrapped = lu.wrap_init(fun)\n",
|
|
" # Flatten input args\n",
|
|
" jax_args, in_tree = tree_util.tree_flatten((args, kwargs))\n",
|
|
" # Transform fun to accept flat args\n",
|
|
" # and return a flat list result\n",
|
|
" jaxtree_fun, out_tree = api_util.flatten_fun(wrapped, in_tree) \n",
|
|
" # Abstract and partial-val's flat args\n",
|
|
" pvals = safe_map(pv_like, jax_args)\n",
|
|
" # Trace function into Jaxpr\n",
|
|
" jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals) \n",
|
|
" return jaxpr, consts, (in_tree, out_tree())\n",
|
|
" return jaxpr_const_maker"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "CpTml2PTrzZ4"
|
|
},
|
|
"source": [
|
|
"This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The `pe.trace_to_jaxpr` function is used to then trace a function into a Jaxpr\n",
|
|
"from a list of partial value inputs."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 119
|
|
},
|
|
"colab_type": "code",
|
|
"id": "Tc1REN5aq_fH",
|
|
"outputId": "2e6f2833-5139-48ef-da3f-01ea9cd6a7d3"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{ lambda ; ; a.\n",
|
|
" let b = tanh a\n",
|
|
" c = exp b\n",
|
|
" in [c] }\n",
|
|
"\n",
|
|
"()\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" return np.exp(np.tanh(x))\n",
|
|
"jaxpr, consts, _ = make_jaxpr2(f)(np.ones(5))\n",
|
|
"print(jaxpr)\n",
|
|
"print(consts)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "WmZ3BcmZsbfR"
|
|
},
|
|
"source": [
|
|
"This particular function doesn't have any example constants, but in general, this is how you both trace into a Jaxpr and extract the constants.\n",
|
|
"\n",
|
|
"### 2. Evaluating a Jaxpr\n",
|
|
"\n",
|
|
"\n",
|
|
"Before we write a custom Jaxpr interpreter, let's first implement the \"default\" interpreter, `eval_jaxpr`, which evaluates the Jaxpr as-is, computing the same values that the original, un-transformed Python function would. \n",
|
|
"\n",
|
|
"To do this, we first create an environment to store the values for each of the variables, and update the environment with each equation we evaluate in the Jaxpr."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "ACMxjIHStHwD"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def eval_jaxpr(jaxpr, consts, *args):\n",
|
|
" # Mapping from variable -> value\n",
|
|
" env = {}\n",
|
|
" \n",
|
|
" def read(var):\n",
|
|
" # Literals are values baked into the Jaxpr\n",
|
|
" if type(var) is core.Literal:\n",
|
|
" return var.val\n",
|
|
" return env[var]\n",
|
|
"\n",
|
|
" def write(var, val):\n",
|
|
" env[var] = val\n",
|
|
"\n",
|
|
" # Bind args and consts to environment\n",
|
|
" write(core.unitvar, core.unit)\n",
|
|
" safe_map(write, jaxpr.invars, args)\n",
|
|
" safe_map(write, jaxpr.constvars, consts)\n",
|
|
"\n",
|
|
" # Loop through equations and evaluate primitives using `bind`\n",
|
|
" for eqn in jaxpr.eqns:\n",
|
|
" # Read inputs to equation from environment\n",
|
|
" invals = safe_map(read, eqn.invars) \n",
|
|
" # `bind` is how a primitive is called\n",
|
|
" outvals = eqn.primitive.bind(*invals, **eqn.params)\n",
|
|
" # Primitives may return multiple outputs or not\n",
|
|
" if not eqn.primitive.multiple_results: \n",
|
|
" outvals = [outvals]\n",
|
|
" # Write the results of the primitive into the environment\n",
|
|
" safe_map(write, eqn.outvars, outvals) \n",
|
|
" # Read the final result of the Jaxpr from the environment\n",
|
|
" return safe_map(read, jaxpr.outvars) "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 51
|
|
},
|
|
"colab_type": "code",
|
|
"id": "mGHPc3NruCFV",
|
|
"outputId": "3f459905-02ad-4f43-e70f-3db0a5ff58e4"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[DeviceArray([2.14168763, 2.14168763, 2.14168763, 2.14168763, 2.14168763],\n",
|
|
" dtype=float32)]"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"jaxpr, consts, _ = make_jaxpr2(f)(np.ones(5))\n",
|
|
"eval_jaxpr(jaxpr, consts, np.ones(5))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "XhZhzbVBvAiT"
|
|
},
|
|
"source": [
|
|
"Notice that `eval_jaxpr` will always return a list even if the original function does not. To \"unflatten\" the list into what the function was originally supposed to return, we can use the `out_tree` object returned by `trace`. \n",
|
|
"\n",
|
|
"Furthermore, this interpreter does not handle `subjaxprs`, which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/master/jax/core.py#L185-L212)) to see the edge cases that this interpreter does not cover."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "0vb2ZoGrCMM4"
|
|
},
|
|
"source": [
|
|
"\n",
|
|
"### Custom `inverse` Jaxpr interpreter\n",
|
|
"\n",
|
|
"An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n",
|
|
"\n",
|
|
"It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/master/jax/interpreters/ad.py#L141-L187).\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "gSMIT2z1vUpO"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"inverse_registry = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "JgrpMgDyCrC7"
|
|
},
|
|
"source": [
|
|
"We'll now register inverses for some of the primitives. By convention, primitives in Jax end in `_p` and a lot of the popular ones live in `lax`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "fUerorGkCqhw"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"inverse_registry[lax.exp_p] = np.log\n",
|
|
"inverse_registry[lax.tanh_p] = np.arctanh"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "mDtH_lYDC5WK"
|
|
},
|
|
"source": [
|
|
"`inverse` will first trace the function, then custom-interpret the Jaxpr. Let's set up a simple skeleton."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "jGNfV6JJC1B3"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def inverse(fun):\n",
|
|
" @wraps(fun)\n",
|
|
" def wrapped(*args, **kwargs):\n",
|
|
" # Since we assume unary functions, we won't\n",
|
|
" # worry about flattening and\n",
|
|
" # unflattening arguments\n",
|
|
" jaxpr, consts, _ = make_jaxpr2(fun)(*args, **kwargs)\n",
|
|
" out = inverse_jaxpr(jaxpr, consts, *args)\n",
|
|
" return out[0]\n",
|
|
" return wrapped"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "g6v6wV7SDM7g"
|
|
},
|
|
"source": [
|
|
"Now we just need to define `inverse_jaxpr`, which will walk through the Jaxpr backward and invert primitives when it can."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "uUAd-L-BDKT5"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def inverse_jaxpr(jaxpr, consts, *args):\n",
|
|
" env = {}\n",
|
|
" \n",
|
|
" def read(var):\n",
|
|
" if type(var) is core.Literal:\n",
|
|
" return var.val\n",
|
|
" return env[var]\n",
|
|
"\n",
|
|
" def write(var, val):\n",
|
|
" env[var] = val\n",
|
|
" # Args now correspond to Jaxpr outvars\n",
|
|
" write(core.unitvar, core.unit)\n",
|
|
" safe_map(write, jaxpr.outvars, args)\n",
|
|
" safe_map(write, jaxpr.constvars, consts)\n",
|
|
"\n",
|
|
" # Looping backward\n",
|
|
" for eqn in jaxpr.eqns[::-1]:\n",
|
|
" # outvars are now invars \n",
|
|
" invals = safe_map(read, eqn.outvars)\n",
|
|
" if eqn.primitive not in inverse_registry:\n",
|
|
" raise NotImplementedError(\"{} does not have registered inverse.\".format(\n",
|
|
" eqn.primitive\n",
|
|
" ))\n",
|
|
" # Assuming a unary function \n",
|
|
" outval = inverse_registry[eqn.primitive](*invals)\n",
|
|
" safe_map(write, eqn.invars, [outval])\n",
|
|
" return safe_map(read, jaxpr.invars)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "M8i3wGbVERhA"
|
|
},
|
|
"source": [
|
|
"That's it!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 71
|
|
},
|
|
"colab_type": "code",
|
|
"id": "cjEKWso-D5Bu",
|
|
"outputId": "54e73a42-a908-448c-d6d5-e6d0fe9adbf9"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/usr/local/google/home/sharadmv/workspace/jax/jax/numpy/lax_numpy.py:220: RuntimeWarning: divide by zero encountered in arctanh\n",
|
|
" return _dtype(op(*args))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" return np.exp(np.tanh(x))\n",
|
|
"f_inv = inverse(f)\n",
|
|
"assert np.allclose(f_inv(f(1.0)), 1.0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Ny7Oo4WLHdXt"
|
|
},
|
|
"source": [
|
|
"Importantly, you can trace through a Jaxpr interpreter."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 238
|
|
},
|
|
"colab_type": "code",
|
|
"id": "j6ov_rveHmTb",
|
|
"outputId": "a7a9b4be-f284-4e78-c469-d84e352d838d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{ lambda ; ; a.\n",
|
|
" let b = log a\n",
|
|
" c = abs b\n",
|
|
" d = le c 1.0\n",
|
|
" e = add b 1.0\n",
|
|
" f = sub 1.0 b\n",
|
|
" g = div e f\n",
|
|
" h = log g\n",
|
|
" i = mul h 0.5\n",
|
|
" j = tie_in b nan\n",
|
|
" k = broadcast[ sizes=() ] j\n",
|
|
" l = select d i k\n",
|
|
" in [l] }"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.make_jaxpr(inverse(f))(f(1.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "yfWVBsKwH0j6"
|
|
},
|
|
"source": [
|
|
"That's all it takes to add a new transformation to a system, and you get composition with all the others for free! For example, we can use `jit`, `vmap`, and `grad` with `inverse`!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 51
|
|
},
|
|
"colab_type": "code",
|
|
"id": "3tjNk21CH4yZ",
|
|
"outputId": "4c6d2090-0558-4171-8dce-0ea681ef6e53"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DeviceArray([ 0. , 15.58493137, 2.25512528, 1.31550276,\n",
|
|
" 1. ], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"jit(vmap(grad(inverse(f))))((np.arange(5) + 1.) / 5.)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "APtG-u_6E4tK"
|
|
},
|
|
"source": [
|
|
"## Exercises for the reader\n",
|
|
"\n",
|
|
"* Handle primitives with multiple arguments where inputs are partially known, for example `lax.add_p`, `lax.mul_p`.\n",
|
|
"* Handle `xla_call` and `xla_pmap` primitives, which will not work with both `eval_jaxpr` and `inverse_jaxpr` as written."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"name": "Writing custom interpreters in Jax",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|