CI: update ruff to v0.6.1

This commit is contained in:
Jake VanderPlas 2024-08-27 14:54:11 -07:00
parent afff0e09aa
commit 68be5b5085
41 changed files with 241 additions and 314 deletions

View File

@ -26,7 +26,7 @@ repos:
files: \.py$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.6.1
hooks:
- id: ruff

View File

@ -510,8 +510,8 @@
"outputs": [],
"source": [
"image_partitions = P(1, 1, 4, 2)\n",
"sharded_conv = sharded_jit(conv, \n",
" in_parts=(image_partitions, None), \n",
"sharded_conv = sharded_jit(conv,\n",
" in_parts=(image_partitions, None),\n",
" out_parts=image_partitions)\n",
"\n",
"sharded_conv(image, kernel)"

View File

@ -877,7 +877,7 @@
" def g(z):\n",
" return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n",
" return grad(lambda w: jnp.sum(g(w)))(x)\n",
" \n",
"\n",
"f(x)"
]
},
@ -950,17 +950,6 @@
"per_example_hess = pmap(input_hess) # pmap!\n",
"per_example_hess(inputs)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "u3ggM_WYZ8QC"
},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@ -67,7 +67,6 @@
"source": [
"from functools import partial\n",
"import jax\n",
"from jax import jit, pmap\n",
"from jax import lax\n",
"from jax import tree_util\n",
"import jax.numpy as jnp\n",

View File

@ -640,7 +640,7 @@ def our_jacrev(f):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
@ -654,7 +654,7 @@ from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun

View File

@ -3317,7 +3317,6 @@
],
"source": [
"# @title\n",
"from jax import dtypes\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import pandas as pd\n",
@ -3802,7 +3801,6 @@
],
"source": [
"# @title\n",
"from jax import dtypes\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import pandas as pd\n",

View File

@ -908,7 +908,6 @@ display.HTML(table.to_html())
:tags: [hide-input]
# @title
from jax import dtypes
import jax
import jax.numpy as jnp
import pandas as pd
@ -963,7 +962,6 @@ display.HTML(table.to_html())
:tags: [hide-input]
# @title
from jax import dtypes
import jax
import jax.numpy as jnp
import pandas as pd

View File

@ -226,7 +226,6 @@
],
"source": [
"import jax.numpy as jnp\n",
"import jax.lax as lax\n",
"from jax import make_jaxpr\n",
"\n",
"# lax.fori_loop\n",
@ -1031,7 +1030,6 @@
}
],
"source": [
"from jax import random\n",
"key = random.key(0)\n",
"key"
]
@ -1105,8 +1103,8 @@
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
"print(\" \\---SPLIT --> new key \", key)\n",
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
"print(r\" \\---SPLIT --> new key \", key)\n",
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
@ -1140,8 +1138,8 @@
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
"print(\" \\---SPLIT --> new key \", key)\n",
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
"print(r\" \\---SPLIT --> new key \", key)\n",
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
@ -1701,7 +1699,7 @@
],
"source": [
"init_val = 0\n",
"cond_fun = lambda x: x<10\n",
"cond_fun = lambda x: x < 10\n",
"body_fun = lambda x: x+1\n",
"lax.while_loop(cond_fun, body_fun, init_val)\n",
"# --> array(10, dtype=int32)"

View File

@ -130,7 +130,6 @@ It is not recommended to use iterators in any JAX function you want to `jit` or
:outputId: 52d885fd-0239-4a08-f5ce-0c38cc008903
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr
# lax.fori_loop
@ -471,7 +470,6 @@ The random state is described by a special array element that we call a __key__:
:id: yPHE7KTWgAWs
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
from jax import random
key = random.key(0)
key
```
@ -504,8 +502,8 @@ Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a ne
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
print(r" \---SPLIT --> new key ", key)
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
```
+++ {"id": "tqtFVE4MthO3"}
@ -519,8 +517,8 @@ We propagate the __key__ and make new __subkeys__ whenever we need a new random
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
print(r" \---SPLIT --> new key ", key)
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
```
+++ {"id": "0KLYUluz3lN3"}
@ -805,7 +803,7 @@ def while_loop(cond_fun, body_fun, init_val):
:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
init_val = 0
cond_fun = lambda x: x<10
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)

View File

@ -247,7 +247,6 @@
}
],
"source": [
"import jax.numpy as jnp\n",
"\n",
"def log1pexp(x):\n",
" return jnp.log(1. + jnp.exp(x))\n",
@ -984,7 +983,7 @@
" (a, x_star, x_star_bar),\n",
" x_star_bar))\n",
" return a_bar, jnp.zeros_like(x_star)\n",
" \n",
"\n",
"def rev_iter(f, packed, u):\n",
" a, x_star, x_star_bar = packed\n",
" _, vjp_x = vjp(lambda x: f(a, x), x_star)\n",
@ -1884,7 +1883,6 @@
}
],
"source": [
"from jax import vjp\n",
"\n",
"y, f_vjp = vjp(f, 3.)\n",
"print(y)"
@ -1983,7 +1981,7 @@
" return x, x\n",
"\n",
"def debug_bwd(x, g):\n",
" import pdb; pdb.set_trace()\n",
" pdb.set_trace()\n",
" return g\n",
"\n",
"debug.defvjp(debug_fwd, debug_bwd)"

View File

@ -145,7 +145,6 @@ Say we want to write a function called `log1pexp`, which computes $x \mapsto \lo
:id: 6lWbTvs40ET-
:outputId: 8caff99e-add1-4c70-ace3-212c0c5c6f4e
import jax.numpy as jnp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
@ -524,7 +523,7 @@ def fixed_point_rev(f, res, x_star_bar):
(a, x_star, x_star_bar),
x_star_bar))
return a_bar, jnp.zeros_like(x_star)
def rev_iter(f, packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: f(a, x), x_star)
@ -965,7 +964,6 @@ print(grad(f)(3.))
:id: s1Pn_qCIODcF
:outputId: 423d34e0-35b8-4b57-e89d-f70f20e28ea9
from jax import vjp
y, f_vjp = vjp(f, 3.)
print(y)
@ -1015,7 +1013,7 @@ def debug_fwd(x):
return x, x
def debug_bwd(x, g):
import pdb; pdb.set_trace()
pdb.set_trace()
return g
debug.defvjp(debug_fwd, debug_bwd)

View File

@ -30,9 +30,7 @@
},
"outputs": [],
"source": [
"import os\n",
"\n",
"import functools\n",
"from typing import Optional\n",
"\n",
"import numpy as np\n",

View File

@ -26,9 +26,7 @@ This tutorial discusses parallelism via `jax.Array`, the unified array object mo
```{code-cell}
:id: FNxScTfq3vGF
import os
import functools
from typing import Optional
import numpy as np

View File

@ -15,12 +15,12 @@
"*necula@google.com*, October 2019.\n",
"\n",
"JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,\n",
"`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, \n",
"`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,\n",
"which means that as the Python function executes\n",
"the only operations it applies to the data are either inspections of data\n",
"attributes such as shape or type, or special operations called JAX primitives.\n",
"In particular, a JAX-traceable function is sometimes invoked by JAX with\n",
"abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, \n",
"abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,\n",
"which captures the type and the shape of values, but not the concrete data values.\n",
"JAX primitives know how to operate on both concrete data\n",
"values and on the JAX abstract values.\n",
@ -30,7 +30,7 @@
"to ensure that these transformations\n",
"can be composed, e.g., `jit(jacfwd(grad(f)))`.\n",
"\n",
"There are pre-defined JAX primitives corresponding to most XLA operations, \n",
"There are pre-defined JAX primitives corresponding to most XLA operations,\n",
"e.g., add, matmul, sin, cos, indexing.\n",
"JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs\n",
"using JAXs implementation of numpy are JAX-traceable and therefore transformable.\n",
@ -42,8 +42,8 @@
"**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**\n",
"\n",
"Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n",
"as \"multiply_add(x, y, z) = x * y + z\". \n",
"This function operates on 3 identically-shaped tensors of floating point \n",
"as \"multiply_add(x, y, z) = x * y + z\".\n",
"This function operates on 3 identically-shaped tensors of floating point\n",
"values and performs the operations pointwise."
]
},
@ -56,7 +56,7 @@
"## Using existing primitives\n",
"\n",
"The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other\n",
"functions that are themselves written using JAX primitives, e.g., those \n",
"functions that are themselves written using JAX primitives, e.g., those\n",
"defined in the `jax.lax` module:"
]
},
@ -165,7 +165,7 @@
" return str(v)\n",
" def pp_values(args):\n",
" return \", \".join([pp(arg) for arg in args])\n",
" \n",
"\n",
" @functools.wraps(func)\n",
" def func_wrapper(*args):\n",
" _trace_indent(\"call {}({})\".format(name, pp_values(args)))\n",
@ -199,7 +199,7 @@
"id": "Qf4eLrLCFYDl"
},
"source": [
"Instead of using `jax.lax` primitives directly, we can use other functions \n",
"Instead of using `jax.lax` primitives directly, we can use other functions\n",
"that are already written in terms of those primitives, such as those in `jax.numpy`:"
]
},
@ -244,7 +244,7 @@
"def square_add_numpy(a, b):\n",
" return multiply_add_numpy(a, a, b)\n",
"\n",
"print(\"\\nNormal evaluation:\") \n",
"print(\"\\nNormal evaluation:\")\n",
"print(\"square_add_numpy = \", square_add_numpy(2., 10.))\n",
"print(\"\\nGradient evaluation:\")\n",
"print(\"grad(square_add_numpy) = \", api.grad(square_add_numpy)(2.0, 10.))"
@ -257,13 +257,13 @@
},
"source": [
"Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and\n",
"`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further \n",
"below in this colab). \n",
"It is important to remember that a JAX-traceable function must be able to \n",
"`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further\n",
"below in this colab).\n",
"It is important to remember that a JAX-traceable function must be able to\n",
"operate not only on concrete arguments but also on special abstract arguments\n",
"that JAX may use to abstract the function execution.\n",
"\n",
"The JAX traceability property is satisfied as long as the function is written \n",
"The JAX traceability property is satisfied as long as the function is written\n",
"in terms of JAX primitives."
]
},
@ -277,7 +277,7 @@
"\n",
"The right way to add support for multiply-add is in terms of existing\n",
"JAX primitives, as shown above. However, in order to demonstrate how JAX\n",
"primitives work let us pretend that we want to add a new primitive to \n",
"primitives work let us pretend that we want to add a new primitive to\n",
"JAX for the multiply-add functionality."
]
},
@ -295,9 +295,9 @@
"@trace(\"multiply_add_prim\")\n",
"def multiply_add_prim(x, y, z):\n",
" \"\"\"The JAX-traceable way to use the JAX primitive.\n",
" \n",
"\n",
" Note that the traced arguments must be passed as positional arguments\n",
" to `bind`. \n",
" to `bind`.\n",
" \"\"\"\n",
" return multiply_add_p.bind(x, y, z)\n",
"\n",
@ -392,7 +392,7 @@
"\n",
" This function does not need to be JAX traceable.\n",
" Args:\n",
" x, y, z: the concrete arguments of the primitive. Will only be called with \n",
" x, y, z: the concrete arguments of the primitive. Will only be called with\n",
" concrete values.\n",
" Returns:\n",
" the concrete result of the primitive.\n",
@ -485,17 +485,17 @@
},
"source": [
"#### Abstract evaluation rules\n",
"In order to JIT the function, and for other transformations as well, \n",
"JAX first evaluates it abstractly using only the \n",
"In order to JIT the function, and for other transformations as well,\n",
"JAX first evaluates it abstractly using only the\n",
"shape and type of the arguments. This abstract evaluation serves multiple\n",
"purposes:\n",
"\n",
" * Gets the sequence of JAX primitives that are used in the computation. This \n",
" sequence will be compiled. \n",
" * Computes the shape and type of all vectors and operations used in the computation. \n",
" * Gets the sequence of JAX primitives that are used in the computation. This\n",
" sequence will be compiled.\n",
" * Computes the shape and type of all vectors and operations used in the computation.\n",
"\n",
"\n",
"For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. \n",
"For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.\n",
"In the latter case, JAX uses the actual concrete value wrapped as an abstract value."
]
},
@ -527,7 +527,7 @@
" \"\"\"Abstract evaluation of the primitive.\n",
"\n",
" This function does not need to be JAX traceable. It will be invoked with\n",
" abstractions of the actual arguments. \n",
" abstractions of the actual arguments.\n",
" Args:\n",
" xs, ys, zs: abstractions of the arguments.\n",
" Result:\n",
@ -603,7 +603,7 @@
"\n",
"JAX compilation works by compiling each primitive into a graph of XLA operations.\n",
"\n",
"This is the biggest hurdle to adding new functionality to JAX, because the \n",
"This is the biggest hurdle to adding new functionality to JAX, because the\n",
"set of XLA operations is limited, and JAX already has pre-defined primitives\n",
"for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++."
]
@ -642,7 +642,7 @@
},
"source": [
"Now we succeed to JIT. Notice below that JAX first evaluates the function\n",
"abstractly, which triggers the `multiply_add_abstract_eval` function, and \n",
"abstractly, which triggers the `multiply_add_abstract_eval` function, and\n",
"then compiles the set of primitives it has encountered, including `multiply_add`.\n",
"At this point JAX invokes `multiply_add_xla_translation`."
]
@ -682,7 +682,7 @@
"source": [
"Below is another use of `jit` where we compile only\n",
"with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads\n",
"in the third argument to `multiply_add_abstract_eval` being \n",
"in the third argument to `multiply_add_abstract_eval` being\n",
"`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with\n",
"both `ShapedArray` and `ConcreteArray`."
]
@ -711,7 +711,7 @@
}
],
"source": [
"assert api.jit(lambda x, y: square_add_prim(x, y), \n",
"assert api.jit(lambda x, y: square_add_prim(x, y),\n",
" static_argnums=1)(2., 10.) == 14."
]
},
@ -794,16 +794,16 @@
"def multiply_add_value_and_jvp(arg_values, arg_tangents):\n",
" \"\"\"Evaluates the primal output and the tangents (Jacobian-vector product).\n",
"\n",
" Given values of the arguments and perturbation of the arguments (tangents), \n",
" Given values of the arguments and perturbation of the arguments (tangents),\n",
" compute the output of the primitive and the perturbation of the output.\n",
"\n",
" This method must be JAX-traceable. JAX may invoke it with abstract values \n",
" This method must be JAX-traceable. JAX may invoke it with abstract values\n",
" for the arguments and tangents.\n",
"\n",
" Args:\n",
" arg_values: a tuple of arguments\n",
" arg_tangents: a tuple with the tangents of the arguments. The tuple has \n",
" the same length as the arg_values. Some of the tangents may also be the \n",
" arg_tangents: a tuple with the tangents of the arguments. The tuple has\n",
" the same length as the arg_values. Some of the tangents may also be the\n",
" special value ad.Zero to specify a zero tangent.\n",
" Returns:\n",
" a pair of the primal output and the tangent.\n",
@ -811,26 +811,26 @@
" x, y, z = arg_values\n",
" xt, yt, zt = arg_tangents\n",
" _trace(\"Primal evaluation:\")\n",
" # Now we have a JAX-traceable computation of the output. \n",
" # Normally, we can use the ma primitive itself to compute the primal output. \n",
" # Now we have a JAX-traceable computation of the output.\n",
" # Normally, we can use the ma primitive itself to compute the primal output.\n",
" primal_out = multiply_add_prim(x, y, z)\n",
" \n",
"\n",
" _trace(\"Tangent evaluation:\")\n",
" # We must use a JAX-traceable way to compute the tangent. It turns out that \n",
" # We must use a JAX-traceable way to compute the tangent. It turns out that\n",
" # the output tangent can be computed as (xt * y + x * yt + zt),\n",
" # which we can implement in a JAX-traceable way using the same \"multiply_add_prim\" primitive.\n",
" \n",
" # We do need to deal specially with Zero. Here we just turn it into a \n",
" # proper tensor of 0s (of the same shape as 'x'). \n",
" # An alternative would be to check for Zero and perform algebraic \n",
"\n",
" # We do need to deal specially with Zero. Here we just turn it into a\n",
" # proper tensor of 0s (of the same shape as 'x').\n",
" # An alternative would be to check for Zero and perform algebraic\n",
" # simplification of the output tangent computation.\n",
" def make_zero(tan):\n",
" return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan \n",
" \n",
" return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan\n",
"\n",
" output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))\n",
" return (primal_out, output_tangent)\n",
"\n",
"# Register the forward differentiation rule with JAX \n",
"# Register the forward differentiation rule with JAX\n",
"ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp"
]
},
@ -880,7 +880,7 @@
"id": "69QsEcu-lP4u"
},
"source": [
"TO EXPLAIN: \n",
"TO EXPLAIN:\n",
"\n",
" * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n",
" * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n",
@ -941,7 +941,7 @@
}
],
"source": [
"assert api.jit(lambda arg_values, arg_tangents: \n",
"assert api.jit(lambda arg_values, arg_tangents:\n",
" api.jvp(square_add_prim, arg_values, arg_tangents))(\n",
" (2., 10.), (1., 1.)) == (14., 5.)"
]
@ -953,7 +953,7 @@
},
"source": [
"Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn\n",
"evaluates abstractly both the primal and the tangent evaluation (a total of \n",
"evaluates abstractly both the primal and the tangent evaluation (a total of\n",
"3 invocations of the `ma` primitive). Then we compile the 3 occurrences\n",
"of the primitive."
]
@ -967,21 +967,21 @@
"### Reverse differentiation\n",
"\n",
"If we attempt now to use reverse differentiation we\n",
"see that JAX starts by using the `multiply_add_value_and_jvp` to \n",
"see that JAX starts by using the `multiply_add_value_and_jvp` to\n",
"compute the forward differentiation for abstract values, but then runs\n",
"into a `NotImplementedError`. \n",
"into a `NotImplementedError`.\n",
"\n",
"When computing the reverse differentiation JAX first does abstract evaluation\n",
"of the forward differentiation code `multiply_add_value_and_jvp` to obtain a \n",
"trace of primitives that compute the output tangent. \n",
"of the forward differentiation code `multiply_add_value_and_jvp` to obtain a\n",
"trace of primitives that compute the output tangent.\n",
"Observe that JAX performs this abstract evaluation with concrete values\n",
"for the differentiation point, and abstract values for the tangents. \n",
"for the differentiation point, and abstract values for the tangents.\n",
"Observe also that JAX uses the special abstract tangent value `Zero` for\n",
"the tangent corresponding to the 3rd argument of `ma`. This reflects the \n",
"the tangent corresponding to the 3rd argument of `ma`. This reflects the\n",
"fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n",
"which flows to the 3rd argument to `multiply_add_prim`.\n",
"\n",
"Observe also that during the abstract evaluation of the tangent we pass the \n",
"Observe also that during the abstract evaluation of the tangent we pass the\n",
"value 0.0 as the tangent for the 3rd argument. This is due to the use\n",
"of the `make_zero` function in the definition of `multiply_add_value_and_jvp`."
]
@ -1071,7 +1071,7 @@
"\n",
"As explained above, when computing reverse differentiation JAX obtains\n",
"a trace of primitives that compute the tangent using forward differentiation.\n",
"Then, **JAX interprets this trace abstractly backwards** and for each \n",
"Then, **JAX interprets this trace abstractly backwards** and for each\n",
"primitive it applies a **transposition** rule.\n",
"\n",
"To understand what is going on, consider for now a simpler example of the function \"f(x, y) = x * y + y\". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:\n",
@ -1082,7 +1082,7 @@
" ft = c + yt\n",
"```\n",
"\n",
"By construction, the tangent calculation is always linear in the input tangents. \n",
"By construction, the tangent calculation is always linear in the input tangents.\n",
"The only non-linear operator that may arise in the tangent calculation is multiplication,\n",
"but then one of the operands is constant.\n",
"\n",
@ -1108,8 +1108,8 @@
" xct += act * 4.\n",
"```\n",
"\n",
"One can verify that this computation produces `xct = 4.` and `yct = 3.`, which \n",
"are the partial derivatives of the function `f`. \n",
"One can verify that this computation produces `xct = 4.` and `yct = 3.`, which\n",
"are the partial derivatives of the function `f`.\n",
"\n",
"JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:\n",
"```\n",
@ -1117,10 +1117,10 @@
"```\n",
"\n",
"Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other\n",
"arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned \n",
"arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned\n",
"for the constant arguments.\n",
"\n",
"In particular, \n",
"In particular,\n",
"```\n",
" add_transpose(out_ct, _, _) = (out_ct, out_ct)\n",
" mult_transpose(out_ct, x, _) = (None, x * out_ct)\n",
@ -1140,16 +1140,16 @@
"def multiply_add_transpose(ct, x, y, z):\n",
" \"\"\"Evaluates the transpose of a linear primitive.\n",
"\n",
" This method is only used when computing the backward gradient following \n",
" value_and_jvp, and is only needed for primitives that are used in the JVP \n",
" calculation for some other primitive. We need transposition for multiply_add_prim, \n",
" because we have used multiply_add_prim in the computation of the output_tangent in \n",
" This method is only used when computing the backward gradient following\n",
" value_and_jvp, and is only needed for primitives that are used in the JVP\n",
" calculation for some other primitive. We need transposition for multiply_add_prim,\n",
" because we have used multiply_add_prim in the computation of the output_tangent in\n",
" multiply_add_value_and_jvp.\n",
"\n",
" In our case, multiply_add is not a linear primitive. However, it is used linearly \n",
" In our case, multiply_add is not a linear primitive. However, it is used linearly\n",
" w.r.t. tangents in multiply_add_value_and_jvp:\n",
" output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n",
" \n",
"\n",
" Always one of the first two multiplicative arguments is a constant.\n",
"\n",
" Args:\n",
@ -1244,7 +1244,7 @@
},
"source": [
"Notice the two calls to `multiply_add_transpose`. They correspond to the two\n",
"uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the \n",
"uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the\n",
"last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0."
]
},
@ -1254,7 +1254,7 @@
"id": "EIJs6FYmPg6c"
},
"source": [
"#### JIT of reverse differentiation \n",
"#### JIT of reverse differentiation\n",
"\n",
"Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n",
"abstract values, while in the absence of JIT we used `ConcreteArray`."
@ -1397,20 +1397,20 @@
"@trace(\"multiply_add_batch\")\n",
"def multiply_add_batch(vector_arg_values, batch_axes):\n",
" \"\"\"Computes the batched version of the primitive.\n",
" \n",
"\n",
" This must be a JAX-traceable function.\n",
" \n",
"\n",
" Since the multiply_add primitive already operates pointwise on arbitrary\n",
" dimension tensors, to batch it we can use the primitive itself. This works as\n",
" long as both the inputs have the same dimensions and are batched along the\n",
" same axes. The result is batched along the axis that the inputs are batched.\n",
" \n",
"\n",
" Args:\n",
" vector_arg_values: a tuple of two arguments, each being a tensor of matching\n",
" shape.\n",
" batch_axes: the axes that are being batched. See vmap documentation.\n",
" Returns:\n",
" a tuple of the result, and the result axis that was batched. \n",
" a tuple of the result, and the result axis that was batched.\n",
" \"\"\"\n",
" assert batch_axes[0] == batch_axes[1]\n",
" assert batch_axes[0] == batch_axes[2]\n",

View File

@ -22,12 +22,12 @@ kernelspec:
*necula@google.com*, October 2019.
JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,
`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,
`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,
which means that as the Python function executes
the only operations it applies to the data are either inspections of data
attributes such as shape or type, or special operations called JAX primitives.
In particular, a JAX-traceable function is sometimes invoked by JAX with
abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,
abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,
which captures the type and the shape of values, but not the concrete data values.
JAX primitives know how to operate on both concrete data
values and on the JAX abstract values.
@ -37,7 +37,7 @@ The JAX-transformed functions must themselves be JAX-traceable functions,
to ensure that these transformations
can be composed, e.g., `jit(jacfwd(grad(f)))`.
There are pre-defined JAX primitives corresponding to most XLA operations,
There are pre-defined JAX primitives corresponding to most XLA operations,
e.g., add, matmul, sin, cos, indexing.
JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs
using JAXs implementation of numpy are JAX-traceable and therefore transformable.
@ -49,8 +49,8 @@ one can define a new primitive that encapsulates the behavior of the function.
**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**
Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically
as "multiply_add(x, y, z) = x * y + z".
This function operates on 3 identically-shaped tensors of floating point
as "multiply_add(x, y, z) = x * y + z".
This function operates on 3 identically-shaped tensors of floating point
values and performs the operations pointwise.
+++ {"id": "HIJYIHNTD1yI"}
@ -58,7 +58,7 @@ values and performs the operations pointwise.
## Using existing primitives
The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other
functions that are themselves written using JAX primitives, e.g., those
functions that are themselves written using JAX primitives, e.g., those
defined in the `jax.lax` module:
```{code-cell} ipython3
@ -134,7 +134,7 @@ def trace(name):
return str(v)
def pp_values(args):
return ", ".join([pp(arg) for arg in args])
@functools.wraps(func)
def func_wrapper(*args):
_trace_indent("call {}({})".format(name, pp_values(args)))
@ -164,7 +164,7 @@ class expectNotImplementedError(object):
+++ {"id": "Qf4eLrLCFYDl"}
Instead of using `jax.lax` primitives directly, we can use other functions
Instead of using `jax.lax` primitives directly, we can use other functions
that are already written in terms of those primitives, such as those in `jax.numpy`:
```{code-cell} ipython3
@ -182,7 +182,7 @@ def multiply_add_numpy(x, y, z):
def square_add_numpy(a, b):
return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")
print("\nNormal evaluation:")
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
@ -191,13 +191,13 @@ print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
+++ {"id": "Sg-D8EdeFn4a"}
Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and
`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further
below in this colab).
It is important to remember that a JAX-traceable function must be able to
`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further
below in this colab).
It is important to remember that a JAX-traceable function must be able to
operate not only on concrete arguments but also on special abstract arguments
that JAX may use to abstract the function execution.
The JAX traceability property is satisfied as long as the function is written
The JAX traceability property is satisfied as long as the function is written
in terms of JAX primitives.
+++ {"id": "WxrQO7-XGLcg"}
@ -206,7 +206,7 @@ in terms of JAX primitives.
The right way to add support for multiply-add is in terms of existing
JAX primitives, as shown above. However, in order to demonstrate how JAX
primitives work let us pretend that we want to add a new primitive to
primitives work let us pretend that we want to add a new primitive to
JAX for the multiply-add functionality.
```{code-cell} ipython3
@ -218,9 +218,9 @@ multiply_add_p = core.Primitive("multiply_add") # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
"""The JAX-traceable way to use the JAX primitive.
Note that the traced arguments must be passed as positional arguments
to `bind`.
to `bind`.
"""
return multiply_add_p.bind(x, y, z)
@ -257,7 +257,7 @@ def multiply_add_impl(x, y, z):
This function does not need to be JAX traceable.
Args:
x, y, z: the concrete arguments of the primitive. Will only be called with
x, y, z: the concrete arguments of the primitive. Will only be called with
concrete values.
Returns:
the concrete result of the primitive.
@ -293,17 +293,17 @@ with expectNotImplementedError():
+++ {"id": "rHS1bAGHH44E"}
#### Abstract evaluation rules
In order to JIT the function, and for other transformations as well,
JAX first evaluates it abstractly using only the
In order to JIT the function, and for other transformations as well,
JAX first evaluates it abstractly using only the
shape and type of the arguments. This abstract evaluation serves multiple
purposes:
* Gets the sequence of JAX primitives that are used in the computation. This
sequence will be compiled.
* Computes the shape and type of all vectors and operations used in the computation.
* Gets the sequence of JAX primitives that are used in the computation. This
sequence will be compiled.
* Computes the shape and type of all vectors and operations used in the computation.
For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.
For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.
In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
```{code-cell} ipython3
@ -316,7 +316,7 @@ def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
This function does not need to be JAX traceable. It will be invoked with
abstractions of the actual arguments.
abstractions of the actual arguments.
Args:
xs, ys, zs: abstractions of the arguments.
Result:
@ -349,7 +349,7 @@ with expectNotImplementedError():
JAX compilation works by compiling each primitive into a graph of XLA operations.
This is the biggest hurdle to adding new functionality to JAX, because the
This is the biggest hurdle to adding new functionality to JAX, because the
set of XLA operations is limited, and JAX already has pre-defined primitives
for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++.
@ -378,7 +378,7 @@ mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
+++ {"id": "K98LX-VaJkFu"}
Now we succeed to JIT. Notice below that JAX first evaluates the function
abstractly, which triggers the `multiply_add_abstract_eval` function, and
abstractly, which triggers the `multiply_add_abstract_eval` function, and
then compiles the set of primitives it has encountered, including `multiply_add`.
At this point JAX invokes `multiply_add_xla_translation`.
@ -393,7 +393,7 @@ assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
Below is another use of `jit` where we compile only
with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads
in the third argument to `multiply_add_abstract_eval` being
in the third argument to `multiply_add_abstract_eval` being
`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with
both `ShapedArray` and `ConcreteArray`.
@ -401,7 +401,7 @@ both `ShapedArray` and `ConcreteArray`.
:id: mPfTwIBoKOEK
:outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b
assert api.jit(lambda x, y: square_add_prim(x, y),
assert api.jit(lambda x, y: square_add_prim(x, y),
static_argnums=1)(2., 10.) == 14.
```
@ -437,16 +437,16 @@ from jax.interpreters import ad
def multiply_add_value_and_jvp(arg_values, arg_tangents):
"""Evaluates the primal output and the tangents (Jacobian-vector product).
Given values of the arguments and perturbation of the arguments (tangents),
Given values of the arguments and perturbation of the arguments (tangents),
compute the output of the primitive and the perturbation of the output.
This method must be JAX-traceable. JAX may invoke it with abstract values
This method must be JAX-traceable. JAX may invoke it with abstract values
for the arguments and tangents.
Args:
arg_values: a tuple of arguments
arg_tangents: a tuple with the tangents of the arguments. The tuple has
the same length as the arg_values. Some of the tangents may also be the
arg_tangents: a tuple with the tangents of the arguments. The tuple has
the same length as the arg_values. Some of the tangents may also be the
special value ad.Zero to specify a zero tangent.
Returns:
a pair of the primal output and the tangent.
@ -454,26 +454,26 @@ def multiply_add_value_and_jvp(arg_values, arg_tangents):
x, y, z = arg_values
xt, yt, zt = arg_tangents
_trace("Primal evaluation:")
# Now we have a JAX-traceable computation of the output.
# Normally, we can use the ma primitive itself to compute the primal output.
# Now we have a JAX-traceable computation of the output.
# Normally, we can use the ma primitive itself to compute the primal output.
primal_out = multiply_add_prim(x, y, z)
_trace("Tangent evaluation:")
# We must use a JAX-traceable way to compute the tangent. It turns out that
# We must use a JAX-traceable way to compute the tangent. It turns out that
# the output tangent can be computed as (xt * y + x * yt + zt),
# which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
# We do need to deal specially with Zero. Here we just turn it into a
# proper tensor of 0s (of the same shape as 'x').
# An alternative would be to check for Zero and perform algebraic
# We do need to deal specially with Zero. Here we just turn it into a
# proper tensor of 0s (of the same shape as 'x').
# An alternative would be to check for Zero and perform algebraic
# simplification of the output tangent computation.
def make_zero(tan):
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
return (primal_out, output_tangent)
# Register the forward differentiation rule with JAX
# Register the forward differentiation rule with JAX
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
```
@ -487,7 +487,7 @@ assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
+++ {"id": "69QsEcu-lP4u"}
TO EXPLAIN:
TO EXPLAIN:
* Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.
* Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet
@ -504,7 +504,7 @@ We can apply JIT to the forward differentiation function:
:id: hg-hzVu-N-hv
:outputId: 38d32067-e152-4046-ad80-7f95a31ba628
assert api.jit(lambda arg_values, arg_tangents:
assert api.jit(lambda arg_values, arg_tangents:
api.jvp(square_add_prim, arg_values, arg_tangents))(
(2., 10.), (1., 1.)) == (14., 5.)
```
@ -512,7 +512,7 @@ assert api.jit(lambda arg_values, arg_tangents:
+++ {"id": "jlZt1_v2mU88"}
Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn
evaluates abstractly both the primal and the tangent evaluation (a total of
evaluates abstractly both the primal and the tangent evaluation (a total of
3 invocations of the `ma` primitive). Then we compile the 3 occurrences
of the primitive.
@ -521,21 +521,21 @@ of the primitive.
### Reverse differentiation
If we attempt now to use reverse differentiation we
see that JAX starts by using the `multiply_add_value_and_jvp` to
see that JAX starts by using the `multiply_add_value_and_jvp` to
compute the forward differentiation for abstract values, but then runs
into a `NotImplementedError`.
into a `NotImplementedError`.
When computing the reverse differentiation JAX first does abstract evaluation
of the forward differentiation code `multiply_add_value_and_jvp` to obtain a
trace of primitives that compute the output tangent.
of the forward differentiation code `multiply_add_value_and_jvp` to obtain a
trace of primitives that compute the output tangent.
Observe that JAX performs this abstract evaluation with concrete values
for the differentiation point, and abstract values for the tangents.
for the differentiation point, and abstract values for the tangents.
Observe also that JAX uses the special abstract tangent value `Zero` for
the tangent corresponding to the 3rd argument of `ma`. This reflects the
the tangent corresponding to the 3rd argument of `ma`. This reflects the
fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,
which flows to the 3rd argument to `multiply_add_prim`.
Observe also that during the abstract evaluation of the tangent we pass the
Observe also that during the abstract evaluation of the tangent we pass the
value 0.0 as the tangent for the 3rd argument. This is due to the use
of the `make_zero` function in the definition of `multiply_add_value_and_jvp`.
@ -560,7 +560,7 @@ to use the forward differentiation code to compute reverse differentiation.
As explained above, when computing reverse differentiation JAX obtains
a trace of primitives that compute the tangent using forward differentiation.
Then, **JAX interprets this trace abstractly backwards** and for each
Then, **JAX interprets this trace abstractly backwards** and for each
primitive it applies a **transposition** rule.
To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:
@ -571,7 +571,7 @@ To understand what is going on, consider for now a simpler example of the functi
ft = c + yt
```
By construction, the tangent calculation is always linear in the input tangents.
By construction, the tangent calculation is always linear in the input tangents.
The only non-linear operator that may arise in the tangent calculation is multiplication,
but then one of the operands is constant.
@ -597,8 +597,8 @@ of the operation:
xct += act * 4.
```
One can verify that this computation produces `xct = 4.` and `yct = 3.`, which
are the partial derivatives of the function `f`.
One can verify that this computation produces `xct = 4.` and `yct = 3.`, which
are the partial derivatives of the function `f`.
JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:
```
@ -606,10 +606,10 @@ p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
```
Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other
arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned
arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned
for the constant arguments.
In particular,
In particular,
```
add_transpose(out_ct, _, _) = (out_ct, out_ct)
mult_transpose(out_ct, x, _) = (None, x * out_ct)
@ -623,16 +623,16 @@ In particular,
def multiply_add_transpose(ct, x, y, z):
"""Evaluates the transpose of a linear primitive.
This method is only used when computing the backward gradient following
value_and_jvp, and is only needed for primitives that are used in the JVP
calculation for some other primitive. We need transposition for multiply_add_prim,
because we have used multiply_add_prim in the computation of the output_tangent in
This method is only used when computing the backward gradient following
value_and_jvp, and is only needed for primitives that are used in the JVP
calculation for some other primitive. We need transposition for multiply_add_prim,
because we have used multiply_add_prim in the computation of the output_tangent in
multiply_add_value_and_jvp.
In our case, multiply_add is not a linear primitive. However, it is used linearly
In our case, multiply_add is not a linear primitive. However, it is used linearly
w.r.t. tangents in multiply_add_value_and_jvp:
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
Always one of the first two multiplicative arguments is a constant.
Args:
@ -674,12 +674,12 @@ assert api.grad(square_add_prim)(2., 10.) == 4.
+++ {"id": "8M1xLCXW4fK7"}
Notice the two calls to `multiply_add_transpose`. They correspond to the two
uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the
uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the
last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0.
+++ {"id": "EIJs6FYmPg6c"}
#### JIT of reverse differentiation
#### JIT of reverse differentiation
Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only
abstract values, while in the absence of JIT we used `ConcreteArray`.
@ -721,20 +721,20 @@ from jax.interpreters import batching
@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
"""Computes the batched version of the primitive.
This must be a JAX-traceable function.
Since the multiply_add primitive already operates pointwise on arbitrary
dimension tensors, to batch it we can use the primitive itself. This works as
long as both the inputs have the same dimensions and are batched along the
same axes. The result is batched along the axis that the inputs are batched.
Args:
vector_arg_values: a tuple of two arguments, each being a tensor of matching
shape.
batch_axes: the axes that are being batched. See vmap documentation.
Returns:
a tuple of the result, and the result axis that was batched.
a tuple of the result, and the result axis that was batched.
"""
assert batch_axes[0] == batch_axes[1]
assert batch_axes[0] == batch_axes[2]

