mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
CI: update ruff to v0.6.1
This commit is contained in:
parent
afff0e09aa
commit
68be5b5085
@ -26,7 +26,7 @@ repos:
|
||||
files: \.py$
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.4.4
|
||||
rev: v0.6.1
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
|
@ -510,8 +510,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"image_partitions = P(1, 1, 4, 2)\n",
|
||||
"sharded_conv = sharded_jit(conv, \n",
|
||||
" in_parts=(image_partitions, None), \n",
|
||||
"sharded_conv = sharded_jit(conv,\n",
|
||||
" in_parts=(image_partitions, None),\n",
|
||||
" out_parts=image_partitions)\n",
|
||||
"\n",
|
||||
"sharded_conv(image, kernel)"
|
||||
|
@ -877,7 +877,7 @@
|
||||
" def g(z):\n",
|
||||
" return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n",
|
||||
" return grad(lambda w: jnp.sum(g(w)))(x)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"f(x)"
|
||||
]
|
||||
},
|
||||
@ -950,17 +950,6 @@
|
||||
"per_example_hess = pmap(input_hess) # pmap!\n",
|
||||
"per_example_hess(inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "u3ggM_WYZ8QC"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -67,7 +67,6 @@
|
||||
"source": [
|
||||
"from functools import partial\n",
|
||||
"import jax\n",
|
||||
"from jax import jit, pmap\n",
|
||||
"from jax import lax\n",
|
||||
"from jax import tree_util\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
|
@ -640,7 +640,7 @@ def our_jacrev(f):
|
||||
y, vjp_fun = vjp(f, x)
|
||||
# Use vmap to do a matrix-Jacobian product.
|
||||
# Here, the matrix is the Euclidean basis, so we get all
|
||||
# entries in the Jacobian at once.
|
||||
# entries in the Jacobian at once.
|
||||
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
|
||||
return J
|
||||
return jacfun
|
||||
@ -654,7 +654,7 @@ from jax import jacfwd as builtin_jacfwd
|
||||
def our_jacfwd(f):
|
||||
def jacfun(x):
|
||||
_jvp = lambda s: jvp(f, (x,), (s,))[1]
|
||||
Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
|
||||
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
|
||||
return jnp.transpose(Jt)
|
||||
return jacfun
|
||||
|
||||
|
@ -3317,7 +3317,6 @@
|
||||
],
|
||||
"source": [
|
||||
"# @title\n",
|
||||
"from jax import dtypes\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import pandas as pd\n",
|
||||
@ -3802,7 +3801,6 @@
|
||||
],
|
||||
"source": [
|
||||
"# @title\n",
|
||||
"from jax import dtypes\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import pandas as pd\n",
|
||||
|
@ -908,7 +908,6 @@ display.HTML(table.to_html())
|
||||
:tags: [hide-input]
|
||||
|
||||
# @title
|
||||
from jax import dtypes
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import pandas as pd
|
||||
@ -963,7 +962,6 @@ display.HTML(table.to_html())
|
||||
:tags: [hide-input]
|
||||
|
||||
# @title
|
||||
from jax import dtypes
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import pandas as pd
|
||||
|
@ -226,7 +226,6 @@
|
||||
],
|
||||
"source": [
|
||||
"import jax.numpy as jnp\n",
|
||||
"import jax.lax as lax\n",
|
||||
"from jax import make_jaxpr\n",
|
||||
"\n",
|
||||
"# lax.fori_loop\n",
|
||||
@ -1031,7 +1030,6 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from jax import random\n",
|
||||
"key = random.key(0)\n",
|
||||
"key"
|
||||
]
|
||||
@ -1105,8 +1103,8 @@
|
||||
"print(\"old key\", key)\n",
|
||||
"key, subkey = random.split(key)\n",
|
||||
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
|
||||
"print(\" \\---SPLIT --> new key \", key)\n",
|
||||
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
|
||||
"print(r\" \\---SPLIT --> new key \", key)\n",
|
||||
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1140,8 +1138,8 @@
|
||||
"print(\"old key\", key)\n",
|
||||
"key, subkey = random.split(key)\n",
|
||||
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
|
||||
"print(\" \\---SPLIT --> new key \", key)\n",
|
||||
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
|
||||
"print(r\" \\---SPLIT --> new key \", key)\n",
|
||||
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1701,7 +1699,7 @@
|
||||
],
|
||||
"source": [
|
||||
"init_val = 0\n",
|
||||
"cond_fun = lambda x: x<10\n",
|
||||
"cond_fun = lambda x: x < 10\n",
|
||||
"body_fun = lambda x: x+1\n",
|
||||
"lax.while_loop(cond_fun, body_fun, init_val)\n",
|
||||
"# --> array(10, dtype=int32)"
|
||||
|
@ -130,7 +130,6 @@ It is not recommended to use iterators in any JAX function you want to `jit` or
|
||||
:outputId: 52d885fd-0239-4a08-f5ce-0c38cc008903
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax.lax as lax
|
||||
from jax import make_jaxpr
|
||||
|
||||
# lax.fori_loop
|
||||
@ -471,7 +470,6 @@ The random state is described by a special array element that we call a __key__:
|
||||
:id: yPHE7KTWgAWs
|
||||
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
|
||||
|
||||
from jax import random
|
||||
key = random.key(0)
|
||||
key
|
||||
```
|
||||
@ -504,8 +502,8 @@ Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a ne
|
||||
print("old key", key)
|
||||
key, subkey = random.split(key)
|
||||
normal_pseudorandom = random.normal(subkey, shape=(1,))
|
||||
print(" \---SPLIT --> new key ", key)
|
||||
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
|
||||
print(r" \---SPLIT --> new key ", key)
|
||||
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
|
||||
```
|
||||
|
||||
+++ {"id": "tqtFVE4MthO3"}
|
||||
@ -519,8 +517,8 @@ We propagate the __key__ and make new __subkeys__ whenever we need a new random
|
||||
print("old key", key)
|
||||
key, subkey = random.split(key)
|
||||
normal_pseudorandom = random.normal(subkey, shape=(1,))
|
||||
print(" \---SPLIT --> new key ", key)
|
||||
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
|
||||
print(r" \---SPLIT --> new key ", key)
|
||||
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
|
||||
```
|
||||
|
||||
+++ {"id": "0KLYUluz3lN3"}
|
||||
@ -805,7 +803,7 @@ def while_loop(cond_fun, body_fun, init_val):
|
||||
:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
|
||||
|
||||
init_val = 0
|
||||
cond_fun = lambda x: x<10
|
||||
cond_fun = lambda x: x < 10
|
||||
body_fun = lambda x: x+1
|
||||
lax.while_loop(cond_fun, body_fun, init_val)
|
||||
# --> array(10, dtype=int32)
|
||||
|
@ -247,7 +247,6 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax.numpy as jnp\n",
|
||||
"\n",
|
||||
"def log1pexp(x):\n",
|
||||
" return jnp.log(1. + jnp.exp(x))\n",
|
||||
@ -984,7 +983,7 @@
|
||||
" (a, x_star, x_star_bar),\n",
|
||||
" x_star_bar))\n",
|
||||
" return a_bar, jnp.zeros_like(x_star)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"def rev_iter(f, packed, u):\n",
|
||||
" a, x_star, x_star_bar = packed\n",
|
||||
" _, vjp_x = vjp(lambda x: f(a, x), x_star)\n",
|
||||
@ -1884,7 +1883,6 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from jax import vjp\n",
|
||||
"\n",
|
||||
"y, f_vjp = vjp(f, 3.)\n",
|
||||
"print(y)"
|
||||
@ -1983,7 +1981,7 @@
|
||||
" return x, x\n",
|
||||
"\n",
|
||||
"def debug_bwd(x, g):\n",
|
||||
" import pdb; pdb.set_trace()\n",
|
||||
" pdb.set_trace()\n",
|
||||
" return g\n",
|
||||
"\n",
|
||||
"debug.defvjp(debug_fwd, debug_bwd)"
|
||||
|
@ -145,7 +145,6 @@ Say we want to write a function called `log1pexp`, which computes $x \mapsto \lo
|
||||
:id: 6lWbTvs40ET-
|
||||
:outputId: 8caff99e-add1-4c70-ace3-212c0c5c6f4e
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
def log1pexp(x):
|
||||
return jnp.log(1. + jnp.exp(x))
|
||||
@ -524,7 +523,7 @@ def fixed_point_rev(f, res, x_star_bar):
|
||||
(a, x_star, x_star_bar),
|
||||
x_star_bar))
|
||||
return a_bar, jnp.zeros_like(x_star)
|
||||
|
||||
|
||||
def rev_iter(f, packed, u):
|
||||
a, x_star, x_star_bar = packed
|
||||
_, vjp_x = vjp(lambda x: f(a, x), x_star)
|
||||
@ -965,7 +964,6 @@ print(grad(f)(3.))
|
||||
:id: s1Pn_qCIODcF
|
||||
:outputId: 423d34e0-35b8-4b57-e89d-f70f20e28ea9
|
||||
|
||||
from jax import vjp
|
||||
|
||||
y, f_vjp = vjp(f, 3.)
|
||||
print(y)
|
||||
@ -1015,7 +1013,7 @@ def debug_fwd(x):
|
||||
return x, x
|
||||
|
||||
def debug_bwd(x, g):
|
||||
import pdb; pdb.set_trace()
|
||||
pdb.set_trace()
|
||||
return g
|
||||
|
||||
debug.defvjp(debug_fwd, debug_bwd)
|
||||
|
@ -30,9 +30,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import functools\n",
|
||||
"from typing import Optional\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
|
@ -26,9 +26,7 @@ This tutorial discusses parallelism via `jax.Array`, the unified array object mo
|
||||
```{code-cell}
|
||||
:id: FNxScTfq3vGF
|
||||
|
||||
import os
|
||||
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -15,12 +15,12 @@
|
||||
"*necula@google.com*, October 2019.\n",
|
||||
"\n",
|
||||
"JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,\n",
|
||||
"`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, \n",
|
||||
"`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,\n",
|
||||
"which means that as the Python function executes\n",
|
||||
"the only operations it applies to the data are either inspections of data\n",
|
||||
"attributes such as shape or type, or special operations called JAX primitives.\n",
|
||||
"In particular, a JAX-traceable function is sometimes invoked by JAX with\n",
|
||||
"abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, \n",
|
||||
"abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,\n",
|
||||
"which captures the type and the shape of values, but not the concrete data values.\n",
|
||||
"JAX primitives know how to operate on both concrete data\n",
|
||||
"values and on the JAX abstract values.\n",
|
||||
@ -30,7 +30,7 @@
|
||||
"to ensure that these transformations\n",
|
||||
"can be composed, e.g., `jit(jacfwd(grad(f)))`.\n",
|
||||
"\n",
|
||||
"There are pre-defined JAX primitives corresponding to most XLA operations, \n",
|
||||
"There are pre-defined JAX primitives corresponding to most XLA operations,\n",
|
||||
"e.g., add, matmul, sin, cos, indexing.\n",
|
||||
"JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs\n",
|
||||
"using JAX’s implementation of numpy are JAX-traceable and therefore transformable.\n",
|
||||
@ -42,8 +42,8 @@
|
||||
"**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**\n",
|
||||
"\n",
|
||||
"Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n",
|
||||
"as \"multiply_add(x, y, z) = x * y + z\". \n",
|
||||
"This function operates on 3 identically-shaped tensors of floating point \n",
|
||||
"as \"multiply_add(x, y, z) = x * y + z\".\n",
|
||||
"This function operates on 3 identically-shaped tensors of floating point\n",
|
||||
"values and performs the operations pointwise."
|
||||
]
|
||||
},
|
||||
@ -56,7 +56,7 @@
|
||||
"## Using existing primitives\n",
|
||||
"\n",
|
||||
"The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other\n",
|
||||
"functions that are themselves written using JAX primitives, e.g., those \n",
|
||||
"functions that are themselves written using JAX primitives, e.g., those\n",
|
||||
"defined in the `jax.lax` module:"
|
||||
]
|
||||
},
|
||||
@ -165,7 +165,7 @@
|
||||
" return str(v)\n",
|
||||
" def pp_values(args):\n",
|
||||
" return \", \".join([pp(arg) for arg in args])\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" @functools.wraps(func)\n",
|
||||
" def func_wrapper(*args):\n",
|
||||
" _trace_indent(\"call {}({})\".format(name, pp_values(args)))\n",
|
||||
@ -199,7 +199,7 @@
|
||||
"id": "Qf4eLrLCFYDl"
|
||||
},
|
||||
"source": [
|
||||
"Instead of using `jax.lax` primitives directly, we can use other functions \n",
|
||||
"Instead of using `jax.lax` primitives directly, we can use other functions\n",
|
||||
"that are already written in terms of those primitives, such as those in `jax.numpy`:"
|
||||
]
|
||||
},
|
||||
@ -244,7 +244,7 @@
|
||||
"def square_add_numpy(a, b):\n",
|
||||
" return multiply_add_numpy(a, a, b)\n",
|
||||
"\n",
|
||||
"print(\"\\nNormal evaluation:\") \n",
|
||||
"print(\"\\nNormal evaluation:\")\n",
|
||||
"print(\"square_add_numpy = \", square_add_numpy(2., 10.))\n",
|
||||
"print(\"\\nGradient evaluation:\")\n",
|
||||
"print(\"grad(square_add_numpy) = \", api.grad(square_add_numpy)(2.0, 10.))"
|
||||
@ -257,13 +257,13 @@
|
||||
},
|
||||
"source": [
|
||||
"Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and\n",
|
||||
"`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further \n",
|
||||
"below in this colab). \n",
|
||||
"It is important to remember that a JAX-traceable function must be able to \n",
|
||||
"`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further\n",
|
||||
"below in this colab).\n",
|
||||
"It is important to remember that a JAX-traceable function must be able to\n",
|
||||
"operate not only on concrete arguments but also on special abstract arguments\n",
|
||||
"that JAX may use to abstract the function execution.\n",
|
||||
"\n",
|
||||
"The JAX traceability property is satisfied as long as the function is written \n",
|
||||
"The JAX traceability property is satisfied as long as the function is written\n",
|
||||
"in terms of JAX primitives."
|
||||
]
|
||||
},
|
||||
@ -277,7 +277,7 @@
|
||||
"\n",
|
||||
"The right way to add support for multiply-add is in terms of existing\n",
|
||||
"JAX primitives, as shown above. However, in order to demonstrate how JAX\n",
|
||||
"primitives work let us pretend that we want to add a new primitive to \n",
|
||||
"primitives work let us pretend that we want to add a new primitive to\n",
|
||||
"JAX for the multiply-add functionality."
|
||||
]
|
||||
},
|
||||
@ -295,9 +295,9 @@
|
||||
"@trace(\"multiply_add_prim\")\n",
|
||||
"def multiply_add_prim(x, y, z):\n",
|
||||
" \"\"\"The JAX-traceable way to use the JAX primitive.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Note that the traced arguments must be passed as positional arguments\n",
|
||||
" to `bind`. \n",
|
||||
" to `bind`.\n",
|
||||
" \"\"\"\n",
|
||||
" return multiply_add_p.bind(x, y, z)\n",
|
||||
"\n",
|
||||
@ -392,7 +392,7 @@
|
||||
"\n",
|
||||
" This function does not need to be JAX traceable.\n",
|
||||
" Args:\n",
|
||||
" x, y, z: the concrete arguments of the primitive. Will only be called with \n",
|
||||
" x, y, z: the concrete arguments of the primitive. Will only be called with\n",
|
||||
" concrete values.\n",
|
||||
" Returns:\n",
|
||||
" the concrete result of the primitive.\n",
|
||||
@ -485,17 +485,17 @@
|
||||
},
|
||||
"source": [
|
||||
"#### Abstract evaluation rules\n",
|
||||
"In order to JIT the function, and for other transformations as well, \n",
|
||||
"JAX first evaluates it abstractly using only the \n",
|
||||
"In order to JIT the function, and for other transformations as well,\n",
|
||||
"JAX first evaluates it abstractly using only the\n",
|
||||
"shape and type of the arguments. This abstract evaluation serves multiple\n",
|
||||
"purposes:\n",
|
||||
"\n",
|
||||
" * Gets the sequence of JAX primitives that are used in the computation. This \n",
|
||||
" sequence will be compiled. \n",
|
||||
" * Computes the shape and type of all vectors and operations used in the computation. \n",
|
||||
" * Gets the sequence of JAX primitives that are used in the computation. This\n",
|
||||
" sequence will be compiled.\n",
|
||||
" * Computes the shape and type of all vectors and operations used in the computation.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. \n",
|
||||
"For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.\n",
|
||||
"In the latter case, JAX uses the actual concrete value wrapped as an abstract value."
|
||||
]
|
||||
},
|
||||
@ -527,7 +527,7 @@
|
||||
" \"\"\"Abstract evaluation of the primitive.\n",
|
||||
"\n",
|
||||
" This function does not need to be JAX traceable. It will be invoked with\n",
|
||||
" abstractions of the actual arguments. \n",
|
||||
" abstractions of the actual arguments.\n",
|
||||
" Args:\n",
|
||||
" xs, ys, zs: abstractions of the arguments.\n",
|
||||
" Result:\n",
|
||||
@ -603,7 +603,7 @@
|
||||
"\n",
|
||||
"JAX compilation works by compiling each primitive into a graph of XLA operations.\n",
|
||||
"\n",
|
||||
"This is the biggest hurdle to adding new functionality to JAX, because the \n",
|
||||
"This is the biggest hurdle to adding new functionality to JAX, because the\n",
|
||||
"set of XLA operations is limited, and JAX already has pre-defined primitives\n",
|
||||
"for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++."
|
||||
]
|
||||
@ -642,7 +642,7 @@
|
||||
},
|
||||
"source": [
|
||||
"Now we succeed to JIT. Notice below that JAX first evaluates the function\n",
|
||||
"abstractly, which triggers the `multiply_add_abstract_eval` function, and \n",
|
||||
"abstractly, which triggers the `multiply_add_abstract_eval` function, and\n",
|
||||
"then compiles the set of primitives it has encountered, including `multiply_add`.\n",
|
||||
"At this point JAX invokes `multiply_add_xla_translation`."
|
||||
]
|
||||
@ -682,7 +682,7 @@
|
||||
"source": [
|
||||
"Below is another use of `jit` where we compile only\n",
|
||||
"with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads\n",
|
||||
"in the third argument to `multiply_add_abstract_eval` being \n",
|
||||
"in the third argument to `multiply_add_abstract_eval` being\n",
|
||||
"`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with\n",
|
||||
"both `ShapedArray` and `ConcreteArray`."
|
||||
]
|
||||
@ -711,7 +711,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"assert api.jit(lambda x, y: square_add_prim(x, y), \n",
|
||||
"assert api.jit(lambda x, y: square_add_prim(x, y),\n",
|
||||
" static_argnums=1)(2., 10.) == 14."
|
||||
]
|
||||
},
|
||||
@ -794,16 +794,16 @@
|
||||
"def multiply_add_value_and_jvp(arg_values, arg_tangents):\n",
|
||||
" \"\"\"Evaluates the primal output and the tangents (Jacobian-vector product).\n",
|
||||
"\n",
|
||||
" Given values of the arguments and perturbation of the arguments (tangents), \n",
|
||||
" Given values of the arguments and perturbation of the arguments (tangents),\n",
|
||||
" compute the output of the primitive and the perturbation of the output.\n",
|
||||
"\n",
|
||||
" This method must be JAX-traceable. JAX may invoke it with abstract values \n",
|
||||
" This method must be JAX-traceable. JAX may invoke it with abstract values\n",
|
||||
" for the arguments and tangents.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" arg_values: a tuple of arguments\n",
|
||||
" arg_tangents: a tuple with the tangents of the arguments. The tuple has \n",
|
||||
" the same length as the arg_values. Some of the tangents may also be the \n",
|
||||
" arg_tangents: a tuple with the tangents of the arguments. The tuple has\n",
|
||||
" the same length as the arg_values. Some of the tangents may also be the\n",
|
||||
" special value ad.Zero to specify a zero tangent.\n",
|
||||
" Returns:\n",
|
||||
" a pair of the primal output and the tangent.\n",
|
||||
@ -811,26 +811,26 @@
|
||||
" x, y, z = arg_values\n",
|
||||
" xt, yt, zt = arg_tangents\n",
|
||||
" _trace(\"Primal evaluation:\")\n",
|
||||
" # Now we have a JAX-traceable computation of the output. \n",
|
||||
" # Normally, we can use the ma primitive itself to compute the primal output. \n",
|
||||
" # Now we have a JAX-traceable computation of the output.\n",
|
||||
" # Normally, we can use the ma primitive itself to compute the primal output.\n",
|
||||
" primal_out = multiply_add_prim(x, y, z)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" _trace(\"Tangent evaluation:\")\n",
|
||||
" # We must use a JAX-traceable way to compute the tangent. It turns out that \n",
|
||||
" # We must use a JAX-traceable way to compute the tangent. It turns out that\n",
|
||||
" # the output tangent can be computed as (xt * y + x * yt + zt),\n",
|
||||
" # which we can implement in a JAX-traceable way using the same \"multiply_add_prim\" primitive.\n",
|
||||
" \n",
|
||||
" # We do need to deal specially with Zero. Here we just turn it into a \n",
|
||||
" # proper tensor of 0s (of the same shape as 'x'). \n",
|
||||
" # An alternative would be to check for Zero and perform algebraic \n",
|
||||
"\n",
|
||||
" # We do need to deal specially with Zero. Here we just turn it into a\n",
|
||||
" # proper tensor of 0s (of the same shape as 'x').\n",
|
||||
" # An alternative would be to check for Zero and perform algebraic\n",
|
||||
" # simplification of the output tangent computation.\n",
|
||||
" def make_zero(tan):\n",
|
||||
" return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan \n",
|
||||
" \n",
|
||||
" return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan\n",
|
||||
"\n",
|
||||
" output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))\n",
|
||||
" return (primal_out, output_tangent)\n",
|
||||
"\n",
|
||||
"# Register the forward differentiation rule with JAX \n",
|
||||
"# Register the forward differentiation rule with JAX\n",
|
||||
"ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp"
|
||||
]
|
||||
},
|
||||
@ -880,7 +880,7 @@
|
||||
"id": "69QsEcu-lP4u"
|
||||
},
|
||||
"source": [
|
||||
"TO EXPLAIN: \n",
|
||||
"TO EXPLAIN:\n",
|
||||
"\n",
|
||||
" * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n",
|
||||
" * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n",
|
||||
@ -941,7 +941,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"assert api.jit(lambda arg_values, arg_tangents: \n",
|
||||
"assert api.jit(lambda arg_values, arg_tangents:\n",
|
||||
" api.jvp(square_add_prim, arg_values, arg_tangents))(\n",
|
||||
" (2., 10.), (1., 1.)) == (14., 5.)"
|
||||
]
|
||||
@ -953,7 +953,7 @@
|
||||
},
|
||||
"source": [
|
||||
"Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn\n",
|
||||
"evaluates abstractly both the primal and the tangent evaluation (a total of \n",
|
||||
"evaluates abstractly both the primal and the tangent evaluation (a total of\n",
|
||||
"3 invocations of the `ma` primitive). Then we compile the 3 occurrences\n",
|
||||
"of the primitive."
|
||||
]
|
||||
@ -967,21 +967,21 @@
|
||||
"### Reverse differentiation\n",
|
||||
"\n",
|
||||
"If we attempt now to use reverse differentiation we\n",
|
||||
"see that JAX starts by using the `multiply_add_value_and_jvp` to \n",
|
||||
"see that JAX starts by using the `multiply_add_value_and_jvp` to\n",
|
||||
"compute the forward differentiation for abstract values, but then runs\n",
|
||||
"into a `NotImplementedError`. \n",
|
||||
"into a `NotImplementedError`.\n",
|
||||
"\n",
|
||||
"When computing the reverse differentiation JAX first does abstract evaluation\n",
|
||||
"of the forward differentiation code `multiply_add_value_and_jvp` to obtain a \n",
|
||||
"trace of primitives that compute the output tangent. \n",
|
||||
"of the forward differentiation code `multiply_add_value_and_jvp` to obtain a\n",
|
||||
"trace of primitives that compute the output tangent.\n",
|
||||
"Observe that JAX performs this abstract evaluation with concrete values\n",
|
||||
"for the differentiation point, and abstract values for the tangents. \n",
|
||||
"for the differentiation point, and abstract values for the tangents.\n",
|
||||
"Observe also that JAX uses the special abstract tangent value `Zero` for\n",
|
||||
"the tangent corresponding to the 3rd argument of `ma`. This reflects the \n",
|
||||
"the tangent corresponding to the 3rd argument of `ma`. This reflects the\n",
|
||||
"fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n",
|
||||
"which flows to the 3rd argument to `multiply_add_prim`.\n",
|
||||
"\n",
|
||||
"Observe also that during the abstract evaluation of the tangent we pass the \n",
|
||||
"Observe also that during the abstract evaluation of the tangent we pass the\n",
|
||||
"value 0.0 as the tangent for the 3rd argument. This is due to the use\n",
|
||||
"of the `make_zero` function in the definition of `multiply_add_value_and_jvp`."
|
||||
]
|
||||
@ -1071,7 +1071,7 @@
|
||||
"\n",
|
||||
"As explained above, when computing reverse differentiation JAX obtains\n",
|
||||
"a trace of primitives that compute the tangent using forward differentiation.\n",
|
||||
"Then, **JAX interprets this trace abstractly backwards** and for each \n",
|
||||
"Then, **JAX interprets this trace abstractly backwards** and for each\n",
|
||||
"primitive it applies a **transposition** rule.\n",
|
||||
"\n",
|
||||
"To understand what is going on, consider for now a simpler example of the function \"f(x, y) = x * y + y\". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:\n",
|
||||
@ -1082,7 +1082,7 @@
|
||||
" ft = c + yt\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"By construction, the tangent calculation is always linear in the input tangents. \n",
|
||||
"By construction, the tangent calculation is always linear in the input tangents.\n",
|
||||
"The only non-linear operator that may arise in the tangent calculation is multiplication,\n",
|
||||
"but then one of the operands is constant.\n",
|
||||
"\n",
|
||||
@ -1108,8 +1108,8 @@
|
||||
" xct += act * 4.\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"One can verify that this computation produces `xct = 4.` and `yct = 3.`, which \n",
|
||||
"are the partial derivatives of the function `f`. \n",
|
||||
"One can verify that this computation produces `xct = 4.` and `yct = 3.`, which\n",
|
||||
"are the partial derivatives of the function `f`.\n",
|
||||
"\n",
|
||||
"JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:\n",
|
||||
"```\n",
|
||||
@ -1117,10 +1117,10 @@
|
||||
"```\n",
|
||||
"\n",
|
||||
"Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other\n",
|
||||
"arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned \n",
|
||||
"arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned\n",
|
||||
"for the constant arguments.\n",
|
||||
"\n",
|
||||
"In particular, \n",
|
||||
"In particular,\n",
|
||||
"```\n",
|
||||
" add_transpose(out_ct, _, _) = (out_ct, out_ct)\n",
|
||||
" mult_transpose(out_ct, x, _) = (None, x * out_ct)\n",
|
||||
@ -1140,16 +1140,16 @@
|
||||
"def multiply_add_transpose(ct, x, y, z):\n",
|
||||
" \"\"\"Evaluates the transpose of a linear primitive.\n",
|
||||
"\n",
|
||||
" This method is only used when computing the backward gradient following \n",
|
||||
" value_and_jvp, and is only needed for primitives that are used in the JVP \n",
|
||||
" calculation for some other primitive. We need transposition for multiply_add_prim, \n",
|
||||
" because we have used multiply_add_prim in the computation of the output_tangent in \n",
|
||||
" This method is only used when computing the backward gradient following\n",
|
||||
" value_and_jvp, and is only needed for primitives that are used in the JVP\n",
|
||||
" calculation for some other primitive. We need transposition for multiply_add_prim,\n",
|
||||
" because we have used multiply_add_prim in the computation of the output_tangent in\n",
|
||||
" multiply_add_value_and_jvp.\n",
|
||||
"\n",
|
||||
" In our case, multiply_add is not a linear primitive. However, it is used linearly \n",
|
||||
" In our case, multiply_add is not a linear primitive. However, it is used linearly\n",
|
||||
" w.r.t. tangents in multiply_add_value_and_jvp:\n",
|
||||
" output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Always one of the first two multiplicative arguments is a constant.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
@ -1244,7 +1244,7 @@
|
||||
},
|
||||
"source": [
|
||||
"Notice the two calls to `multiply_add_transpose`. They correspond to the two\n",
|
||||
"uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the \n",
|
||||
"uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the\n",
|
||||
"last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0."
|
||||
]
|
||||
},
|
||||
@ -1254,7 +1254,7 @@
|
||||
"id": "EIJs6FYmPg6c"
|
||||
},
|
||||
"source": [
|
||||
"#### JIT of reverse differentiation \n",
|
||||
"#### JIT of reverse differentiation\n",
|
||||
"\n",
|
||||
"Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n",
|
||||
"abstract values, while in the absence of JIT we used `ConcreteArray`."
|
||||
@ -1397,20 +1397,20 @@
|
||||
"@trace(\"multiply_add_batch\")\n",
|
||||
"def multiply_add_batch(vector_arg_values, batch_axes):\n",
|
||||
" \"\"\"Computes the batched version of the primitive.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" This must be a JAX-traceable function.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Since the multiply_add primitive already operates pointwise on arbitrary\n",
|
||||
" dimension tensors, to batch it we can use the primitive itself. This works as\n",
|
||||
" long as both the inputs have the same dimensions and are batched along the\n",
|
||||
" same axes. The result is batched along the axis that the inputs are batched.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" vector_arg_values: a tuple of two arguments, each being a tensor of matching\n",
|
||||
" shape.\n",
|
||||
" batch_axes: the axes that are being batched. See vmap documentation.\n",
|
||||
" Returns:\n",
|
||||
" a tuple of the result, and the result axis that was batched. \n",
|
||||
" a tuple of the result, and the result axis that was batched.\n",
|
||||
" \"\"\"\n",
|
||||
" assert batch_axes[0] == batch_axes[1]\n",
|
||||
" assert batch_axes[0] == batch_axes[2]\n",
|
||||
|
@ -22,12 +22,12 @@ kernelspec:
|
||||
*necula@google.com*, October 2019.
|
||||
|
||||
JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,
|
||||
`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,
|
||||
`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,
|
||||
which means that as the Python function executes
|
||||
the only operations it applies to the data are either inspections of data
|
||||
attributes such as shape or type, or special operations called JAX primitives.
|
||||
In particular, a JAX-traceable function is sometimes invoked by JAX with
|
||||
abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,
|
||||
abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,
|
||||
which captures the type and the shape of values, but not the concrete data values.
|
||||
JAX primitives know how to operate on both concrete data
|
||||
values and on the JAX abstract values.
|
||||
@ -37,7 +37,7 @@ The JAX-transformed functions must themselves be JAX-traceable functions,
|
||||
to ensure that these transformations
|
||||
can be composed, e.g., `jit(jacfwd(grad(f)))`.
|
||||
|
||||
There are pre-defined JAX primitives corresponding to most XLA operations,
|
||||
There are pre-defined JAX primitives corresponding to most XLA operations,
|
||||
e.g., add, matmul, sin, cos, indexing.
|
||||
JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs
|
||||
using JAX’s implementation of numpy are JAX-traceable and therefore transformable.
|
||||
@ -49,8 +49,8 @@ one can define a new primitive that encapsulates the behavior of the function.
|
||||
**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**
|
||||
|
||||
Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically
|
||||
as "multiply_add(x, y, z) = x * y + z".
|
||||
This function operates on 3 identically-shaped tensors of floating point
|
||||
as "multiply_add(x, y, z) = x * y + z".
|
||||
This function operates on 3 identically-shaped tensors of floating point
|
||||
values and performs the operations pointwise.
|
||||
|
||||
+++ {"id": "HIJYIHNTD1yI"}
|
||||
@ -58,7 +58,7 @@ values and performs the operations pointwise.
|
||||
## Using existing primitives
|
||||
|
||||
The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other
|
||||
functions that are themselves written using JAX primitives, e.g., those
|
||||
functions that are themselves written using JAX primitives, e.g., those
|
||||
defined in the `jax.lax` module:
|
||||
|
||||
```{code-cell} ipython3
|
||||
@ -134,7 +134,7 @@ def trace(name):
|
||||
return str(v)
|
||||
def pp_values(args):
|
||||
return ", ".join([pp(arg) for arg in args])
|
||||
|
||||
|
||||
@functools.wraps(func)
|
||||
def func_wrapper(*args):
|
||||
_trace_indent("call {}({})".format(name, pp_values(args)))
|
||||
@ -164,7 +164,7 @@ class expectNotImplementedError(object):
|
||||
|
||||
+++ {"id": "Qf4eLrLCFYDl"}
|
||||
|
||||
Instead of using `jax.lax` primitives directly, we can use other functions
|
||||
Instead of using `jax.lax` primitives directly, we can use other functions
|
||||
that are already written in terms of those primitives, such as those in `jax.numpy`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
@ -182,7 +182,7 @@ def multiply_add_numpy(x, y, z):
|
||||
def square_add_numpy(a, b):
|
||||
return multiply_add_numpy(a, a, b)
|
||||
|
||||
print("\nNormal evaluation:")
|
||||
print("\nNormal evaluation:")
|
||||
print("square_add_numpy = ", square_add_numpy(2., 10.))
|
||||
print("\nGradient evaluation:")
|
||||
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
|
||||
@ -191,13 +191,13 @@ print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
|
||||
+++ {"id": "Sg-D8EdeFn4a"}
|
||||
|
||||
Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and
|
||||
`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further
|
||||
below in this colab).
|
||||
It is important to remember that a JAX-traceable function must be able to
|
||||
`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further
|
||||
below in this colab).
|
||||
It is important to remember that a JAX-traceable function must be able to
|
||||
operate not only on concrete arguments but also on special abstract arguments
|
||||
that JAX may use to abstract the function execution.
|
||||
|
||||
The JAX traceability property is satisfied as long as the function is written
|
||||
The JAX traceability property is satisfied as long as the function is written
|
||||
in terms of JAX primitives.
|
||||
|
||||
+++ {"id": "WxrQO7-XGLcg"}
|
||||
@ -206,7 +206,7 @@ in terms of JAX primitives.
|
||||
|
||||
The right way to add support for multiply-add is in terms of existing
|
||||
JAX primitives, as shown above. However, in order to demonstrate how JAX
|
||||
primitives work let us pretend that we want to add a new primitive to
|
||||
primitives work let us pretend that we want to add a new primitive to
|
||||
JAX for the multiply-add functionality.
|
||||
|
||||
```{code-cell} ipython3
|
||||
@ -218,9 +218,9 @@ multiply_add_p = core.Primitive("multiply_add") # Create the primitive
|
||||
@trace("multiply_add_prim")
|
||||
def multiply_add_prim(x, y, z):
|
||||
"""The JAX-traceable way to use the JAX primitive.
|
||||
|
||||
|
||||
Note that the traced arguments must be passed as positional arguments
|
||||
to `bind`.
|
||||
to `bind`.
|
||||
"""
|
||||
return multiply_add_p.bind(x, y, z)
|
||||
|
||||
@ -257,7 +257,7 @@ def multiply_add_impl(x, y, z):
|
||||
|
||||
This function does not need to be JAX traceable.
|
||||
Args:
|
||||
x, y, z: the concrete arguments of the primitive. Will only be called with
|
||||
x, y, z: the concrete arguments of the primitive. Will only be called with
|
||||
concrete values.
|
||||
Returns:
|
||||
the concrete result of the primitive.
|
||||
@ -293,17 +293,17 @@ with expectNotImplementedError():
|
||||
+++ {"id": "rHS1bAGHH44E"}
|
||||
|
||||
#### Abstract evaluation rules
|
||||
In order to JIT the function, and for other transformations as well,
|
||||
JAX first evaluates it abstractly using only the
|
||||
In order to JIT the function, and for other transformations as well,
|
||||
JAX first evaluates it abstractly using only the
|
||||
shape and type of the arguments. This abstract evaluation serves multiple
|
||||
purposes:
|
||||
|
||||
* Gets the sequence of JAX primitives that are used in the computation. This
|
||||
sequence will be compiled.
|
||||
* Computes the shape and type of all vectors and operations used in the computation.
|
||||
* Gets the sequence of JAX primitives that are used in the computation. This
|
||||
sequence will be compiled.
|
||||
* Computes the shape and type of all vectors and operations used in the computation.
|
||||
|
||||
|
||||
For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.
|
||||
For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.
|
||||
In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
|
||||
|
||||
```{code-cell} ipython3
|
||||
@ -316,7 +316,7 @@ def multiply_add_abstract_eval(xs, ys, zs):
|
||||
"""Abstract evaluation of the primitive.
|
||||
|
||||
This function does not need to be JAX traceable. It will be invoked with
|
||||
abstractions of the actual arguments.
|
||||
abstractions of the actual arguments.
|
||||
Args:
|
||||
xs, ys, zs: abstractions of the arguments.
|
||||
Result:
|
||||
@ -349,7 +349,7 @@ with expectNotImplementedError():
|
||||
|
||||
JAX compilation works by compiling each primitive into a graph of XLA operations.
|
||||
|
||||
This is the biggest hurdle to adding new functionality to JAX, because the
|
||||
This is the biggest hurdle to adding new functionality to JAX, because the
|
||||
set of XLA operations is limited, and JAX already has pre-defined primitives
|
||||
for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++.
|
||||
|
||||
@ -378,7 +378,7 @@ mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
|
||||
+++ {"id": "K98LX-VaJkFu"}
|
||||
|
||||
Now we succeed to JIT. Notice below that JAX first evaluates the function
|
||||
abstractly, which triggers the `multiply_add_abstract_eval` function, and
|
||||
abstractly, which triggers the `multiply_add_abstract_eval` function, and
|
||||
then compiles the set of primitives it has encountered, including `multiply_add`.
|
||||
At this point JAX invokes `multiply_add_xla_translation`.
|
||||
|
||||
@ -393,7 +393,7 @@ assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
|
||||
|
||||
Below is another use of `jit` where we compile only
|
||||
with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads
|
||||
in the third argument to `multiply_add_abstract_eval` being
|
||||
in the third argument to `multiply_add_abstract_eval` being
|
||||
`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with
|
||||
both `ShapedArray` and `ConcreteArray`.
|
||||
|
||||
@ -401,7 +401,7 @@ both `ShapedArray` and `ConcreteArray`.
|
||||
:id: mPfTwIBoKOEK
|
||||
:outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b
|
||||
|
||||
assert api.jit(lambda x, y: square_add_prim(x, y),
|
||||
assert api.jit(lambda x, y: square_add_prim(x, y),
|
||||
static_argnums=1)(2., 10.) == 14.
|
||||
```
|
||||
|
||||
@ -437,16 +437,16 @@ from jax.interpreters import ad
|
||||
def multiply_add_value_and_jvp(arg_values, arg_tangents):
|
||||
"""Evaluates the primal output and the tangents (Jacobian-vector product).
|
||||
|
||||
Given values of the arguments and perturbation of the arguments (tangents),
|
||||
Given values of the arguments and perturbation of the arguments (tangents),
|
||||
compute the output of the primitive and the perturbation of the output.
|
||||
|
||||
This method must be JAX-traceable. JAX may invoke it with abstract values
|
||||
This method must be JAX-traceable. JAX may invoke it with abstract values
|
||||
for the arguments and tangents.
|
||||
|
||||
Args:
|
||||
arg_values: a tuple of arguments
|
||||
arg_tangents: a tuple with the tangents of the arguments. The tuple has
|
||||
the same length as the arg_values. Some of the tangents may also be the
|
||||
arg_tangents: a tuple with the tangents of the arguments. The tuple has
|
||||
the same length as the arg_values. Some of the tangents may also be the
|
||||
special value ad.Zero to specify a zero tangent.
|
||||
Returns:
|
||||
a pair of the primal output and the tangent.
|
||||
@ -454,26 +454,26 @@ def multiply_add_value_and_jvp(arg_values, arg_tangents):
|
||||
x, y, z = arg_values
|
||||
xt, yt, zt = arg_tangents
|
||||
_trace("Primal evaluation:")
|
||||
# Now we have a JAX-traceable computation of the output.
|
||||
# Normally, we can use the ma primitive itself to compute the primal output.
|
||||
# Now we have a JAX-traceable computation of the output.
|
||||
# Normally, we can use the ma primitive itself to compute the primal output.
|
||||
primal_out = multiply_add_prim(x, y, z)
|
||||
|
||||
|
||||
_trace("Tangent evaluation:")
|
||||
# We must use a JAX-traceable way to compute the tangent. It turns out that
|
||||
# We must use a JAX-traceable way to compute the tangent. It turns out that
|
||||
# the output tangent can be computed as (xt * y + x * yt + zt),
|
||||
# which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
|
||||
|
||||
# We do need to deal specially with Zero. Here we just turn it into a
|
||||
# proper tensor of 0s (of the same shape as 'x').
|
||||
# An alternative would be to check for Zero and perform algebraic
|
||||
|
||||
# We do need to deal specially with Zero. Here we just turn it into a
|
||||
# proper tensor of 0s (of the same shape as 'x').
|
||||
# An alternative would be to check for Zero and perform algebraic
|
||||
# simplification of the output tangent computation.
|
||||
def make_zero(tan):
|
||||
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
|
||||
|
||||
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
|
||||
|
||||
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
|
||||
return (primal_out, output_tangent)
|
||||
|
||||
# Register the forward differentiation rule with JAX
|
||||
# Register the forward differentiation rule with JAX
|
||||
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
|
||||
```
|
||||
|
||||
@ -487,7 +487,7 @@ assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
|
||||
|
||||
+++ {"id": "69QsEcu-lP4u"}
|
||||
|
||||
TO EXPLAIN:
|
||||
TO EXPLAIN:
|
||||
|
||||
* Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.
|
||||
* Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet
|
||||
@ -504,7 +504,7 @@ We can apply JIT to the forward differentiation function:
|
||||
:id: hg-hzVu-N-hv
|
||||
:outputId: 38d32067-e152-4046-ad80-7f95a31ba628
|
||||
|
||||
assert api.jit(lambda arg_values, arg_tangents:
|
||||
assert api.jit(lambda arg_values, arg_tangents:
|
||||
api.jvp(square_add_prim, arg_values, arg_tangents))(
|
||||
(2., 10.), (1., 1.)) == (14., 5.)
|
||||
```
|
||||
@ -512,7 +512,7 @@ assert api.jit(lambda arg_values, arg_tangents:
|
||||
+++ {"id": "jlZt1_v2mU88"}
|
||||
|
||||
Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn
|
||||
evaluates abstractly both the primal and the tangent evaluation (a total of
|
||||
evaluates abstractly both the primal and the tangent evaluation (a total of
|
||||
3 invocations of the `ma` primitive). Then we compile the 3 occurrences
|
||||
of the primitive.
|
||||
|
||||
@ -521,21 +521,21 @@ of the primitive.
|
||||
### Reverse differentiation
|
||||
|
||||
If we attempt now to use reverse differentiation we
|
||||
see that JAX starts by using the `multiply_add_value_and_jvp` to
|
||||
see that JAX starts by using the `multiply_add_value_and_jvp` to
|
||||
compute the forward differentiation for abstract values, but then runs
|
||||
into a `NotImplementedError`.
|
||||
into a `NotImplementedError`.
|
||||
|
||||
When computing the reverse differentiation JAX first does abstract evaluation
|
||||
of the forward differentiation code `multiply_add_value_and_jvp` to obtain a
|
||||
trace of primitives that compute the output tangent.
|
||||
of the forward differentiation code `multiply_add_value_and_jvp` to obtain a
|
||||
trace of primitives that compute the output tangent.
|
||||
Observe that JAX performs this abstract evaluation with concrete values
|
||||
for the differentiation point, and abstract values for the tangents.
|
||||
for the differentiation point, and abstract values for the tangents.
|
||||
Observe also that JAX uses the special abstract tangent value `Zero` for
|
||||
the tangent corresponding to the 3rd argument of `ma`. This reflects the
|
||||
the tangent corresponding to the 3rd argument of `ma`. This reflects the
|
||||
fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,
|
||||
which flows to the 3rd argument to `multiply_add_prim`.
|
||||
|
||||
Observe also that during the abstract evaluation of the tangent we pass the
|
||||
Observe also that during the abstract evaluation of the tangent we pass the
|
||||
value 0.0 as the tangent for the 3rd argument. This is due to the use
|
||||
of the `make_zero` function in the definition of `multiply_add_value_and_jvp`.
|
||||
|
||||
@ -560,7 +560,7 @@ to use the forward differentiation code to compute reverse differentiation.
|
||||
|
||||
As explained above, when computing reverse differentiation JAX obtains
|
||||
a trace of primitives that compute the tangent using forward differentiation.
|
||||
Then, **JAX interprets this trace abstractly backwards** and for each
|
||||
Then, **JAX interprets this trace abstractly backwards** and for each
|
||||
primitive it applies a **transposition** rule.
|
||||
|
||||
To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:
|
||||
@ -571,7 +571,7 @@ To understand what is going on, consider for now a simpler example of the functi
|
||||
ft = c + yt
|
||||
```
|
||||
|
||||
By construction, the tangent calculation is always linear in the input tangents.
|
||||
By construction, the tangent calculation is always linear in the input tangents.
|
||||
The only non-linear operator that may arise in the tangent calculation is multiplication,
|
||||
but then one of the operands is constant.
|
||||
|
||||
@ -597,8 +597,8 @@ of the operation:
|
||||
xct += act * 4.
|
||||
```
|
||||
|
||||
One can verify that this computation produces `xct = 4.` and `yct = 3.`, which
|
||||
are the partial derivatives of the function `f`.
|
||||
One can verify that this computation produces `xct = 4.` and `yct = 3.`, which
|
||||
are the partial derivatives of the function `f`.
|
||||
|
||||
JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:
|
||||
```
|
||||
@ -606,10 +606,10 @@ p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
|
||||
```
|
||||
|
||||
Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other
|
||||
arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned
|
||||
arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned
|
||||
for the constant arguments.
|
||||
|
||||
In particular,
|
||||
In particular,
|
||||
```
|
||||
add_transpose(out_ct, _, _) = (out_ct, out_ct)
|
||||
mult_transpose(out_ct, x, _) = (None, x * out_ct)
|
||||
@ -623,16 +623,16 @@ In particular,
|
||||
def multiply_add_transpose(ct, x, y, z):
|
||||
"""Evaluates the transpose of a linear primitive.
|
||||
|
||||
This method is only used when computing the backward gradient following
|
||||
value_and_jvp, and is only needed for primitives that are used in the JVP
|
||||
calculation for some other primitive. We need transposition for multiply_add_prim,
|
||||
because we have used multiply_add_prim in the computation of the output_tangent in
|
||||
This method is only used when computing the backward gradient following
|
||||
value_and_jvp, and is only needed for primitives that are used in the JVP
|
||||
calculation for some other primitive. We need transposition for multiply_add_prim,
|
||||
because we have used multiply_add_prim in the computation of the output_tangent in
|
||||
multiply_add_value_and_jvp.
|
||||
|
||||
In our case, multiply_add is not a linear primitive. However, it is used linearly
|
||||
In our case, multiply_add is not a linear primitive. However, it is used linearly
|
||||
w.r.t. tangents in multiply_add_value_and_jvp:
|
||||
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
|
||||
|
||||
|
||||
Always one of the first two multiplicative arguments is a constant.
|
||||
|
||||
Args:
|
||||
@ -674,12 +674,12 @@ assert api.grad(square_add_prim)(2., 10.) == 4.
|
||||
+++ {"id": "8M1xLCXW4fK7"}
|
||||
|
||||
Notice the two calls to `multiply_add_transpose`. They correspond to the two
|
||||
uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the
|
||||
uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the
|
||||
last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0.
|
||||
|
||||
+++ {"id": "EIJs6FYmPg6c"}
|
||||
|
||||
#### JIT of reverse differentiation
|
||||
#### JIT of reverse differentiation
|
||||
|
||||
Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only
|
||||
abstract values, while in the absence of JIT we used `ConcreteArray`.
|
||||
@ -721,20 +721,20 @@ from jax.interpreters import batching
|
||||
@trace("multiply_add_batch")
|
||||
def multiply_add_batch(vector_arg_values, batch_axes):
|
||||
"""Computes the batched version of the primitive.
|
||||
|
||||
|
||||
This must be a JAX-traceable function.
|
||||
|
||||
|
||||
Since the multiply_add primitive already operates pointwise on arbitrary
|
||||
dimension tensors, to batch it we can use the primitive itself. This works as
|
||||
long as both the inputs have the same dimensions and are batched along the
|
||||
same axes. The result is batched along the axis that the inputs are batched.
|
||||
|
||||
|
||||
Args:
|
||||
vector_arg_values: a tuple of two arguments, each being a tensor of matching
|
||||
shape.
|
||||
batch_axes: the axes that are being batched. See vmap documentation.
|
||||
Returns:
|
||||
a tuple of the result, and the result axis that was batched.
|
||||
a tuple of the result, and the result axis that was batched.
|
||||
"""
|
||||
assert batch_axes[0] == batch_axes[1]
|
||||
assert batch_axes[0] == batch_axes[2]
|
||||
|
@ -119,7 +119,7 @@
|
||||
" for w, b in params[:-1]:\n",
|
||||
" outputs = jnp.dot(w, activations) + b\n",
|
||||
" activations = relu(outputs)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" final_w, final_b = params[-1]\n",
|
||||
" logits = jnp.dot(final_w, activations) + final_b\n",
|
||||
" return logits - logsumexp(logits)"
|
||||
@ -238,7 +238,7 @@
|
||||
"def one_hot(x, k, dtype=jnp.float32):\n",
|
||||
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
|
||||
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"def accuracy(params, images, targets):\n",
|
||||
" target_class = jnp.argmax(targets, axis=1)\n",
|
||||
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",
|
||||
|
@ -96,7 +96,7 @@ def predict(params, image):
|
||||
for w, b in params[:-1]:
|
||||
outputs = jnp.dot(w, activations) + b
|
||||
activations = relu(outputs)
|
||||
|
||||
|
||||
final_w, final_b = params[-1]
|
||||
logits = jnp.dot(final_w, activations) + final_b
|
||||
return logits - logsumexp(logits)
|
||||
@ -156,7 +156,7 @@ At this point, we have all the ingredients we need to define our neural network
|
||||
def one_hot(x, k, dtype=jnp.float32):
|
||||
"""Create a one-hot encoding of x of size k."""
|
||||
return jnp.array(x[:, None] == jnp.arange(k), dtype)
|
||||
|
||||
|
||||
def accuracy(params, images, targets):
|
||||
target_class = jnp.argmax(targets, axis=1)
|
||||
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
|
||||
|
@ -35,7 +35,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import jit, grad, vmap\n",
|
||||
@ -214,7 +213,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Importing Jax functions useful for tracing/interpreting.\n",
|
||||
"import numpy as np\n",
|
||||
"from functools import wraps\n",
|
||||
"\n",
|
||||
"from jax import core\n",
|
||||
@ -273,7 +271,7 @@
|
||||
"def eval_jaxpr(jaxpr, consts, *args):\n",
|
||||
" # Mapping from variable -> value\n",
|
||||
" env = {}\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def read(var):\n",
|
||||
" # Literals are values baked into the Jaxpr\n",
|
||||
" if type(var) is core.Literal:\n",
|
||||
@ -290,16 +288,16 @@
|
||||
" # Loop through equations and evaluate primitives using `bind`\n",
|
||||
" for eqn in jaxpr.eqns:\n",
|
||||
" # Read inputs to equation from environment\n",
|
||||
" invals = safe_map(read, eqn.invars) \n",
|
||||
" invals = safe_map(read, eqn.invars)\n",
|
||||
" # `bind` is how a primitive is called\n",
|
||||
" outvals = eqn.primitive.bind(*invals, **eqn.params)\n",
|
||||
" # Primitives may return multiple outputs or not\n",
|
||||
" if not eqn.primitive.multiple_results: \n",
|
||||
" if not eqn.primitive.multiple_results:\n",
|
||||
" outvals = [outvals]\n",
|
||||
" # Write the results of the primitive into the environment\n",
|
||||
" safe_map(write, eqn.outvars, outvals) \n",
|
||||
" safe_map(write, eqn.outvars, outvals)\n",
|
||||
" # Read the final result of the Jaxpr from the environment\n",
|
||||
" return safe_map(read, jaxpr.outvars) "
|
||||
" return safe_map(read, jaxpr.outvars)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -417,7 +415,7 @@
|
||||
"source": [
|
||||
"def inverse_jaxpr(jaxpr, consts, *args):\n",
|
||||
" env = {}\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def read(var):\n",
|
||||
" if type(var) is core.Literal:\n",
|
||||
" return var.val\n",
|
||||
@ -431,12 +429,12 @@
|
||||
"\n",
|
||||
" # Looping backward\n",
|
||||
" for eqn in jaxpr.eqns[::-1]:\n",
|
||||
" # outvars are now invars \n",
|
||||
" # outvars are now invars\n",
|
||||
" invals = safe_map(read, eqn.outvars)\n",
|
||||
" if eqn.primitive not in inverse_registry:\n",
|
||||
" raise NotImplementedError(\n",
|
||||
" f\"{eqn.primitive} does not have registered inverse.\")\n",
|
||||
" # Assuming a unary function \n",
|
||||
" # Assuming a unary function\n",
|
||||
" outval = inverse_registry[eqn.primitive](*invals)\n",
|
||||
" safe_map(write, eqn.invars, [outval])\n",
|
||||
" return safe_map(read, jaxpr.invars)"
|
||||
|
@ -32,7 +32,6 @@ Here we show how to add your own function transformations to the system, by writ
|
||||
```{code-cell} ipython3
|
||||
:id: s27RDKvKXFL8
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, grad, vmap
|
||||
@ -146,7 +145,6 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr.
|
||||
:id: BHkg_3P1pXJj
|
||||
|
||||
# Importing Jax functions useful for tracing/interpreting.
|
||||
import numpy as np
|
||||
from functools import wraps
|
||||
|
||||
from jax import core
|
||||
@ -185,7 +183,7 @@ To do this, we first create an environment to store the values for each of the v
|
||||
def eval_jaxpr(jaxpr, consts, *args):
|
||||
# Mapping from variable -> value
|
||||
env = {}
|
||||
|
||||
|
||||
def read(var):
|
||||
# Literals are values baked into the Jaxpr
|
||||
if type(var) is core.Literal:
|
||||
@ -202,16 +200,16 @@ def eval_jaxpr(jaxpr, consts, *args):
|
||||
# Loop through equations and evaluate primitives using `bind`
|
||||
for eqn in jaxpr.eqns:
|
||||
# Read inputs to equation from environment
|
||||
invals = safe_map(read, eqn.invars)
|
||||
invals = safe_map(read, eqn.invars)
|
||||
# `bind` is how a primitive is called
|
||||
outvals = eqn.primitive.bind(*invals, **eqn.params)
|
||||
# Primitives may return multiple outputs or not
|
||||
if not eqn.primitive.multiple_results:
|
||||
if not eqn.primitive.multiple_results:
|
||||
outvals = [outvals]
|
||||
# Write the results of the primitive into the environment
|
||||
safe_map(write, eqn.outvars, outvals)
|
||||
safe_map(write, eqn.outvars, outvals)
|
||||
# Read the final result of the Jaxpr from the environment
|
||||
return safe_map(read, jaxpr.outvars)
|
||||
return safe_map(read, jaxpr.outvars)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
@ -279,7 +277,7 @@ Now we just need to define `inverse_jaxpr`, which will walk through the Jaxpr ba
|
||||
|
||||
def inverse_jaxpr(jaxpr, consts, *args):
|
||||
env = {}
|
||||
|
||||
|
||||
def read(var):
|
||||
if type(var) is core.Literal:
|
||||
return var.val
|
||||
@ -293,12 +291,12 @@ def inverse_jaxpr(jaxpr, consts, *args):
|
||||
|
||||
# Looping backward
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
# outvars are now invars
|
||||
# outvars are now invars
|
||||
invals = safe_map(read, eqn.outvars)
|
||||
if eqn.primitive not in inverse_registry:
|
||||
raise NotImplementedError(
|
||||
f"{eqn.primitive} does not have registered inverse.")
|
||||
# Assuming a unary function
|
||||
# Assuming a unary function
|
||||
outval = inverse_registry[eqn.primitive](*invals)
|
||||
safe_map(write, eqn.invars, [outval])
|
||||
return safe_map(read, jaxpr.invars)
|
||||
|
@ -1148,7 +1148,7 @@
|
||||
" y, vjp_fun = vjp(f, x)\n",
|
||||
" # Use vmap to do a matrix-Jacobian product.\n",
|
||||
" # Here, the matrix is the Euclidean basis, so we get all\n",
|
||||
" # entries in the Jacobian at once. \n",
|
||||
" # entries in the Jacobian at once.\n",
|
||||
" J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))\n",
|
||||
" return J\n",
|
||||
" return jacfun\n",
|
||||
@ -1169,7 +1169,7 @@
|
||||
"def our_jacfwd(f):\n",
|
||||
" def jacfun(x):\n",
|
||||
" _jvp = lambda s: jvp(f, (x,), (s,))[1]\n",
|
||||
" Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n",
|
||||
" Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n",
|
||||
" return jnp.transpose(Jt)\n",
|
||||
" return jacfun\n",
|
||||
"\n",
|
||||
|
@ -675,7 +675,7 @@ def our_jacrev(f):
|
||||
y, vjp_fun = vjp(f, x)
|
||||
# Use vmap to do a matrix-Jacobian product.
|
||||
# Here, the matrix is the Euclidean basis, so we get all
|
||||
# entries in the Jacobian at once.
|
||||
# entries in the Jacobian at once.
|
||||
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
|
||||
return J
|
||||
return jacfun
|
||||
@ -691,7 +691,7 @@ from jax import jacfwd as builtin_jacfwd
|
||||
def our_jacfwd(f):
|
||||
def jacfun(x):
|
||||
_jvp = lambda s: jvp(f, (x,), (s,))[1]
|
||||
Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
|
||||
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
|
||||
return jnp.transpose(Jt)
|
||||
return jacfun
|
||||
|
||||
|
@ -739,8 +739,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.ad_checkpoint import checkpoint_name\n",
|
||||
"\n",
|
||||
"def predict(params, x):\n",
|
||||
" *Ws, Wlast = params\n",
|
||||
" for i, W in enumerate(Ws):\n",
|
||||
|
@ -370,8 +370,6 @@ Notice also that by providing a policy, we didn't need to edit the code defining
|
||||
Some policies can refer to values named with `jax.ad_checkpoint.checkpoint_name`:
|
||||
|
||||
```{code-cell}
|
||||
from jax.ad_checkpoint import checkpoint_name
|
||||
|
||||
def predict(params, x):
|
||||
*Ws, Wlast = params
|
||||
for i, W in enumerate(Ws):
|
||||
|
@ -410,7 +410,7 @@
|
||||
],
|
||||
"source": [
|
||||
"dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape\n",
|
||||
" kernel.shape, # only ndim matters, not shape \n",
|
||||
" kernel.shape, # only ndim matters, not shape\n",
|
||||
" ('NHWC', 'HWIO', 'NHWC')) # the important bit\n",
|
||||
"print(dn)"
|
||||
]
|
||||
@ -806,8 +806,8 @@
|
||||
],
|
||||
"source": [
|
||||
"# 1D kernel - WIO layout\n",
|
||||
"kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], \n",
|
||||
" [[1, 1, 1], [-1, -1, -1]]], \n",
|
||||
"kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],\n",
|
||||
" [[1, 1, 1], [-1, -1, -1]]],\n",
|
||||
" dtype=jnp.float32).transpose([2,1,0])\n",
|
||||
"# 1D data - NWC layout\n",
|
||||
"data = np.zeros((1, 200, 2), dtype=jnp.float32)\n",
|
||||
@ -895,8 +895,8 @@
|
||||
"# Random 3D kernel - HWDIO layout\n",
|
||||
"kernel = jnp.array([\n",
|
||||
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n",
|
||||
" [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n",
|
||||
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n",
|
||||
" [[0, -1, 0], [-1, 0, -1], [0, -1, 0]],\n",
|
||||
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]]],\n",
|
||||
" dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]\n",
|
||||
"\n",
|
||||
"# 3D data - NHWDC layout\n",
|
||||
@ -919,7 +919,6 @@
|
||||
"print(\"out shape: \", out.shape)\n",
|
||||
"\n",
|
||||
"# Make some simple 3d density plots:\n",
|
||||
"from mpl_toolkits.mplot3d import Axes3D\n",
|
||||
"def make_alpha(cmap):\n",
|
||||
" my_cmap = cmap(jnp.arange(cmap.N))\n",
|
||||
" my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3\n",
|
||||
|
@ -210,7 +210,7 @@ The important argument is the 3-tuple of axis layout arguments:
|
||||
:outputId: d5a569b3-febc-4832-f725-1d5e8fd31b9b
|
||||
|
||||
dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape
|
||||
kernel.shape, # only ndim matters, not shape
|
||||
kernel.shape, # only ndim matters, not shape
|
||||
('NHWC', 'HWIO', 'NHWC')) # the important bit
|
||||
print(dn)
|
||||
```
|
||||
@ -363,8 +363,8 @@ You aren't limited to 2D convolutions, a simple 1D demo is below:
|
||||
:outputId: 67c46ace-6adc-4c47-c1c7-1f185be5fd4b
|
||||
|
||||
# 1D kernel - WIO layout
|
||||
kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],
|
||||
[[1, 1, 1], [-1, -1, -1]]],
|
||||
kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],
|
||||
[[1, 1, 1], [-1, -1, -1]]],
|
||||
dtype=jnp.float32).transpose([2,1,0])
|
||||
# 1D data - NWC layout
|
||||
data = np.zeros((1, 200, 2), dtype=jnp.float32)
|
||||
@ -406,8 +406,8 @@ import matplotlib as mpl
|
||||
# Random 3D kernel - HWDIO layout
|
||||
kernel = jnp.array([
|
||||
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
|
||||
[[0, -1, 0], [-1, 0, -1], [0, -1, 0]],
|
||||
[[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
|
||||
[[0, -1, 0], [-1, 0, -1], [0, -1, 0]],
|
||||
[[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
|
||||
dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]
|
||||
|
||||
# 3D data - NHWDC layout
|
||||
@ -430,7 +430,6 @@ out = lax.conv_general_dilated(data, # lhs = image tensor
|
||||
print("out shape: ", out.shape)
|
||||
|
||||
# Make some simple 3d density plots:
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
def make_alpha(cmap):
|
||||
my_cmap = cmap(jnp.arange(cmap.N))
|
||||
my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3
|
||||
|
@ -840,7 +840,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from functools import partial\n",
|
||||
"j1 = partial(jv, 1)\n",
|
||||
"z = jnp.arange(5.0)"
|
||||
]
|
||||
|
@ -410,7 +410,6 @@ This lets us call into `scipy.special.jv` from transformed JAX code, including w
|
||||
```{code-cell}
|
||||
:id: f4e46670f4e4
|
||||
|
||||
from functools import partial
|
||||
j1 = partial(jv, 1)
|
||||
z = jnp.arange(5.0)
|
||||
```
|
||||
|
@ -132,7 +132,7 @@
|
||||
" for w, b in params[:-1]:\n",
|
||||
" outputs = jnp.dot(w, activations) + b\n",
|
||||
" activations = relu(outputs)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" final_w, final_b = params[-1]\n",
|
||||
" logits = jnp.dot(final_w, activations) + final_b\n",
|
||||
" return logits - logsumexp(logits)"
|
||||
@ -251,7 +251,7 @@
|
||||
"def one_hot(x, k, dtype=jnp.float32):\n",
|
||||
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
|
||||
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"def accuracy(params, images, targets):\n",
|
||||
" target_class = jnp.argmax(targets, axis=1)\n",
|
||||
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",
|
||||
|
@ -104,7 +104,7 @@ def predict(params, image):
|
||||
for w, b in params[:-1]:
|
||||
outputs = jnp.dot(w, activations) + b
|
||||
activations = relu(outputs)
|
||||
|
||||
|
||||
final_w, final_b = params[-1]
|
||||
logits = jnp.dot(final_w, activations) + final_b
|
||||
return logits - logsumexp(logits)
|
||||
@ -164,7 +164,7 @@ At this point, we have all the ingredients we need to define our neural network
|
||||
def one_hot(x, k, dtype=jnp.float32):
|
||||
"""Create a one-hot encoding of x of size k."""
|
||||
return jnp.array(x[:, None] == jnp.arange(k), dtype)
|
||||
|
||||
|
||||
def accuracy(params, images, targets):
|
||||
target_class = jnp.argmax(targets, axis=1)
|
||||
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
|
||||
|
@ -25,17 +25,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import functools\n",
|
||||
"import itertools\n",
|
||||
"import re\n",
|
||||
"import sys\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"from matplotlib.pyplot import *\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import jax\n",
|
||||
"\n",
|
||||
"from jax import lax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import jax.scipy as jsp\n",
|
||||
"from jax import random\n",
|
||||
@ -348,7 +341,7 @@
|
||||
"def elbo(beta_loc, beta_log_scale, epsilon):\n",
|
||||
" beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n",
|
||||
" return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"elbo = jax.jit(elbo)\n",
|
||||
"elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))"
|
||||
]
|
||||
@ -548,25 +541,16 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"figure(figsize=(7, 7))\n",
|
||||
"plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n",
|
||||
"plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n",
|
||||
"plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n",
|
||||
"plt.figure(figsize=(7, 7))\n",
|
||||
"plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n",
|
||||
"plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\\sigma$ Error Bars')\n",
|
||||
"plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n",
|
||||
"plot_scale = 3\n",
|
||||
"plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n",
|
||||
"xlabel('True beta')\n",
|
||||
"ylabel('Estimated beta')\n",
|
||||
"legend(loc='best')"
|
||||
"plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n",
|
||||
"plt.xlabel('True beta')\n",
|
||||
"plt.ylabel('Estimated beta')\n",
|
||||
"plt.legend(loc='best')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "_bXdOlvUEJl0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -27,17 +27,10 @@ Inspired by a notebook by @davmre.
|
||||
```{code-cell} ipython3
|
||||
:id: 8RZDkfbV3zdR
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
from matplotlib.pyplot import *
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import jax
|
||||
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
from jax import random
|
||||
@ -192,7 +185,7 @@ batched_log_joint = jax.jit(jax.vmap(log_joint))
|
||||
def elbo(beta_loc, beta_log_scale, epsilon):
|
||||
beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
|
||||
return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))
|
||||
|
||||
|
||||
elbo = jax.jit(elbo)
|
||||
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
|
||||
```
|
||||
@ -240,19 +233,13 @@ Coverage isn't quite as good as we might like, but it's not bad, and nobody said
|
||||
:id: zt1NBLoVHtOG
|
||||
:outputId: fb159795-e6e7-497c-e501-9933ec761af4
|
||||
|
||||
figure(figsize=(7, 7))
|
||||
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
|
||||
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
|
||||
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
|
||||
plt.figure(figsize=(7, 7))
|
||||
plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
|
||||
plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\sigma$ Error Bars')
|
||||
plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
|
||||
plot_scale = 3
|
||||
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
|
||||
xlabel('True beta')
|
||||
ylabel('Estimated beta')
|
||||
legend(loc='best')
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: _bXdOlvUEJl0
|
||||
|
||||
|
||||
plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
|
||||
plt.xlabel('True beta')
|
||||
plt.ylabel('Estimated beta')
|
||||
plt.legend(loc='best')
|
||||
```
|
||||
|
@ -189,9 +189,6 @@
|
||||
],
|
||||
"source": [
|
||||
"# Pardon the boilerplate; constructing a sharding will become easier in future!\n",
|
||||
"from jax.sharding import Mesh\n",
|
||||
"from jax.sharding import PartitionSpec\n",
|
||||
"from jax.sharding import NamedSharding\n",
|
||||
"from jax.experimental import mesh_utils\n",
|
||||
"\n",
|
||||
"P = jax.sharding.PartitionSpec\n",
|
||||
|
@ -73,9 +73,6 @@ Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimens
|
||||
:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5
|
||||
|
||||
# Pardon the boilerplate; constructing a sharding will become easier in future!
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.experimental import mesh_utils
|
||||
|
||||
P = jax.sharding.PartitionSpec
|
||||
|
@ -1268,7 +1268,7 @@ def new_base_main(trace_type: type[Trace],
|
||||
@contextmanager
|
||||
def pop_level(level: int):
|
||||
if level == 0:
|
||||
return (yield)
|
||||
return (yield) # noqa: B901
|
||||
prev, thread_local_state.trace_state.trace_stack.stack = \
|
||||
thread_local_state.trace_state.trace_stack.stack, \
|
||||
thread_local_state.trace_state.trace_stack.stack[:level]
|
||||
|
@ -760,7 +760,7 @@ def _dot_product_attention_fwd_partition(
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training,
|
||||
mesh, arg_shapes, result_shape):
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes)
|
||||
out_shardings = _infer_fwd_output_sharding(
|
||||
mesh, arg_shapes, variadic_args, is_training)
|
||||
impl = functools.partial(
|
||||
@ -810,7 +810,7 @@ def _dot_product_attention_bwd_partition(
|
||||
arg_shapes, result_shape):
|
||||
out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes)
|
||||
def sharded_impl(*args):
|
||||
impl = functools.partial(
|
||||
_dot_product_attention_bwd_impl,
|
||||
|
@ -238,7 +238,7 @@ class BufferedRef:
|
||||
Returns:
|
||||
Initialized BufferedRef
|
||||
"""
|
||||
block_shape = tuple([1 if x is None else x for x in spec.block_shape])
|
||||
block_shape = tuple(1 if x is None else x for x in spec.block_shape)
|
||||
if buffer_type is BufferType.ACCUMULATOR:
|
||||
accum_ref = VMEM(block_shape, dtype)
|
||||
else:
|
||||
@ -310,7 +310,7 @@ class BufferedRef:
|
||||
@property
|
||||
def current_ref(self):
|
||||
buffer_slice = tuple(
|
||||
[0 if x is None else slice(None) for x in self.block_shape])
|
||||
0 if x is None else slice(None) for x in self.block_shape)
|
||||
if self.memory_space == VMEM:
|
||||
return self.vmem_ref.at[buffer_slice]
|
||||
else:
|
||||
@ -349,7 +349,7 @@ class BufferedRef:
|
||||
|
||||
def compute_slice(self, grid_indices):
|
||||
"""Compute DMA slice from grid indices."""
|
||||
block_shape = tuple([1 if x is None else x for x in self.block_shape])
|
||||
block_shape = tuple(1 if x is None else x for x in self.block_shape)
|
||||
indices = self.compute_index(*grid_indices)
|
||||
return jax.tree.map(_make_ds, indices, block_shape)
|
||||
|
||||
|
@ -172,7 +172,7 @@ def _normalize_tolerance(tol):
|
||||
if isinstance(tol, dict):
|
||||
return {np.dtype(k): v for k, v in tol.items()}
|
||||
else:
|
||||
return {k: tol for k in _default_tolerance}
|
||||
return dict.fromkeys(_default_tolerance, tol)
|
||||
|
||||
def join_tolerance(tol1, tol2):
|
||||
tol1 = _normalize_tolerance(tol1)
|
||||
|
@ -115,6 +115,8 @@ ignore = [
|
||||
"C408",
|
||||
# Unnecessary map usage
|
||||
"C417",
|
||||
# Unnecessary dict comprehension for iterable
|
||||
"C420",
|
||||
# Object names too complex
|
||||
"C901",
|
||||
# Local variable is assigned to but never used
|
||||
@ -141,7 +143,16 @@ max-complexity = 18
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# F811: Redefinition of unused name.
|
||||
# F821: Undefined name.
|
||||
"docs/autodidax.py" = ["F811"]
|
||||
"docs/pallas/tpu/matmul.ipynb" = ["F811"]
|
||||
"docs/pallas/tpu/distributed.ipynb" = ["F811"]
|
||||
"docs/pallas/quickstart.ipynb" = ["F811"]
|
||||
"docs/notebooks/autodiff_cookbook.ipynb" = ["F811", "F821"]
|
||||
"docs/notebooks/autodiff_remat.ipynb" = ["F811", "F821"]
|
||||
"docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb" = ["F811"]
|
||||
"docs/jep/9407-type-promotion.ipynb" = ["F811"]
|
||||
"docs/autodidax.ipynb" = ["F811"]
|
||||
# Note: we don't use jax/*.py because this matches contents of jax/_src
|
||||
"__init__.py" = ["F401"]
|
||||
"jax/abstract_arrays.py" = ["F401"]
|
||||
|
@ -308,7 +308,7 @@ class MemRefTest(TestCase):
|
||||
expanded_shape = get_packed_shape(strides, shape)
|
||||
total_size = np.prod(expanded_shape)
|
||||
np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape)
|
||||
index = tuple([slice(0, s) for s in shape])
|
||||
index = tuple(slice(0, s) for s in shape)
|
||||
|
||||
# Reference implementation
|
||||
def np_fold(inp, dim, fold_rank):
|
||||
|
@ -88,15 +88,6 @@
|
||||
"height": 68
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"from jaxlib import xla_extension\n",
|
||||
"import jax\n",
|
||||
"key = jax.random.PRNGKey(1701)\n",
|
||||
"arr = jax.random.normal(key, (1000,))\n",
|
||||
"device = arr.device()\n",
|
||||
"print(f\"JAX device type: {device}\")\n",
|
||||
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
|
||||
],
|
||||
"execution_count": 2,
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -500,11 +500,11 @@ class IndexerOpsTest(PallasBaseTest):
|
||||
def body(x_ref, y_ref1, y_ref2):
|
||||
if slice_type == "slice":
|
||||
slices = tuple(
|
||||
[slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides)]
|
||||
slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides)
|
||||
)
|
||||
else:
|
||||
slices = tuple(
|
||||
[pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides)]
|
||||
pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides)
|
||||
)
|
||||
if indexer_type == "state":
|
||||
y_ref1[...] = x_ref[slices]
|
||||
|
Loading…
x
Reference in New Issue
Block a user