diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 355f134f0..79a0df6e9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: files: \.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.4 + rev: v0.6.1 hooks: - id: ruff diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index 279aef3e9..cb5a42ced 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -510,8 +510,8 @@ "outputs": [], "source": [ "image_partitions = P(1, 1, 4, 2)\n", - "sharded_conv = sharded_jit(conv, \n", - " in_parts=(image_partitions, None), \n", + "sharded_conv = sharded_jit(conv,\n", + " in_parts=(image_partitions, None),\n", " out_parts=image_partitions)\n", "\n", "sharded_conv(image, kernel)" diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index 9acb1971c..4952cdbe9 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -877,7 +877,7 @@ " def g(z):\n", " return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n", " return grad(lambda w: jnp.sum(g(w)))(x)\n", - " \n", + "\n", "f(x)" ] }, @@ -950,17 +950,6 @@ "per_example_hess = pmap(input_hess) # pmap!\n", "per_example_hess(inputs)" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "u3ggM_WYZ8QC" - }, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/cloud_tpu_colabs/Wave_Equation.ipynb b/cloud_tpu_colabs/Wave_Equation.ipynb index 059173919..16f675a76 100644 --- a/cloud_tpu_colabs/Wave_Equation.ipynb +++ b/cloud_tpu_colabs/Wave_Equation.ipynb @@ -67,7 +67,6 @@ "source": [ "from functools import partial\n", "import jax\n", - "from jax import jit, pmap\n", "from jax import lax\n", "from jax import tree_util\n", "import jax.numpy as jnp\n", diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index 0449b82e9..20affa8cf 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -640,7 +640,7 @@ def our_jacrev(f): y, vjp_fun = vjp(f, x) # Use vmap to do a matrix-Jacobian product. # Here, the matrix is the Euclidean basis, so we get all - # entries in the Jacobian at once. + # entries in the Jacobian at once. J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) return J return jacfun @@ -654,7 +654,7 @@ from jax import jacfwd as builtin_jacfwd def our_jacfwd(f): def jacfun(x): _jvp = lambda s: jvp(f, (x,), (s,))[1] - Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x))) + Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x))) return jnp.transpose(Jt) return jacfun diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index 2aef17681..3e99daabe 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -3317,7 +3317,6 @@ ], "source": [ "# @title\n", - "from jax import dtypes\n", "import jax\n", "import jax.numpy as jnp\n", "import pandas as pd\n", @@ -3802,7 +3801,6 @@ ], "source": [ "# @title\n", - "from jax import dtypes\n", "import jax\n", "import jax.numpy as jnp\n", "import pandas as pd\n", diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index 107bcd8c9..2d12944f1 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -908,7 +908,6 @@ display.HTML(table.to_html()) :tags: [hide-input] # @title -from jax import dtypes import jax import jax.numpy as jnp import pandas as pd @@ -963,7 +962,6 @@ display.HTML(table.to_html()) :tags: [hide-input] # @title -from jax import dtypes import jax import jax.numpy as jnp import pandas as pd diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index d76914440..c143b520a 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -226,7 +226,6 @@ ], "source": [ "import jax.numpy as jnp\n", - "import jax.lax as lax\n", "from jax import make_jaxpr\n", "\n", "# lax.fori_loop\n", @@ -1031,7 +1030,6 @@ } ], "source": [ - "from jax import random\n", "key = random.key(0)\n", "key" ] @@ -1105,8 +1103,8 @@ "print(\"old key\", key)\n", "key, subkey = random.split(key)\n", "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(\" \\---SPLIT --> new key \", key)\n", - "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" + "print(r\" \\---SPLIT --> new key \", key)\n", + "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" ] }, { @@ -1140,8 +1138,8 @@ "print(\"old key\", key)\n", "key, subkey = random.split(key)\n", "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(\" \\---SPLIT --> new key \", key)\n", - "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" + "print(r\" \\---SPLIT --> new key \", key)\n", + "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" ] }, { @@ -1701,7 +1699,7 @@ ], "source": [ "init_val = 0\n", - "cond_fun = lambda x: x<10\n", + "cond_fun = lambda x: x < 10\n", "body_fun = lambda x: x+1\n", "lax.while_loop(cond_fun, body_fun, init_val)\n", "# --> array(10, dtype=int32)" diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index edf5c9446..0b21a57e3 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -130,7 +130,6 @@ It is not recommended to use iterators in any JAX function you want to `jit` or :outputId: 52d885fd-0239-4a08-f5ce-0c38cc008903 import jax.numpy as jnp -import jax.lax as lax from jax import make_jaxpr # lax.fori_loop @@ -471,7 +470,6 @@ The random state is described by a special array element that we call a __key__: :id: yPHE7KTWgAWs :outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3 -from jax import random key = random.key(0) key ``` @@ -504,8 +502,8 @@ Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a ne print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(" \---SPLIT --> new key ", key) -print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom) +print(r" \---SPLIT --> new key ", key) +print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) ``` +++ {"id": "tqtFVE4MthO3"} @@ -519,8 +517,8 @@ We propagate the __key__ and make new __subkeys__ whenever we need a new random print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(" \---SPLIT --> new key ", key) -print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom) +print(r" \---SPLIT --> new key ", key) +print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) ``` +++ {"id": "0KLYUluz3lN3"} @@ -805,7 +803,7 @@ def while_loop(cond_fun, body_fun, init_val): :outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e init_val = 0 -cond_fun = lambda x: x<10 +cond_fun = lambda x: x < 10 body_fun = lambda x: x+1 lax.while_loop(cond_fun, body_fun, init_val) # --> array(10, dtype=int32) diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 3abb6d9cb..ec85f6e63 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -247,7 +247,6 @@ } ], "source": [ - "import jax.numpy as jnp\n", "\n", "def log1pexp(x):\n", " return jnp.log(1. + jnp.exp(x))\n", @@ -984,7 +983,7 @@ " (a, x_star, x_star_bar),\n", " x_star_bar))\n", " return a_bar, jnp.zeros_like(x_star)\n", - " \n", + "\n", "def rev_iter(f, packed, u):\n", " a, x_star, x_star_bar = packed\n", " _, vjp_x = vjp(lambda x: f(a, x), x_star)\n", @@ -1884,7 +1883,6 @@ } ], "source": [ - "from jax import vjp\n", "\n", "y, f_vjp = vjp(f, 3.)\n", "print(y)" @@ -1983,7 +1981,7 @@ " return x, x\n", "\n", "def debug_bwd(x, g):\n", - " import pdb; pdb.set_trace()\n", + " pdb.set_trace()\n", " return g\n", "\n", "debug.defvjp(debug_fwd, debug_bwd)" diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index ad577d55c..3c60cce0c 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -145,7 +145,6 @@ Say we want to write a function called `log1pexp`, which computes $x \mapsto \lo :id: 6lWbTvs40ET- :outputId: 8caff99e-add1-4c70-ace3-212c0c5c6f4e -import jax.numpy as jnp def log1pexp(x): return jnp.log(1. + jnp.exp(x)) @@ -524,7 +523,7 @@ def fixed_point_rev(f, res, x_star_bar): (a, x_star, x_star_bar), x_star_bar)) return a_bar, jnp.zeros_like(x_star) - + def rev_iter(f, packed, u): a, x_star, x_star_bar = packed _, vjp_x = vjp(lambda x: f(a, x), x_star) @@ -965,7 +964,6 @@ print(grad(f)(3.)) :id: s1Pn_qCIODcF :outputId: 423d34e0-35b8-4b57-e89d-f70f20e28ea9 -from jax import vjp y, f_vjp = vjp(f, 3.) print(y) @@ -1015,7 +1013,7 @@ def debug_fwd(x): return x, x def debug_bwd(x, g): - import pdb; pdb.set_trace() + pdb.set_trace() return g debug.defvjp(debug_fwd, debug_bwd) diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 3d8c5b020..8bc0e0a52 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -30,9 +30,7 @@ }, "outputs": [], "source": [ - "import os\n", "\n", - "import functools\n", "from typing import Optional\n", "\n", "import numpy as np\n", diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index cb5d4602c..c5f3c08ed 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -26,9 +26,7 @@ This tutorial discusses parallelism via `jax.Array`, the unified array object mo ```{code-cell} :id: FNxScTfq3vGF -import os -import functools from typing import Optional import numpy as np diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb index f42e3f74b..0c20fc47d 100644 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ b/docs/notebooks/How_JAX_primitives_work.ipynb @@ -15,12 +15,12 @@ "*necula@google.com*, October 2019.\n", "\n", "JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,\n", - "`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, \n", + "`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,\n", "which means that as the Python function executes\n", "the only operations it applies to the data are either inspections of data\n", "attributes such as shape or type, or special operations called JAX primitives.\n", "In particular, a JAX-traceable function is sometimes invoked by JAX with\n", - "abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, \n", + "abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,\n", "which captures the type and the shape of values, but not the concrete data values.\n", "JAX primitives know how to operate on both concrete data\n", "values and on the JAX abstract values.\n", @@ -30,7 +30,7 @@ "to ensure that these transformations\n", "can be composed, e.g., `jit(jacfwd(grad(f)))`.\n", "\n", - "There are pre-defined JAX primitives corresponding to most XLA operations, \n", + "There are pre-defined JAX primitives corresponding to most XLA operations,\n", "e.g., add, matmul, sin, cos, indexing.\n", "JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs\n", "using JAX’s implementation of numpy are JAX-traceable and therefore transformable.\n", @@ -42,8 +42,8 @@ "**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**\n", "\n", "Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n", - "as \"multiply_add(x, y, z) = x * y + z\". \n", - "This function operates on 3 identically-shaped tensors of floating point \n", + "as \"multiply_add(x, y, z) = x * y + z\".\n", + "This function operates on 3 identically-shaped tensors of floating point\n", "values and performs the operations pointwise." ] }, @@ -56,7 +56,7 @@ "## Using existing primitives\n", "\n", "The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other\n", - "functions that are themselves written using JAX primitives, e.g., those \n", + "functions that are themselves written using JAX primitives, e.g., those\n", "defined in the `jax.lax` module:" ] }, @@ -165,7 +165,7 @@ " return str(v)\n", " def pp_values(args):\n", " return \", \".join([pp(arg) for arg in args])\n", - " \n", + "\n", " @functools.wraps(func)\n", " def func_wrapper(*args):\n", " _trace_indent(\"call {}({})\".format(name, pp_values(args)))\n", @@ -199,7 +199,7 @@ "id": "Qf4eLrLCFYDl" }, "source": [ - "Instead of using `jax.lax` primitives directly, we can use other functions \n", + "Instead of using `jax.lax` primitives directly, we can use other functions\n", "that are already written in terms of those primitives, such as those in `jax.numpy`:" ] }, @@ -244,7 +244,7 @@ "def square_add_numpy(a, b):\n", " return multiply_add_numpy(a, a, b)\n", "\n", - "print(\"\\nNormal evaluation:\") \n", + "print(\"\\nNormal evaluation:\")\n", "print(\"square_add_numpy = \", square_add_numpy(2., 10.))\n", "print(\"\\nGradient evaluation:\")\n", "print(\"grad(square_add_numpy) = \", api.grad(square_add_numpy)(2.0, 10.))" @@ -257,13 +257,13 @@ }, "source": [ "Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and\n", - "`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further \n", - "below in this colab). \n", - "It is important to remember that a JAX-traceable function must be able to \n", + "`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further\n", + "below in this colab).\n", + "It is important to remember that a JAX-traceable function must be able to\n", "operate not only on concrete arguments but also on special abstract arguments\n", "that JAX may use to abstract the function execution.\n", "\n", - "The JAX traceability property is satisfied as long as the function is written \n", + "The JAX traceability property is satisfied as long as the function is written\n", "in terms of JAX primitives." ] }, @@ -277,7 +277,7 @@ "\n", "The right way to add support for multiply-add is in terms of existing\n", "JAX primitives, as shown above. However, in order to demonstrate how JAX\n", - "primitives work let us pretend that we want to add a new primitive to \n", + "primitives work let us pretend that we want to add a new primitive to\n", "JAX for the multiply-add functionality." ] }, @@ -295,9 +295,9 @@ "@trace(\"multiply_add_prim\")\n", "def multiply_add_prim(x, y, z):\n", " \"\"\"The JAX-traceable way to use the JAX primitive.\n", - " \n", + "\n", " Note that the traced arguments must be passed as positional arguments\n", - " to `bind`. \n", + " to `bind`.\n", " \"\"\"\n", " return multiply_add_p.bind(x, y, z)\n", "\n", @@ -392,7 +392,7 @@ "\n", " This function does not need to be JAX traceable.\n", " Args:\n", - " x, y, z: the concrete arguments of the primitive. Will only be called with \n", + " x, y, z: the concrete arguments of the primitive. Will only be called with\n", " concrete values.\n", " Returns:\n", " the concrete result of the primitive.\n", @@ -485,17 +485,17 @@ }, "source": [ "#### Abstract evaluation rules\n", - "In order to JIT the function, and for other transformations as well, \n", - "JAX first evaluates it abstractly using only the \n", + "In order to JIT the function, and for other transformations as well,\n", + "JAX first evaluates it abstractly using only the\n", "shape and type of the arguments. This abstract evaluation serves multiple\n", "purposes:\n", "\n", - " * Gets the sequence of JAX primitives that are used in the computation. This \n", - " sequence will be compiled. \n", - " * Computes the shape and type of all vectors and operations used in the computation. \n", + " * Gets the sequence of JAX primitives that are used in the computation. This\n", + " sequence will be compiled.\n", + " * Computes the shape and type of all vectors and operations used in the computation.\n", "\n", "\n", - "For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. \n", + "For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.\n", "In the latter case, JAX uses the actual concrete value wrapped as an abstract value." ] }, @@ -527,7 +527,7 @@ " \"\"\"Abstract evaluation of the primitive.\n", "\n", " This function does not need to be JAX traceable. It will be invoked with\n", - " abstractions of the actual arguments. \n", + " abstractions of the actual arguments.\n", " Args:\n", " xs, ys, zs: abstractions of the arguments.\n", " Result:\n", @@ -603,7 +603,7 @@ "\n", "JAX compilation works by compiling each primitive into a graph of XLA operations.\n", "\n", - "This is the biggest hurdle to adding new functionality to JAX, because the \n", + "This is the biggest hurdle to adding new functionality to JAX, because the\n", "set of XLA operations is limited, and JAX already has pre-defined primitives\n", "for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++." ] @@ -642,7 +642,7 @@ }, "source": [ "Now we succeed to JIT. Notice below that JAX first evaluates the function\n", - "abstractly, which triggers the `multiply_add_abstract_eval` function, and \n", + "abstractly, which triggers the `multiply_add_abstract_eval` function, and\n", "then compiles the set of primitives it has encountered, including `multiply_add`.\n", "At this point JAX invokes `multiply_add_xla_translation`." ] @@ -682,7 +682,7 @@ "source": [ "Below is another use of `jit` where we compile only\n", "with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads\n", - "in the third argument to `multiply_add_abstract_eval` being \n", + "in the third argument to `multiply_add_abstract_eval` being\n", "`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with\n", "both `ShapedArray` and `ConcreteArray`." ] @@ -711,7 +711,7 @@ } ], "source": [ - "assert api.jit(lambda x, y: square_add_prim(x, y), \n", + "assert api.jit(lambda x, y: square_add_prim(x, y),\n", " static_argnums=1)(2., 10.) == 14." ] }, @@ -794,16 +794,16 @@ "def multiply_add_value_and_jvp(arg_values, arg_tangents):\n", " \"\"\"Evaluates the primal output and the tangents (Jacobian-vector product).\n", "\n", - " Given values of the arguments and perturbation of the arguments (tangents), \n", + " Given values of the arguments and perturbation of the arguments (tangents),\n", " compute the output of the primitive and the perturbation of the output.\n", "\n", - " This method must be JAX-traceable. JAX may invoke it with abstract values \n", + " This method must be JAX-traceable. JAX may invoke it with abstract values\n", " for the arguments and tangents.\n", "\n", " Args:\n", " arg_values: a tuple of arguments\n", - " arg_tangents: a tuple with the tangents of the arguments. The tuple has \n", - " the same length as the arg_values. Some of the tangents may also be the \n", + " arg_tangents: a tuple with the tangents of the arguments. The tuple has\n", + " the same length as the arg_values. Some of the tangents may also be the\n", " special value ad.Zero to specify a zero tangent.\n", " Returns:\n", " a pair of the primal output and the tangent.\n", @@ -811,26 +811,26 @@ " x, y, z = arg_values\n", " xt, yt, zt = arg_tangents\n", " _trace(\"Primal evaluation:\")\n", - " # Now we have a JAX-traceable computation of the output. \n", - " # Normally, we can use the ma primitive itself to compute the primal output. \n", + " # Now we have a JAX-traceable computation of the output.\n", + " # Normally, we can use the ma primitive itself to compute the primal output.\n", " primal_out = multiply_add_prim(x, y, z)\n", - " \n", + "\n", " _trace(\"Tangent evaluation:\")\n", - " # We must use a JAX-traceable way to compute the tangent. It turns out that \n", + " # We must use a JAX-traceable way to compute the tangent. It turns out that\n", " # the output tangent can be computed as (xt * y + x * yt + zt),\n", " # which we can implement in a JAX-traceable way using the same \"multiply_add_prim\" primitive.\n", - " \n", - " # We do need to deal specially with Zero. Here we just turn it into a \n", - " # proper tensor of 0s (of the same shape as 'x'). \n", - " # An alternative would be to check for Zero and perform algebraic \n", + "\n", + " # We do need to deal specially with Zero. Here we just turn it into a\n", + " # proper tensor of 0s (of the same shape as 'x').\n", + " # An alternative would be to check for Zero and perform algebraic\n", " # simplification of the output tangent computation.\n", " def make_zero(tan):\n", - " return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan \n", - " \n", + " return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan\n", + "\n", " output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))\n", " return (primal_out, output_tangent)\n", "\n", - "# Register the forward differentiation rule with JAX \n", + "# Register the forward differentiation rule with JAX\n", "ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp" ] }, @@ -880,7 +880,7 @@ "id": "69QsEcu-lP4u" }, "source": [ - "TO EXPLAIN: \n", + "TO EXPLAIN:\n", "\n", " * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n", " * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n", @@ -941,7 +941,7 @@ } ], "source": [ - "assert api.jit(lambda arg_values, arg_tangents: \n", + "assert api.jit(lambda arg_values, arg_tangents:\n", " api.jvp(square_add_prim, arg_values, arg_tangents))(\n", " (2., 10.), (1., 1.)) == (14., 5.)" ] @@ -953,7 +953,7 @@ }, "source": [ "Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn\n", - "evaluates abstractly both the primal and the tangent evaluation (a total of \n", + "evaluates abstractly both the primal and the tangent evaluation (a total of\n", "3 invocations of the `ma` primitive). Then we compile the 3 occurrences\n", "of the primitive." ] @@ -967,21 +967,21 @@ "### Reverse differentiation\n", "\n", "If we attempt now to use reverse differentiation we\n", - "see that JAX starts by using the `multiply_add_value_and_jvp` to \n", + "see that JAX starts by using the `multiply_add_value_and_jvp` to\n", "compute the forward differentiation for abstract values, but then runs\n", - "into a `NotImplementedError`. \n", + "into a `NotImplementedError`.\n", "\n", "When computing the reverse differentiation JAX first does abstract evaluation\n", - "of the forward differentiation code `multiply_add_value_and_jvp` to obtain a \n", - "trace of primitives that compute the output tangent. \n", + "of the forward differentiation code `multiply_add_value_and_jvp` to obtain a\n", + "trace of primitives that compute the output tangent.\n", "Observe that JAX performs this abstract evaluation with concrete values\n", - "for the differentiation point, and abstract values for the tangents. \n", + "for the differentiation point, and abstract values for the tangents.\n", "Observe also that JAX uses the special abstract tangent value `Zero` for\n", - "the tangent corresponding to the 3rd argument of `ma`. This reflects the \n", + "the tangent corresponding to the 3rd argument of `ma`. This reflects the\n", "fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n", "which flows to the 3rd argument to `multiply_add_prim`.\n", "\n", - "Observe also that during the abstract evaluation of the tangent we pass the \n", + "Observe also that during the abstract evaluation of the tangent we pass the\n", "value 0.0 as the tangent for the 3rd argument. This is due to the use\n", "of the `make_zero` function in the definition of `multiply_add_value_and_jvp`." ] @@ -1071,7 +1071,7 @@ "\n", "As explained above, when computing reverse differentiation JAX obtains\n", "a trace of primitives that compute the tangent using forward differentiation.\n", - "Then, **JAX interprets this trace abstractly backwards** and for each \n", + "Then, **JAX interprets this trace abstractly backwards** and for each\n", "primitive it applies a **transposition** rule.\n", "\n", "To understand what is going on, consider for now a simpler example of the function \"f(x, y) = x * y + y\". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:\n", @@ -1082,7 +1082,7 @@ " ft = c + yt\n", "```\n", "\n", - "By construction, the tangent calculation is always linear in the input tangents. \n", + "By construction, the tangent calculation is always linear in the input tangents.\n", "The only non-linear operator that may arise in the tangent calculation is multiplication,\n", "but then one of the operands is constant.\n", "\n", @@ -1108,8 +1108,8 @@ " xct += act * 4.\n", "```\n", "\n", - "One can verify that this computation produces `xct = 4.` and `yct = 3.`, which \n", - "are the partial derivatives of the function `f`. \n", + "One can verify that this computation produces `xct = 4.` and `yct = 3.`, which\n", + "are the partial derivatives of the function `f`.\n", "\n", "JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:\n", "```\n", @@ -1117,10 +1117,10 @@ "```\n", "\n", "Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other\n", - "arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned \n", + "arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned\n", "for the constant arguments.\n", "\n", - "In particular, \n", + "In particular,\n", "```\n", " add_transpose(out_ct, _, _) = (out_ct, out_ct)\n", " mult_transpose(out_ct, x, _) = (None, x * out_ct)\n", @@ -1140,16 +1140,16 @@ "def multiply_add_transpose(ct, x, y, z):\n", " \"\"\"Evaluates the transpose of a linear primitive.\n", "\n", - " This method is only used when computing the backward gradient following \n", - " value_and_jvp, and is only needed for primitives that are used in the JVP \n", - " calculation for some other primitive. We need transposition for multiply_add_prim, \n", - " because we have used multiply_add_prim in the computation of the output_tangent in \n", + " This method is only used when computing the backward gradient following\n", + " value_and_jvp, and is only needed for primitives that are used in the JVP\n", + " calculation for some other primitive. We need transposition for multiply_add_prim,\n", + " because we have used multiply_add_prim in the computation of the output_tangent in\n", " multiply_add_value_and_jvp.\n", "\n", - " In our case, multiply_add is not a linear primitive. However, it is used linearly \n", + " In our case, multiply_add is not a linear primitive. However, it is used linearly\n", " w.r.t. tangents in multiply_add_value_and_jvp:\n", " output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n", - " \n", + "\n", " Always one of the first two multiplicative arguments is a constant.\n", "\n", " Args:\n", @@ -1244,7 +1244,7 @@ }, "source": [ "Notice the two calls to `multiply_add_transpose`. They correspond to the two\n", - "uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the \n", + "uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the\n", "last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0." ] }, @@ -1254,7 +1254,7 @@ "id": "EIJs6FYmPg6c" }, "source": [ - "#### JIT of reverse differentiation \n", + "#### JIT of reverse differentiation\n", "\n", "Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n", "abstract values, while in the absence of JIT we used `ConcreteArray`." @@ -1397,20 +1397,20 @@ "@trace(\"multiply_add_batch\")\n", "def multiply_add_batch(vector_arg_values, batch_axes):\n", " \"\"\"Computes the batched version of the primitive.\n", - " \n", + "\n", " This must be a JAX-traceable function.\n", - " \n", + "\n", " Since the multiply_add primitive already operates pointwise on arbitrary\n", " dimension tensors, to batch it we can use the primitive itself. This works as\n", " long as both the inputs have the same dimensions and are batched along the\n", " same axes. The result is batched along the axis that the inputs are batched.\n", - " \n", + "\n", " Args:\n", " vector_arg_values: a tuple of two arguments, each being a tensor of matching\n", " shape.\n", " batch_axes: the axes that are being batched. See vmap documentation.\n", " Returns:\n", - " a tuple of the result, and the result axis that was batched. \n", + " a tuple of the result, and the result axis that was batched.\n", " \"\"\"\n", " assert batch_axes[0] == batch_axes[1]\n", " assert batch_axes[0] == batch_axes[2]\n", diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md index 0ebf202f2..656cd0e59 100644 --- a/docs/notebooks/How_JAX_primitives_work.md +++ b/docs/notebooks/How_JAX_primitives_work.md @@ -22,12 +22,12 @@ kernelspec: *necula@google.com*, October 2019. JAX implements certain transformations of Python functions, e.g., `jit`, `grad`, -`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, +`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, which means that as the Python function executes the only operations it applies to the data are either inspections of data attributes such as shape or type, or special operations called JAX primitives. In particular, a JAX-traceable function is sometimes invoked by JAX with -abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, +abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, which captures the type and the shape of values, but not the concrete data values. JAX primitives know how to operate on both concrete data values and on the JAX abstract values. @@ -37,7 +37,7 @@ The JAX-transformed functions must themselves be JAX-traceable functions, to ensure that these transformations can be composed, e.g., `jit(jacfwd(grad(f)))`. -There are pre-defined JAX primitives corresponding to most XLA operations, +There are pre-defined JAX primitives corresponding to most XLA operations, e.g., add, matmul, sin, cos, indexing. JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs using JAX’s implementation of numpy are JAX-traceable and therefore transformable. @@ -49,8 +49,8 @@ one can define a new primitive that encapsulates the behavior of the function. **The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.** Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically -as "multiply_add(x, y, z) = x * y + z". -This function operates on 3 identically-shaped tensors of floating point +as "multiply_add(x, y, z) = x * y + z". +This function operates on 3 identically-shaped tensors of floating point values and performs the operations pointwise. +++ {"id": "HIJYIHNTD1yI"} @@ -58,7 +58,7 @@ values and performs the operations pointwise. ## Using existing primitives The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other -functions that are themselves written using JAX primitives, e.g., those +functions that are themselves written using JAX primitives, e.g., those defined in the `jax.lax` module: ```{code-cell} ipython3 @@ -134,7 +134,7 @@ def trace(name): return str(v) def pp_values(args): return ", ".join([pp(arg) for arg in args]) - + @functools.wraps(func) def func_wrapper(*args): _trace_indent("call {}({})".format(name, pp_values(args))) @@ -164,7 +164,7 @@ class expectNotImplementedError(object): +++ {"id": "Qf4eLrLCFYDl"} -Instead of using `jax.lax` primitives directly, we can use other functions +Instead of using `jax.lax` primitives directly, we can use other functions that are already written in terms of those primitives, such as those in `jax.numpy`: ```{code-cell} ipython3 @@ -182,7 +182,7 @@ def multiply_add_numpy(x, y, z): def square_add_numpy(a, b): return multiply_add_numpy(a, a, b) -print("\nNormal evaluation:") +print("\nNormal evaluation:") print("square_add_numpy = ", square_add_numpy(2., 10.)) print("\nGradient evaluation:") print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) @@ -191,13 +191,13 @@ print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) +++ {"id": "Sg-D8EdeFn4a"} Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and -`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further -below in this colab). -It is important to remember that a JAX-traceable function must be able to +`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further +below in this colab). +It is important to remember that a JAX-traceable function must be able to operate not only on concrete arguments but also on special abstract arguments that JAX may use to abstract the function execution. -The JAX traceability property is satisfied as long as the function is written +The JAX traceability property is satisfied as long as the function is written in terms of JAX primitives. +++ {"id": "WxrQO7-XGLcg"} @@ -206,7 +206,7 @@ in terms of JAX primitives. The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, in order to demonstrate how JAX -primitives work let us pretend that we want to add a new primitive to +primitives work let us pretend that we want to add a new primitive to JAX for the multiply-add functionality. ```{code-cell} ipython3 @@ -218,9 +218,9 @@ multiply_add_p = core.Primitive("multiply_add") # Create the primitive @trace("multiply_add_prim") def multiply_add_prim(x, y, z): """The JAX-traceable way to use the JAX primitive. - + Note that the traced arguments must be passed as positional arguments - to `bind`. + to `bind`. """ return multiply_add_p.bind(x, y, z) @@ -257,7 +257,7 @@ def multiply_add_impl(x, y, z): This function does not need to be JAX traceable. Args: - x, y, z: the concrete arguments of the primitive. Will only be called with + x, y, z: the concrete arguments of the primitive. Will only be called with concrete values. Returns: the concrete result of the primitive. @@ -293,17 +293,17 @@ with expectNotImplementedError(): +++ {"id": "rHS1bAGHH44E"} #### Abstract evaluation rules -In order to JIT the function, and for other transformations as well, -JAX first evaluates it abstractly using only the +In order to JIT the function, and for other transformations as well, +JAX first evaluates it abstractly using only the shape and type of the arguments. This abstract evaluation serves multiple purposes: - * Gets the sequence of JAX primitives that are used in the computation. This - sequence will be compiled. - * Computes the shape and type of all vectors and operations used in the computation. + * Gets the sequence of JAX primitives that are used in the computation. This + sequence will be compiled. + * Computes the shape and type of all vectors and operations used in the computation. -For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. +For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. In the latter case, JAX uses the actual concrete value wrapped as an abstract value. ```{code-cell} ipython3 @@ -316,7 +316,7 @@ def multiply_add_abstract_eval(xs, ys, zs): """Abstract evaluation of the primitive. This function does not need to be JAX traceable. It will be invoked with - abstractions of the actual arguments. + abstractions of the actual arguments. Args: xs, ys, zs: abstractions of the arguments. Result: @@ -349,7 +349,7 @@ with expectNotImplementedError(): JAX compilation works by compiling each primitive into a graph of XLA operations. -This is the biggest hurdle to adding new functionality to JAX, because the +This is the biggest hurdle to adding new functionality to JAX, because the set of XLA operations is limited, and JAX already has pre-defined primitives for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++. @@ -378,7 +378,7 @@ mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') +++ {"id": "K98LX-VaJkFu"} Now we succeed to JIT. Notice below that JAX first evaluates the function -abstractly, which triggers the `multiply_add_abstract_eval` function, and +abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_xla_translation`. @@ -393,7 +393,7 @@ assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. Below is another use of `jit` where we compile only with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads -in the third argument to `multiply_add_abstract_eval` being +in the third argument to `multiply_add_abstract_eval` being `ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with both `ShapedArray` and `ConcreteArray`. @@ -401,7 +401,7 @@ both `ShapedArray` and `ConcreteArray`. :id: mPfTwIBoKOEK :outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b -assert api.jit(lambda x, y: square_add_prim(x, y), +assert api.jit(lambda x, y: square_add_prim(x, y), static_argnums=1)(2., 10.) == 14. ``` @@ -437,16 +437,16 @@ from jax.interpreters import ad def multiply_add_value_and_jvp(arg_values, arg_tangents): """Evaluates the primal output and the tangents (Jacobian-vector product). - Given values of the arguments and perturbation of the arguments (tangents), + Given values of the arguments and perturbation of the arguments (tangents), compute the output of the primitive and the perturbation of the output. - This method must be JAX-traceable. JAX may invoke it with abstract values + This method must be JAX-traceable. JAX may invoke it with abstract values for the arguments and tangents. Args: arg_values: a tuple of arguments - arg_tangents: a tuple with the tangents of the arguments. The tuple has - the same length as the arg_values. Some of the tangents may also be the + arg_tangents: a tuple with the tangents of the arguments. The tuple has + the same length as the arg_values. Some of the tangents may also be the special value ad.Zero to specify a zero tangent. Returns: a pair of the primal output and the tangent. @@ -454,26 +454,26 @@ def multiply_add_value_and_jvp(arg_values, arg_tangents): x, y, z = arg_values xt, yt, zt = arg_tangents _trace("Primal evaluation:") - # Now we have a JAX-traceable computation of the output. - # Normally, we can use the ma primitive itself to compute the primal output. + # Now we have a JAX-traceable computation of the output. + # Normally, we can use the ma primitive itself to compute the primal output. primal_out = multiply_add_prim(x, y, z) - + _trace("Tangent evaluation:") - # We must use a JAX-traceable way to compute the tangent. It turns out that + # We must use a JAX-traceable way to compute the tangent. It turns out that # the output tangent can be computed as (xt * y + x * yt + zt), # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive. - - # We do need to deal specially with Zero. Here we just turn it into a - # proper tensor of 0s (of the same shape as 'x'). - # An alternative would be to check for Zero and perform algebraic + + # We do need to deal specially with Zero. Here we just turn it into a + # proper tensor of 0s (of the same shape as 'x'). + # An alternative would be to check for Zero and perform algebraic # simplification of the output tangent computation. def make_zero(tan): - return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan - + return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan + output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt))) return (primal_out, output_tangent) -# Register the forward differentiation rule with JAX +# Register the forward differentiation rule with JAX ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp ``` @@ -487,7 +487,7 @@ assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.) +++ {"id": "69QsEcu-lP4u"} -TO EXPLAIN: +TO EXPLAIN: * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here. * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet @@ -504,7 +504,7 @@ We can apply JIT to the forward differentiation function: :id: hg-hzVu-N-hv :outputId: 38d32067-e152-4046-ad80-7f95a31ba628 -assert api.jit(lambda arg_values, arg_tangents: +assert api.jit(lambda arg_values, arg_tangents: api.jvp(square_add_prim, arg_values, arg_tangents))( (2., 10.), (1., 1.)) == (14., 5.) ``` @@ -512,7 +512,7 @@ assert api.jit(lambda arg_values, arg_tangents: +++ {"id": "jlZt1_v2mU88"} Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn -evaluates abstractly both the primal and the tangent evaluation (a total of +evaluates abstractly both the primal and the tangent evaluation (a total of 3 invocations of the `ma` primitive). Then we compile the 3 occurrences of the primitive. @@ -521,21 +521,21 @@ of the primitive. ### Reverse differentiation If we attempt now to use reverse differentiation we -see that JAX starts by using the `multiply_add_value_and_jvp` to +see that JAX starts by using the `multiply_add_value_and_jvp` to compute the forward differentiation for abstract values, but then runs -into a `NotImplementedError`. +into a `NotImplementedError`. When computing the reverse differentiation JAX first does abstract evaluation -of the forward differentiation code `multiply_add_value_and_jvp` to obtain a -trace of primitives that compute the output tangent. +of the forward differentiation code `multiply_add_value_and_jvp` to obtain a +trace of primitives that compute the output tangent. Observe that JAX performs this abstract evaluation with concrete values -for the differentiation point, and abstract values for the tangents. +for the differentiation point, and abstract values for the tangents. Observe also that JAX uses the special abstract tangent value `Zero` for -the tangent corresponding to the 3rd argument of `ma`. This reflects the +the tangent corresponding to the 3rd argument of `ma`. This reflects the fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`, which flows to the 3rd argument to `multiply_add_prim`. -Observe also that during the abstract evaluation of the tangent we pass the +Observe also that during the abstract evaluation of the tangent we pass the value 0.0 as the tangent for the 3rd argument. This is due to the use of the `make_zero` function in the definition of `multiply_add_value_and_jvp`. @@ -560,7 +560,7 @@ to use the forward differentiation code to compute reverse differentiation. As explained above, when computing reverse differentiation JAX obtains a trace of primitives that compute the tangent using forward differentiation. -Then, **JAX interprets this trace abstractly backwards** and for each +Then, **JAX interprets this trace abstractly backwards** and for each primitive it applies a **transposition** rule. To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`: @@ -571,7 +571,7 @@ To understand what is going on, consider for now a simpler example of the functi ft = c + yt ``` -By construction, the tangent calculation is always linear in the input tangents. +By construction, the tangent calculation is always linear in the input tangents. The only non-linear operator that may arise in the tangent calculation is multiplication, but then one of the operands is constant. @@ -597,8 +597,8 @@ of the operation: xct += act * 4. ``` -One can verify that this computation produces `xct = 4.` and `yct = 3.`, which -are the partial derivatives of the function `f`. +One can verify that this computation produces `xct = 4.` and `yct = 3.`, which +are the partial derivatives of the function `f`. JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is: ``` @@ -606,10 +606,10 @@ p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz) ``` Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other -arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned +arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned for the constant arguments. -In particular, +In particular, ``` add_transpose(out_ct, _, _) = (out_ct, out_ct) mult_transpose(out_ct, x, _) = (None, x * out_ct) @@ -623,16 +623,16 @@ In particular, def multiply_add_transpose(ct, x, y, z): """Evaluates the transpose of a linear primitive. - This method is only used when computing the backward gradient following - value_and_jvp, and is only needed for primitives that are used in the JVP - calculation for some other primitive. We need transposition for multiply_add_prim, - because we have used multiply_add_prim in the computation of the output_tangent in + This method is only used when computing the backward gradient following + value_and_jvp, and is only needed for primitives that are used in the JVP + calculation for some other primitive. We need transposition for multiply_add_prim, + because we have used multiply_add_prim in the computation of the output_tangent in multiply_add_value_and_jvp. - In our case, multiply_add is not a linear primitive. However, it is used linearly + In our case, multiply_add is not a linear primitive. However, it is used linearly w.r.t. tangents in multiply_add_value_and_jvp: output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt)) - + Always one of the first two multiplicative arguments is a constant. Args: @@ -674,12 +674,12 @@ assert api.grad(square_add_prim)(2., 10.) == 4. +++ {"id": "8M1xLCXW4fK7"} Notice the two calls to `multiply_add_transpose`. They correspond to the two -uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the +uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0. +++ {"id": "EIJs6FYmPg6c"} -#### JIT of reverse differentiation +#### JIT of reverse differentiation Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only abstract values, while in the absence of JIT we used `ConcreteArray`. @@ -721,20 +721,20 @@ from jax.interpreters import batching @trace("multiply_add_batch") def multiply_add_batch(vector_arg_values, batch_axes): """Computes the batched version of the primitive. - + This must be a JAX-traceable function. - + Since the multiply_add primitive already operates pointwise on arbitrary dimension tensors, to batch it we can use the primitive itself. This works as long as both the inputs have the same dimensions and are batched along the same axes. The result is batched along the axis that the inputs are batched. - + Args: vector_arg_values: a tuple of two arguments, each being a tensor of matching shape. batch_axes: the axes that are being batched. See vmap documentation. Returns: - a tuple of the result, and the result axis that was batched. + a tuple of the result, and the result axis that was batched. """ assert batch_axes[0] == batch_axes[1] assert batch_axes[0] == batch_axes[2] diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index 16e623d0f..a4a4d7d16 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -119,7 +119,7 @@ " for w, b in params[:-1]:\n", " outputs = jnp.dot(w, activations) + b\n", " activations = relu(outputs)\n", - " \n", + "\n", " final_w, final_b = params[-1]\n", " logits = jnp.dot(final_w, activations) + final_b\n", " return logits - logsumexp(logits)" @@ -238,7 +238,7 @@ "def one_hot(x, k, dtype=jnp.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", - " \n", + "\n", "def accuracy(params, images, targets):\n", " target_class = jnp.argmax(targets, axis=1)\n", " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index 87533117e..d234700e4 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -96,7 +96,7 @@ def predict(params, image): for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b activations = relu(outputs) - + final_w, final_b = params[-1] logits = jnp.dot(final_w, activations) + final_b return logits - logsumexp(logits) @@ -156,7 +156,7 @@ At this point, we have all the ingredients we need to define our neural network def one_hot(x, k, dtype=jnp.float32): """Create a one-hot encoding of x of size k.""" return jnp.array(x[:, None] == jnp.arange(k), dtype) - + def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 7e65aefe3..2c231bf99 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -35,7 +35,6 @@ }, "outputs": [], "source": [ - "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import jit, grad, vmap\n", @@ -214,7 +213,6 @@ "outputs": [], "source": [ "# Importing Jax functions useful for tracing/interpreting.\n", - "import numpy as np\n", "from functools import wraps\n", "\n", "from jax import core\n", @@ -273,7 +271,7 @@ "def eval_jaxpr(jaxpr, consts, *args):\n", " # Mapping from variable -> value\n", " env = {}\n", - " \n", + "\n", " def read(var):\n", " # Literals are values baked into the Jaxpr\n", " if type(var) is core.Literal:\n", @@ -290,16 +288,16 @@ " # Loop through equations and evaluate primitives using `bind`\n", " for eqn in jaxpr.eqns:\n", " # Read inputs to equation from environment\n", - " invals = safe_map(read, eqn.invars) \n", + " invals = safe_map(read, eqn.invars)\n", " # `bind` is how a primitive is called\n", " outvals = eqn.primitive.bind(*invals, **eqn.params)\n", " # Primitives may return multiple outputs or not\n", - " if not eqn.primitive.multiple_results: \n", + " if not eqn.primitive.multiple_results:\n", " outvals = [outvals]\n", " # Write the results of the primitive into the environment\n", - " safe_map(write, eqn.outvars, outvals) \n", + " safe_map(write, eqn.outvars, outvals)\n", " # Read the final result of the Jaxpr from the environment\n", - " return safe_map(read, jaxpr.outvars) " + " return safe_map(read, jaxpr.outvars)" ] }, { @@ -417,7 +415,7 @@ "source": [ "def inverse_jaxpr(jaxpr, consts, *args):\n", " env = {}\n", - " \n", + "\n", " def read(var):\n", " if type(var) is core.Literal:\n", " return var.val\n", @@ -431,12 +429,12 @@ "\n", " # Looping backward\n", " for eqn in jaxpr.eqns[::-1]:\n", - " # outvars are now invars \n", + " # outvars are now invars\n", " invals = safe_map(read, eqn.outvars)\n", " if eqn.primitive not in inverse_registry:\n", " raise NotImplementedError(\n", " f\"{eqn.primitive} does not have registered inverse.\")\n", - " # Assuming a unary function \n", + " # Assuming a unary function\n", " outval = inverse_registry[eqn.primitive](*invals)\n", " safe_map(write, eqn.invars, [outval])\n", " return safe_map(read, jaxpr.invars)" diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index e52c6a5f8..883d64c37 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -32,7 +32,6 @@ Here we show how to add your own function transformations to the system, by writ ```{code-cell} ipython3 :id: s27RDKvKXFL8 -import numpy as np import jax import jax.numpy as jnp from jax import jit, grad, vmap @@ -146,7 +145,6 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr. :id: BHkg_3P1pXJj # Importing Jax functions useful for tracing/interpreting. -import numpy as np from functools import wraps from jax import core @@ -185,7 +183,7 @@ To do this, we first create an environment to store the values for each of the v def eval_jaxpr(jaxpr, consts, *args): # Mapping from variable -> value env = {} - + def read(var): # Literals are values baked into the Jaxpr if type(var) is core.Literal: @@ -202,16 +200,16 @@ def eval_jaxpr(jaxpr, consts, *args): # Loop through equations and evaluate primitives using `bind` for eqn in jaxpr.eqns: # Read inputs to equation from environment - invals = safe_map(read, eqn.invars) + invals = safe_map(read, eqn.invars) # `bind` is how a primitive is called outvals = eqn.primitive.bind(*invals, **eqn.params) # Primitives may return multiple outputs or not - if not eqn.primitive.multiple_results: + if not eqn.primitive.multiple_results: outvals = [outvals] # Write the results of the primitive into the environment - safe_map(write, eqn.outvars, outvals) + safe_map(write, eqn.outvars, outvals) # Read the final result of the Jaxpr from the environment - return safe_map(read, jaxpr.outvars) + return safe_map(read, jaxpr.outvars) ``` ```{code-cell} ipython3 @@ -279,7 +277,7 @@ Now we just need to define `inverse_jaxpr`, which will walk through the Jaxpr ba def inverse_jaxpr(jaxpr, consts, *args): env = {} - + def read(var): if type(var) is core.Literal: return var.val @@ -293,12 +291,12 @@ def inverse_jaxpr(jaxpr, consts, *args): # Looping backward for eqn in jaxpr.eqns[::-1]: - # outvars are now invars + # outvars are now invars invals = safe_map(read, eqn.outvars) if eqn.primitive not in inverse_registry: raise NotImplementedError( f"{eqn.primitive} does not have registered inverse.") - # Assuming a unary function + # Assuming a unary function outval = inverse_registry[eqn.primitive](*invals) safe_map(write, eqn.invars, [outval]) return safe_map(read, jaxpr.invars) diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 86c8bfea8..478d84935 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -1148,7 +1148,7 @@ " y, vjp_fun = vjp(f, x)\n", " # Use vmap to do a matrix-Jacobian product.\n", " # Here, the matrix is the Euclidean basis, so we get all\n", - " # entries in the Jacobian at once. \n", + " # entries in the Jacobian at once.\n", " J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))\n", " return J\n", " return jacfun\n", @@ -1169,7 +1169,7 @@ "def our_jacfwd(f):\n", " def jacfun(x):\n", " _jvp = lambda s: jvp(f, (x,), (s,))[1]\n", - " Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n", + " Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n", " return jnp.transpose(Jt)\n", " return jacfun\n", "\n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index bc2d803f1..6dcba2470 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -675,7 +675,7 @@ def our_jacrev(f): y, vjp_fun = vjp(f, x) # Use vmap to do a matrix-Jacobian product. # Here, the matrix is the Euclidean basis, so we get all - # entries in the Jacobian at once. + # entries in the Jacobian at once. J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) return J return jacfun @@ -691,7 +691,7 @@ from jax import jacfwd as builtin_jacfwd def our_jacfwd(f): def jacfun(x): _jvp = lambda s: jvp(f, (x,), (s,))[1] - Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x))) + Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x))) return jnp.transpose(Jt) return jacfun diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index f0552e526..041cf6531 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -739,8 +739,6 @@ "metadata": {}, "outputs": [], "source": [ - "from jax.ad_checkpoint import checkpoint_name\n", - "\n", "def predict(params, x):\n", " *Ws, Wlast = params\n", " for i, W in enumerate(Ws):\n", diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index b31e093b6..077a8b6b1 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -370,8 +370,6 @@ Notice also that by providing a policy, we didn't need to edit the code defining Some policies can refer to values named with `jax.ad_checkpoint.checkpoint_name`: ```{code-cell} -from jax.ad_checkpoint import checkpoint_name - def predict(params, x): *Ws, Wlast = params for i, W in enumerate(Ws): diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index 5246e810d..f628625bd 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -410,7 +410,7 @@ ], "source": [ "dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape\n", - " kernel.shape, # only ndim matters, not shape \n", + " kernel.shape, # only ndim matters, not shape\n", " ('NHWC', 'HWIO', 'NHWC')) # the important bit\n", "print(dn)" ] @@ -806,8 +806,8 @@ ], "source": [ "# 1D kernel - WIO layout\n", - "kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], \n", - " [[1, 1, 1], [-1, -1, -1]]], \n", + "kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],\n", + " [[1, 1, 1], [-1, -1, -1]]],\n", " dtype=jnp.float32).transpose([2,1,0])\n", "# 1D data - NWC layout\n", "data = np.zeros((1, 200, 2), dtype=jnp.float32)\n", @@ -895,8 +895,8 @@ "# Random 3D kernel - HWDIO layout\n", "kernel = jnp.array([\n", " [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n", - " [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n", - " [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n", + " [[0, -1, 0], [-1, 0, -1], [0, -1, 0]],\n", + " [[0, 0, 0], [0, 1, 0], [0, 0, 0]]],\n", " dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]\n", "\n", "# 3D data - NHWDC layout\n", @@ -919,7 +919,6 @@ "print(\"out shape: \", out.shape)\n", "\n", "# Make some simple 3d density plots:\n", - "from mpl_toolkits.mplot3d import Axes3D\n", "def make_alpha(cmap):\n", " my_cmap = cmap(jnp.arange(cmap.N))\n", " my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 2dec35847..467deeec2 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -210,7 +210,7 @@ The important argument is the 3-tuple of axis layout arguments: :outputId: d5a569b3-febc-4832-f725-1d5e8fd31b9b dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape - kernel.shape, # only ndim matters, not shape + kernel.shape, # only ndim matters, not shape ('NHWC', 'HWIO', 'NHWC')) # the important bit print(dn) ``` @@ -363,8 +363,8 @@ You aren't limited to 2D convolutions, a simple 1D demo is below: :outputId: 67c46ace-6adc-4c47-c1c7-1f185be5fd4b # 1D kernel - WIO layout -kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], - [[1, 1, 1], [-1, -1, -1]]], +kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], + [[1, 1, 1], [-1, -1, -1]]], dtype=jnp.float32).transpose([2,1,0]) # 1D data - NWC layout data = np.zeros((1, 200, 2), dtype=jnp.float32) @@ -406,8 +406,8 @@ import matplotlib as mpl # Random 3D kernel - HWDIO layout kernel = jnp.array([ [[0, 0, 0], [0, 1, 0], [0, 0, 0]], - [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], - [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], + [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], + [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis] # 3D data - NHWDC layout @@ -430,7 +430,6 @@ out = lax.conv_general_dilated(data, # lhs = image tensor print("out shape: ", out.shape) # Make some simple 3d density plots: -from mpl_toolkits.mplot3d import Axes3D def make_alpha(cmap): my_cmap = cmap(jnp.arange(cmap.N)) my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3 diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index 25c551c98..3c022124e 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -840,7 +840,6 @@ }, "outputs": [], "source": [ - "from functools import partial\n", "j1 = partial(jv, 1)\n", "z = jnp.arange(5.0)" ] diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index c93139e16..be76f9913 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -410,7 +410,6 @@ This lets us call into `scipy.special.jv` from transformed JAX code, including w ```{code-cell} :id: f4e46670f4e4 -from functools import partial j1 = partial(jv, 1) z = jnp.arange(5.0) ``` diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 7d353c924..91f2ee571 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -132,7 +132,7 @@ " for w, b in params[:-1]:\n", " outputs = jnp.dot(w, activations) + b\n", " activations = relu(outputs)\n", - " \n", + "\n", " final_w, final_b = params[-1]\n", " logits = jnp.dot(final_w, activations) + final_b\n", " return logits - logsumexp(logits)" @@ -251,7 +251,7 @@ "def one_hot(x, k, dtype=jnp.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", - " \n", + "\n", "def accuracy(params, images, targets):\n", " target_class = jnp.argmax(targets, axis=1)\n", " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 2f7ba3271..0c0c4bc5c 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -104,7 +104,7 @@ def predict(params, image): for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b activations = relu(outputs) - + final_w, final_b = params[-1] logits = jnp.dot(final_w, activations) + final_b return logits - logsumexp(logits) @@ -164,7 +164,7 @@ At this point, we have all the ingredients we need to define our neural network def one_hot(x, k, dtype=jnp.float32): """Create a one-hot encoding of x of size k.""" return jnp.array(x[:, None] == jnp.arange(k), dtype) - + def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index a355959ba..9aef1a8eb 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -25,17 +25,10 @@ }, "outputs": [], "source": [ - "import functools\n", - "import itertools\n", - "import re\n", - "import sys\n", - "import time\n", - "\n", - "from matplotlib.pyplot import *\n", + "import matplotlib.pyplot as plt\n", "\n", "import jax\n", "\n", - "from jax import lax\n", "import jax.numpy as jnp\n", "import jax.scipy as jsp\n", "from jax import random\n", @@ -348,7 +341,7 @@ "def elbo(beta_loc, beta_log_scale, epsilon):\n", " beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n", " return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n", - " \n", + "\n", "elbo = jax.jit(elbo)\n", "elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))" ] @@ -548,25 +541,16 @@ } ], "source": [ - "figure(figsize=(7, 7))\n", - "plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n", - "plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n", - "plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n", + "plt.figure(figsize=(7, 7))\n", + "plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n", + "plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\\sigma$ Error Bars')\n", + "plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n", "plot_scale = 3\n", - "plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n", - "xlabel('True beta')\n", - "ylabel('Estimated beta')\n", - "legend(loc='best')" + "plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n", + "plt.xlabel('True beta')\n", + "plt.ylabel('Estimated beta')\n", + "plt.legend(loc='best')" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "id": "_bXdOlvUEJl0" - }, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index 9ecbd9d23..f8cfc3553 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -27,17 +27,10 @@ Inspired by a notebook by @davmre. ```{code-cell} ipython3 :id: 8RZDkfbV3zdR -import functools -import itertools -import re -import sys -import time - -from matplotlib.pyplot import * +import matplotlib.pyplot as plt import jax -from jax import lax import jax.numpy as jnp import jax.scipy as jsp from jax import random @@ -192,7 +185,7 @@ batched_log_joint = jax.jit(jax.vmap(log_joint)) def elbo(beta_loc, beta_log_scale, epsilon): beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi)) - + elbo = jax.jit(elbo) elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1))) ``` @@ -240,19 +233,13 @@ Coverage isn't quite as good as we might like, but it's not bad, and nobody said :id: zt1NBLoVHtOG :outputId: fb159795-e6e7-497c-e501-9933ec761af4 -figure(figsize=(7, 7)) -plot(true_beta, beta_loc, '.', label='Approximated Posterior Means') -plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars') -plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.') +plt.figure(figsize=(7, 7)) +plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means') +plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\sigma$ Error Bars') +plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.') plot_scale = 3 -plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k') -xlabel('True beta') -ylabel('Estimated beta') -legend(loc='best') -``` - -```{code-cell} ipython3 -:id: _bXdOlvUEJl0 - - +plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k') +plt.xlabel('True beta') +plt.ylabel('Estimated beta') +plt.legend(loc='best') ``` diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 8fa210779..b7bc919eb 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -189,9 +189,6 @@ ], "source": [ "# Pardon the boilerplate; constructing a sharding will become easier in future!\n", - "from jax.sharding import Mesh\n", - "from jax.sharding import PartitionSpec\n", - "from jax.sharding import NamedSharding\n", "from jax.experimental import mesh_utils\n", "\n", "P = jax.sharding.PartitionSpec\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 345ca7987..85dfcdc17 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -73,9 +73,6 @@ Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimens :outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5 # Pardon the boilerplate; constructing a sharding will become easier in future! -from jax.sharding import Mesh -from jax.sharding import PartitionSpec -from jax.sharding import NamedSharding from jax.experimental import mesh_utils P = jax.sharding.PartitionSpec diff --git a/jax/_src/core.py b/jax/_src/core.py index 61ed81cde..f80cd0418 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1268,7 +1268,7 @@ def new_base_main(trace_type: type[Trace], @contextmanager def pop_level(level: int): if level == 0: - return (yield) + return (yield) # noqa: B901 prev, thread_local_state.trace_state.trace_stack.stack = \ thread_local_state.trace_state.trace_stack.stack, \ thread_local_state.trace_state.trace_stack.stack[:level] diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 51a86fdcb..7ceac8940 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -760,7 +760,7 @@ def _dot_product_attention_fwd_partition( scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, mesh, arg_shapes, result_shape): # args sharding - arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) out_shardings = _infer_fwd_output_sharding( mesh, arg_shapes, variadic_args, is_training) impl = functools.partial( @@ -810,7 +810,7 @@ def _dot_product_attention_bwd_partition( arg_shapes, result_shape): out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) # args sharding - arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) def sharded_impl(*args): impl = functools.partial( _dot_product_attention_bwd_impl, diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index bc548a69c..af3d55f58 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -238,7 +238,7 @@ class BufferedRef: Returns: Initialized BufferedRef """ - block_shape = tuple([1 if x is None else x for x in spec.block_shape]) + block_shape = tuple(1 if x is None else x for x in spec.block_shape) if buffer_type is BufferType.ACCUMULATOR: accum_ref = VMEM(block_shape, dtype) else: @@ -310,7 +310,7 @@ class BufferedRef: @property def current_ref(self): buffer_slice = tuple( - [0 if x is None else slice(None) for x in self.block_shape]) + 0 if x is None else slice(None) for x in self.block_shape) if self.memory_space == VMEM: return self.vmem_ref.at[buffer_slice] else: @@ -349,7 +349,7 @@ class BufferedRef: def compute_slice(self, grid_indices): """Compute DMA slice from grid indices.""" - block_shape = tuple([1 if x is None else x for x in self.block_shape]) + block_shape = tuple(1 if x is None else x for x in self.block_shape) indices = self.compute_index(*grid_indices) return jax.tree.map(_make_ds, indices, block_shape) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index b0f06624b..72533e619 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -172,7 +172,7 @@ def _normalize_tolerance(tol): if isinstance(tol, dict): return {np.dtype(k): v for k, v in tol.items()} else: - return {k: tol for k in _default_tolerance} + return dict.fromkeys(_default_tolerance, tol) def join_tolerance(tol1, tol2): tol1 = _normalize_tolerance(tol1) diff --git a/pyproject.toml b/pyproject.toml index 193c6b9fd..fb706dbbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,8 @@ ignore = [ "C408", # Unnecessary map usage "C417", + # Unnecessary dict comprehension for iterable + "C420", # Object names too complex "C901", # Local variable is assigned to but never used @@ -141,7 +143,16 @@ max-complexity = 18 [tool.ruff.lint.per-file-ignores] # F811: Redefinition of unused name. +# F821: Undefined name. "docs/autodidax.py" = ["F811"] +"docs/pallas/tpu/matmul.ipynb" = ["F811"] +"docs/pallas/tpu/distributed.ipynb" = ["F811"] +"docs/pallas/quickstart.ipynb" = ["F811"] +"docs/notebooks/autodiff_cookbook.ipynb" = ["F811", "F821"] +"docs/notebooks/autodiff_remat.ipynb" = ["F811", "F821"] +"docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb" = ["F811"] +"docs/jep/9407-type-promotion.ipynb" = ["F811"] +"docs/autodidax.ipynb" = ["F811"] # Note: we don't use jax/*.py because this matches contents of jax/_src "__init__.py" = ["F401"] "jax/abstract_arrays.py" = ["F401"] diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index ce1c02f5a..ec9a7cd8b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -308,7 +308,7 @@ class MemRefTest(TestCase): expanded_shape = get_packed_shape(strides, shape) total_size = np.prod(expanded_shape) np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape) - index = tuple([slice(0, s) for s in shape]) + index = tuple(slice(0, s) for s in shape) # Reference implementation def np_fold(inp, dim, fold_rank): diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index 1540b3d20..e8cd40a67 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -88,15 +88,6 @@ "height": 68 } }, - "source": [ - "from jaxlib import xla_extension\n", - "import jax\n", - "key = jax.random.PRNGKey(1701)\n", - "arr = jax.random.normal(key, (1000,))\n", - "device = arr.device()\n", - "print(f\"JAX device type: {device}\")\n", - "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\"" - ], "execution_count": 2, "outputs": [ { diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 2cad1d064..2818712c0 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -500,11 +500,11 @@ class IndexerOpsTest(PallasBaseTest): def body(x_ref, y_ref1, y_ref2): if slice_type == "slice": slices = tuple( - [slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides)] + slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides) ) else: slices = tuple( - [pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides)] + pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides) ) if indexer_type == "state": y_ref1[...] = x_ref[slices]