View File

@ -119,7 +119,7 @@
" for w, b in params[:-1]:\n",
" outputs = jnp.dot(w, activations) + b\n",
" activations = relu(outputs)\n",
" \n",
"\n",
" final_w, final_b = params[-1]\n",
" logits = jnp.dot(final_w, activations) + final_b\n",
" return logits - logsumexp(logits)"
@ -238,7 +238,7 @@
"def one_hot(x, k, dtype=jnp.float32):\n",
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
" \n",
"\n",
"def accuracy(params, images, targets):\n",
" target_class = jnp.argmax(targets, axis=1)\n",
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",

View File

@ -96,7 +96,7 @@ def predict(params, image):
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
@ -156,7 +156,7 @@ At this point, we have all the ingredients we need to define our neural network
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)

View File

@ -35,7 +35,6 @@
},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from jax import jit, grad, vmap\n",
@ -214,7 +213,6 @@
"outputs": [],
"source": [
"# Importing Jax functions useful for tracing/interpreting.\n",
"import numpy as np\n",
"from functools import wraps\n",
"\n",
"from jax import core\n",
@ -273,7 +271,7 @@
"def eval_jaxpr(jaxpr, consts, *args):\n",
" # Mapping from variable -> value\n",
" env = {}\n",
" \n",
"\n",
" def read(var):\n",
" # Literals are values baked into the Jaxpr\n",
" if type(var) is core.Literal:\n",
@ -290,16 +288,16 @@
" # Loop through equations and evaluate primitives using `bind`\n",
" for eqn in jaxpr.eqns:\n",
" # Read inputs to equation from environment\n",
" invals = safe_map(read, eqn.invars) \n",
" invals = safe_map(read, eqn.invars)\n",
" # `bind` is how a primitive is called\n",
" outvals = eqn.primitive.bind(*invals, **eqn.params)\n",
" # Primitives may return multiple outputs or not\n",
" if not eqn.primitive.multiple_results: \n",
" if not eqn.primitive.multiple_results:\n",
" outvals = [outvals]\n",
" # Write the results of the primitive into the environment\n",
" safe_map(write, eqn.outvars, outvals) \n",
" safe_map(write, eqn.outvars, outvals)\n",
" # Read the final result of the Jaxpr from the environment\n",
" return safe_map(read, jaxpr.outvars) "
" return safe_map(read, jaxpr.outvars)"
]
},
{
@ -417,7 +415,7 @@
"source": [
"def inverse_jaxpr(jaxpr, consts, *args):\n",
" env = {}\n",
" \n",
"\n",
" def read(var):\n",
" if type(var) is core.Literal:\n",
" return var.val\n",
@ -431,12 +429,12 @@
"\n",
" # Looping backward\n",
" for eqn in jaxpr.eqns[::-1]:\n",
" # outvars are now invars \n",
" # outvars are now invars\n",
" invals = safe_map(read, eqn.outvars)\n",
" if eqn.primitive not in inverse_registry:\n",
" raise NotImplementedError(\n",
" f\"{eqn.primitive} does not have registered inverse.\")\n",
" # Assuming a unary function \n",
" # Assuming a unary function\n",
" outval = inverse_registry[eqn.primitive](*invals)\n",
" safe_map(write, eqn.invars, [outval])\n",
" return safe_map(read, jaxpr.invars)"

View File

@ -32,7 +32,6 @@ Here we show how to add your own function transformations to the system, by writ
```{code-cell} ipython3
:id: s27RDKvKXFL8
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
@ -146,7 +145,6 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr.
:id: BHkg_3P1pXJj
# Importing Jax functions useful for tracing/interpreting.
import numpy as np
from functools import wraps
from jax import core
@ -185,7 +183,7 @@ To do this, we first create an environment to store the values for each of the v
def eval_jaxpr(jaxpr, consts, *args):
# Mapping from variable -> value
env = {}
def read(var):
# Literals are values baked into the Jaxpr
if type(var) is core.Literal:
@ -202,16 +200,16 @@ def eval_jaxpr(jaxpr, consts, *args):
# Loop through equations and evaluate primitives using `bind`
for eqn in jaxpr.eqns:
# Read inputs to equation from environment
invals = safe_map(read, eqn.invars)
invals = safe_map(read, eqn.invars)
# `bind` is how a primitive is called
outvals = eqn.primitive.bind(*invals, **eqn.params)
# Primitives may return multiple outputs or not
if not eqn.primitive.multiple_results:
if not eqn.primitive.multiple_results:
outvals = [outvals]
# Write the results of the primitive into the environment
safe_map(write, eqn.outvars, outvals)
safe_map(write, eqn.outvars, outvals)
# Read the final result of the Jaxpr from the environment
return safe_map(read, jaxpr.outvars)
return safe_map(read, jaxpr.outvars)
```
```{code-cell} ipython3
@ -279,7 +277,7 @@ Now we just need to define `inverse_jaxpr`, which will walk through the Jaxpr ba
def inverse_jaxpr(jaxpr, consts, *args):
env = {}
def read(var):
if type(var) is core.Literal:
return var.val
@ -293,12 +291,12 @@ def inverse_jaxpr(jaxpr, consts, *args):
# Looping backward
for eqn in jaxpr.eqns[::-1]:
# outvars are now invars
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
# Assuming a unary function
# Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
return safe_map(read, jaxpr.invars)

View File

@ -1148,7 +1148,7 @@
" y, vjp_fun = vjp(f, x)\n",
" # Use vmap to do a matrix-Jacobian product.\n",
" # Here, the matrix is the Euclidean basis, so we get all\n",
" # entries in the Jacobian at once. \n",
" # entries in the Jacobian at once.\n",
" J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))\n",
" return J\n",
" return jacfun\n",
@ -1169,7 +1169,7 @@
"def our_jacfwd(f):\n",
" def jacfun(x):\n",
" _jvp = lambda s: jvp(f, (x,), (s,))[1]\n",
" Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n",
" Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n",
" return jnp.transpose(Jt)\n",
" return jacfun\n",
"\n",

View File

@ -675,7 +675,7 @@ def our_jacrev(f):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
@ -691,7 +691,7 @@ from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun

View File

@ -739,8 +739,6 @@
"metadata": {},
"outputs": [],
"source": [
"from jax.ad_checkpoint import checkpoint_name\n",
"\n",
"def predict(params, x):\n",
" *Ws, Wlast = params\n",
" for i, W in enumerate(Ws):\n",

View File

@ -370,8 +370,6 @@ Notice also that by providing a policy, we didn't need to edit the code defining
Some policies can refer to values named with `jax.ad_checkpoint.checkpoint_name`:
```{code-cell}
from jax.ad_checkpoint import checkpoint_name
def predict(params, x):
*Ws, Wlast = params
for i, W in enumerate(Ws):

View File

@ -410,7 +410,7 @@
],
"source": [
"dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape\n",
" kernel.shape, # only ndim matters, not shape \n",
" kernel.shape, # only ndim matters, not shape\n",
" ('NHWC', 'HWIO', 'NHWC')) # the important bit\n",
"print(dn)"
]
@ -806,8 +806,8 @@
],
"source": [
"# 1D kernel - WIO layout\n",
"kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], \n",
" [[1, 1, 1], [-1, -1, -1]]], \n",
"kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],\n",
" [[1, 1, 1], [-1, -1, -1]]],\n",
" dtype=jnp.float32).transpose([2,1,0])\n",
"# 1D data - NWC layout\n",
"data = np.zeros((1, 200, 2), dtype=jnp.float32)\n",
@ -895,8 +895,8 @@
"# Random 3D kernel - HWDIO layout\n",
"kernel = jnp.array([\n",
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n",
" [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n",
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n",
" [[0, -1, 0], [-1, 0, -1], [0, -1, 0]],\n",
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]]],\n",
" dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]\n",
"\n",
"# 3D data - NHWDC layout\n",
@ -919,7 +919,6 @@
"print(\"out shape: \", out.shape)\n",
"\n",
"# Make some simple 3d density plots:\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"def make_alpha(cmap):\n",
" my_cmap = cmap(jnp.arange(cmap.N))\n",
" my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3\n",

View File

@ -210,7 +210,7 @@ The important argument is the 3-tuple of axis layout arguments:
:outputId: d5a569b3-febc-4832-f725-1d5e8fd31b9b
dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape
kernel.shape, # only ndim matters, not shape
kernel.shape, # only ndim matters, not shape
('NHWC', 'HWIO', 'NHWC')) # the important bit
print(dn)
```
@ -363,8 +363,8 @@ You aren't limited to 2D convolutions, a simple 1D demo is below:
:outputId: 67c46ace-6adc-4c47-c1c7-1f185be5fd4b
# 1D kernel - WIO layout
kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],
[[1, 1, 1], [-1, -1, -1]]],
kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],
[[1, 1, 1], [-1, -1, -1]]],
dtype=jnp.float32).transpose([2,1,0])
# 1D data - NWC layout
data = np.zeros((1, 200, 2), dtype=jnp.float32)
@ -406,8 +406,8 @@ import matplotlib as mpl
# Random 3D kernel - HWDIO layout
kernel = jnp.array([
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
[[0, -1, 0], [-1, 0, -1], [0, -1, 0]],
[[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
[[0, -1, 0], [-1, 0, -1], [0, -1, 0]],
[[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]
# 3D data - NHWDC layout
@ -430,7 +430,6 @@ out = lax.conv_general_dilated(data, # lhs = image tensor
print("out shape: ", out.shape)
# Make some simple 3d density plots:
from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
my_cmap = cmap(jnp.arange(cmap.N))
my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3

View File

@ -840,7 +840,6 @@
},
"outputs": [],
"source": [
"from functools import partial\n",
"j1 = partial(jv, 1)\n",
"z = jnp.arange(5.0)"
]

View File

@ -410,7 +410,6 @@ This lets us call into `scipy.special.jv` from transformed JAX code, including w
```{code-cell}
:id: f4e46670f4e4
from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)
```

View File

@ -132,7 +132,7 @@
" for w, b in params[:-1]:\n",
" outputs = jnp.dot(w, activations) + b\n",
" activations = relu(outputs)\n",
" \n",
"\n",
" final_w, final_b = params[-1]\n",
" logits = jnp.dot(final_w, activations) + final_b\n",
" return logits - logsumexp(logits)"
@ -251,7 +251,7 @@
"def one_hot(x, k, dtype=jnp.float32):\n",
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
" \n",
"\n",
"def accuracy(params, images, targets):\n",
" target_class = jnp.argmax(targets, axis=1)\n",
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",

View File

@ -104,7 +104,7 @@ def predict(params, image):
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
@ -164,7 +164,7 @@ At this point, we have all the ingredients we need to define our neural network
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)

View File

@ -25,17 +25,10 @@
},
"outputs": [],
"source": [
"import functools\n",
"import itertools\n",
"import re\n",
"import sys\n",
"import time\n",
"\n",
"from matplotlib.pyplot import *\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import jax\n",
"\n",
"from jax import lax\n",
"import jax.numpy as jnp\n",
"import jax.scipy as jsp\n",
"from jax import random\n",
@ -348,7 +341,7 @@
"def elbo(beta_loc, beta_log_scale, epsilon):\n",
" beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n",
" return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n",
" \n",
"\n",
"elbo = jax.jit(elbo)\n",
"elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))"
]
@ -548,25 +541,16 @@
}
],
"source": [
"figure(figsize=(7, 7))\n",
"plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n",
"plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n",
"plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n",
"plt.figure(figsize=(7, 7))\n",
"plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n",
"plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\\sigma$ Error Bars')\n",
"plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n",
"plot_scale = 3\n",
"plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n",
"xlabel('True beta')\n",
"ylabel('Estimated beta')\n",
"legend(loc='best')"
"plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n",
"plt.xlabel('True beta')\n",
"plt.ylabel('Estimated beta')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "_bXdOlvUEJl0"
},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@ -27,17 +27,10 @@ Inspired by a notebook by @davmre.
```{code-cell} ipython3
:id: 8RZDkfbV3zdR
import functools
import itertools
import re
import sys
import time
from matplotlib.pyplot import *
import matplotlib.pyplot as plt
import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
@ -192,7 +185,7 @@ batched_log_joint = jax.jit(jax.vmap(log_joint))
def elbo(beta_loc, beta_log_scale, epsilon):
beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))
elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
```
@ -240,19 +233,13 @@ Coverage isn't quite as good as we might like, but it's not bad, and nobody said
:id: zt1NBLoVHtOG
:outputId: fb159795-e6e7-497c-e501-9933ec761af4
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plt.figure(figsize=(7, 7))
plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\sigma$ Error Bars')
plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
```
```{code-cell} ipython3
:id: _bXdOlvUEJl0
plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
plt.xlabel('True beta')
plt.ylabel('Estimated beta')
plt.legend(loc='best')
```

View File

@ -189,9 +189,6 @@
],
"source": [
"# Pardon the boilerplate; constructing a sharding will become easier in future!\n",
"from jax.sharding import Mesh\n",
"from jax.sharding import PartitionSpec\n",
"from jax.sharding import NamedSharding\n",
"from jax.experimental import mesh_utils\n",
"\n",
"P = jax.sharding.PartitionSpec\n",

View File

@ -73,9 +73,6 @@ Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimens
:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5
# Pardon the boilerplate; constructing a sharding will become easier in future!
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils
P = jax.sharding.PartitionSpec

View File

@ -1268,7 +1268,7 @@ def new_base_main(trace_type: type[Trace],
@contextmanager
def pop_level(level: int):
if level == 0:
return (yield)
return (yield) # noqa: B901
prev, thread_local_state.trace_state.trace_stack.stack = \
thread_local_state.trace_state.trace_stack.stack, \
thread_local_state.trace_state.trace_stack.stack[:level]

View File

@ -760,7 +760,7 @@ def _dot_product_attention_fwd_partition(
scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training,
mesh, arg_shapes, result_shape):
# args sharding
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes)
out_shardings = _infer_fwd_output_sharding(
mesh, arg_shapes, variadic_args, is_training)
impl = functools.partial(
@ -810,7 +810,7 @@ def _dot_product_attention_bwd_partition(
arg_shapes, result_shape):
out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)
# args sharding
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes)
def sharded_impl(*args):
impl = functools.partial(
_dot_product_attention_bwd_impl,

View File

@ -238,7 +238,7 @@ class BufferedRef:
Returns:
Initialized BufferedRef
"""
block_shape = tuple([1 if x is None else x for x in spec.block_shape])
block_shape = tuple(1 if x is None else x for x in spec.block_shape)
if buffer_type is BufferType.ACCUMULATOR:
accum_ref = VMEM(block_shape, dtype)
else:
@ -310,7 +310,7 @@ class BufferedRef:
@property
def current_ref(self):
buffer_slice = tuple(
[0 if x is None else slice(None) for x in self.block_shape])
0 if x is None else slice(None) for x in self.block_shape)
if self.memory_space == VMEM:
return self.vmem_ref.at[buffer_slice]
else:
@ -349,7 +349,7 @@ class BufferedRef:
def compute_slice(self, grid_indices):
"""Compute DMA slice from grid indices."""
block_shape = tuple([1 if x is None else x for x in self.block_shape])
block_shape = tuple(1 if x is None else x for x in self.block_shape)
indices = self.compute_index(*grid_indices)
return jax.tree.map(_make_ds, indices, block_shape)

View File

@ -172,7 +172,7 @@ def _normalize_tolerance(tol):
if isinstance(tol, dict):
return {np.dtype(k): v for k, v in tol.items()}
else:
return {k: tol for k in _default_tolerance}
return dict.fromkeys(_default_tolerance, tol)
def join_tolerance(tol1, tol2):
tol1 = _normalize_tolerance(tol1)

View File

@ -115,6 +115,8 @@ ignore = [
"C408",
# Unnecessary map usage
"C417",
# Unnecessary dict comprehension for iterable
"C420",
# Object names too complex
"C901",
# Local variable is assigned to but never used
@ -141,7 +143,16 @@ max-complexity = 18
[tool.ruff.lint.per-file-ignores]
# F811: Redefinition of unused name.
# F821: Undefined name.
"docs/autodidax.py" = ["F811"]
"docs/pallas/tpu/matmul.ipynb" = ["F811"]
"docs/pallas/tpu/distributed.ipynb" = ["F811"]
"docs/pallas/quickstart.ipynb" = ["F811"]
"docs/notebooks/autodiff_cookbook.ipynb" = ["F811", "F821"]
"docs/notebooks/autodiff_remat.ipynb" = ["F811", "F821"]
"docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb" = ["F811"]
"docs/jep/9407-type-promotion.ipynb" = ["F811"]
"docs/autodidax.ipynb" = ["F811"]
# Note: we don't use jax/*.py because this matches contents of jax/_src
"__init__.py" = ["F401"]
"jax/abstract_arrays.py" = ["F401"]

View File

@ -308,7 +308,7 @@ class MemRefTest(TestCase):
expanded_shape = get_packed_shape(strides, shape)
total_size = np.prod(expanded_shape)
np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape)
index = tuple([slice(0, s) for s in shape])
index = tuple(slice(0, s) for s in shape)
# Reference implementation
def np_fold(inp, dim, fold_rank):

View File

@ -88,15 +88,6 @@
"height": 68
}
},
"source": [
"from jaxlib import xla_extension\n",
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = arr.device()\n",
"print(f\"JAX device type: {device}\")\n",
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
],
"execution_count": 2,
"outputs": [
{

View File

@ -500,11 +500,11 @@ class IndexerOpsTest(PallasBaseTest):
def body(x_ref, y_ref1, y_ref2):
if slice_type == "slice":
slices = tuple(
[slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides)]
slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides)
)
else:
slices = tuple(
[pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides)]
pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides)
)
if indexer_type == "state":
y_ref1[...] = x_ref[slices]