mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 23:06:05 +00:00
1530 lines
59 KiB
Plaintext
1530 lines
59 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "vfxqky4PCUnh"
|
||
},
|
||
"source": [
|
||
"# How JAX primitives work\n",
|
||
"\n",
|
||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n",
|
||
"\n",
|
||
"*necula@google.com*, October 2019.\n",
|
||
"\n",
|
||
"JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,\n",
|
||
"`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, \n",
|
||
"which means that as the Python function executes\n",
|
||
"the only operations it applies to the data are either inspections of data\n",
|
||
"attributes such as shape or type, or special operations called JAX primitives.\n",
|
||
"In particular, a JAX-traceable function is sometimes invoked by JAX with\n",
|
||
"abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, \n",
|
||
"which captures the type and the shape of values, but not the concrete data values.\n",
|
||
"JAX primitives know how to operate on both concrete data\n",
|
||
"values and on the JAX abstract values.\n",
|
||
"\n",
|
||
"\n",
|
||
"The JAX-transformed functions must themselves be JAX-traceable functions,\n",
|
||
"to ensure that these transformations\n",
|
||
"can be composed, e.g., `jit(jacfwd(grad(f)))`.\n",
|
||
"\n",
|
||
"There are pre-defined JAX primitives corresponding to most XLA operations, \n",
|
||
"e.g., add, matmul, sin, cos, indexing.\n",
|
||
"JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs\n",
|
||
"using JAX’s implementation of numpy are JAX-traceable and therefore transformable.\n",
|
||
"Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.\n",
|
||
"\n",
|
||
"The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives,\n",
|
||
"one can define a new primitive that encapsulates the behavior of the function.\n",
|
||
"\n",
|
||
"**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**\n",
|
||
"\n",
|
||
"Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n",
|
||
"as \"multiply_add(x, y, z) = x * y + z\". \n",
|
||
"This function operates on 3 identically-shaped tensors of floating point \n",
|
||
"values and performs the operations pointwise."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "HIJYIHNTD1yI"
|
||
},
|
||
"source": [
|
||
"## Using existing primitives\n",
|
||
"\n",
|
||
"The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other\n",
|
||
"functions that are themselves written using JAX primitives, e.g., those \n",
|
||
"defined in the `jax.lax` module:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"id": "tbOF0LB0EMne",
|
||
"outputId": "3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"square_add_lax = 14.0\n",
|
||
"grad(square_add_lax) = 4.0\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.\n",
|
||
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from jax import lax\n",
|
||
"from jax._src import api\n",
|
||
"\n",
|
||
"def multiply_add_lax(x, y, z):\n",
|
||
" \"\"\"Implementation of multiply-add using the jax.lax primitives.\"\"\"\n",
|
||
" return lax.add(lax.mul(x, y), z)\n",
|
||
"\n",
|
||
"\n",
|
||
"def square_add_lax(a, b):\n",
|
||
" \"\"\"A square-add function using the newly defined multiply-add.\"\"\"\n",
|
||
" return multiply_add_lax(a, a, b)\n",
|
||
"\n",
|
||
"print(\"square_add_lax = \", square_add_lax(2., 10.))\n",
|
||
"# Differentiate w.r.t. the first argument\n",
|
||
"print(\"grad(square_add_lax) = \", api.grad(square_add_lax, argnums=0)(2.0, 10.))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Cgv60Wm3E_D5"
|
||
},
|
||
"source": [
|
||
"In order to understand how JAX is internally using the primitives,\n",
|
||
"we add some helpers for tracing function calls."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 0,
|
||
"metadata": {
|
||
"cellView": "form",
|
||
"id": "mQRQGEGiE53K"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Helper functions (execute this cell)\n",
|
||
"import functools\n",
|
||
"import traceback\n",
|
||
"\n",
|
||
"_indentation = 0\n",
|
||
"def _trace(msg=None):\n",
|
||
" \"\"\"Print a message at current indentation.\"\"\"\n",
|
||
" if msg is not None:\n",
|
||
" print(\" \" * _indentation + msg)\n",
|
||
"\n",
|
||
"def _trace_indent(msg=None):\n",
|
||
" \"\"\"Print a message and then indent the rest.\"\"\"\n",
|
||
" global _indentation\n",
|
||
" _trace(msg)\n",
|
||
" _indentation = 1 + _indentation\n",
|
||
"\n",
|
||
"def _trace_unindent(msg=None):\n",
|
||
" \"\"\"Unindent then print a message.\"\"\"\n",
|
||
" global _indentation\n",
|
||
" _indentation = _indentation - 1\n",
|
||
" _trace(msg)\n",
|
||
"\n",
|
||
"def trace(name):\n",
|
||
" \"\"\"A decorator for functions to trace arguments and results.\"\"\"\n",
|
||
"\n",
|
||
" def trace_func(func): # pylint: disable=missing-docstring\n",
|
||
" def pp(v):\n",
|
||
" \"\"\"Print certain values more succinctly\"\"\"\n",
|
||
" vtype = str(type(v))\n",
|
||
" if \"jax._src.lib.xla_bridge._JaxComputationBuilder\" in vtype:\n",
|
||
" return \"<JaxComputationBuilder>\"\n",
|
||
" elif \"jaxlib.xla_extension.XlaOp\" in vtype:\n",
|
||
" return \"<XlaOp at 0x{:x}>\".format(id(v))\n",
|
||
" elif (\"partial_eval.JaxprTracer\" in vtype or\n",
|
||
" \"batching.BatchTracer\" in vtype or\n",
|
||
" \"ad.JVPTracer\" in vtype):\n",
|
||
" return \"Traced<{}>\".format(v.aval)\n",
|
||
" elif isinstance(v, tuple):\n",
|
||
" return \"({})\".format(pp_values(v))\n",
|
||
" else:\n",
|
||
" return str(v)\n",
|
||
" def pp_values(args):\n",
|
||
" return \", \".join([pp(arg) for arg in args])\n",
|
||
" \n",
|
||
" @functools.wraps(func)\n",
|
||
" def func_wrapper(*args):\n",
|
||
" _trace_indent(\"call {}({})\".format(name, pp_values(args)))\n",
|
||
" res = func(*args)\n",
|
||
" _trace_unindent(\"|<- {} = {}\".format(name, pp(res)))\n",
|
||
" return res\n",
|
||
"\n",
|
||
" return func_wrapper\n",
|
||
"\n",
|
||
" return trace_func\n",
|
||
"\n",
|
||
"class expectNotImplementedError(object):\n",
|
||
" \"\"\"Context manager to check for NotImplementedError.\"\"\"\n",
|
||
" def __enter__(self): pass\n",
|
||
" def __exit__(self, type, value, tb):\n",
|
||
" global _indentation\n",
|
||
" _indentation = 0\n",
|
||
" if type is NotImplementedError:\n",
|
||
" print(\"\\nFound expected exception:\")\n",
|
||
" traceback.print_exc(limit=3)\n",
|
||
" return True\n",
|
||
" elif type is None: # No exception\n",
|
||
" assert False, \"Expected NotImplementedError\"\n",
|
||
" else:\n",
|
||
" return False"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Qf4eLrLCFYDl"
|
||
},
|
||
"source": [
|
||
"Instead of using `jax.lax` primitives directly, we can use other functions \n",
|
||
"that are already written in terms of those primitives, such as those in `jax.numpy`:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"id": "QhKorz6cFRJb",
|
||
"outputId": "aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"Normal evaluation:\n",
|
||
"call square_add_numpy(2.0, 10.0)\n",
|
||
" call multiply_add_numpy(2.0, 2.0, 10.0)\n",
|
||
" |<- multiply_add_numpy = 14.0\n",
|
||
"|<- square_add_numpy = 14.0\n",
|
||
"square_add_numpy = 14.0\n",
|
||
"\n",
|
||
"Gradient evaluation:\n",
|
||
"call square_add_numpy(Traced<ConcreteArray(2.0)>, 10.0)\n",
|
||
" call multiply_add_numpy(Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(2.0)>, 10.0)\n",
|
||
" |<- multiply_add_numpy = Traced<ConcreteArray(14.0)>\n",
|
||
"|<- square_add_numpy = Traced<ConcreteArray(14.0)>\n",
|
||
"grad(square_add_numpy) = 4.0\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import jax.numpy as jnp\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"@trace(\"multiply_add_numpy\")\n",
|
||
"def multiply_add_numpy(x, y, z):\n",
|
||
" return jnp.add(jnp.multiply(x, y), z)\n",
|
||
"\n",
|
||
"@trace(\"square_add_numpy\")\n",
|
||
"def square_add_numpy(a, b):\n",
|
||
" return multiply_add_numpy(a, a, b)\n",
|
||
"\n",
|
||
"print(\"\\nNormal evaluation:\") \n",
|
||
"print(\"square_add_numpy = \", square_add_numpy(2., 10.))\n",
|
||
"print(\"\\nGradient evaluation:\")\n",
|
||
"print(\"grad(square_add_numpy) = \", api.grad(square_add_numpy)(2.0, 10.))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Sg-D8EdeFn4a"
|
||
},
|
||
"source": [
|
||
"Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and\n",
|
||
"`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further \n",
|
||
"below in this colab). \n",
|
||
"It is important to remember that a JAX-traceable function must be able to \n",
|
||
"operate not only on concrete arguments but also on special abstract arguments\n",
|
||
"that JAX may use to abstract the function execution.\n",
|
||
"\n",
|
||
"The JAX traceability property is satisfied as long as the function is written \n",
|
||
"in terms of JAX primitives."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "WxrQO7-XGLcg"
|
||
},
|
||
"source": [
|
||
"## Defining new JAX primitives\n",
|
||
"\n",
|
||
"The right way to add support for multiply-add is in terms of existing\n",
|
||
"JAX primitives, as shown above. However, in order to demonstrate how JAX\n",
|
||
"primitives work let us pretend that we want to add a new primitive to \n",
|
||
"JAX for the multiply-add functionality."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 0,
|
||
"metadata": {
|
||
"id": "cPqAH1XOGTN4"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from jax import core\n",
|
||
"multiply_add_p = core.Primitive(\"multiply_add\") # Create the primitive\n",
|
||
"\n",
|
||
"@trace(\"multiply_add_prim\")\n",
|
||
"def multiply_add_prim(x, y, z):\n",
|
||
" \"\"\"The JAX-traceable way to use the JAX primitive.\n",
|
||
" \n",
|
||
" Note that the traced arguments must be passed as positional arguments\n",
|
||
" to `bind`. \n",
|
||
" \"\"\"\n",
|
||
" return multiply_add_p.bind(x, y, z)\n",
|
||
"\n",
|
||
"@trace(\"square_add_prim\")\n",
|
||
"def square_add_prim(a, b):\n",
|
||
" \"\"\"A square-add function implemented using the new JAX-primitive.\"\"\"\n",
|
||
" return multiply_add_prim(a, a, b)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "LMzs5PAKGr-4"
|
||
},
|
||
"source": [
|
||
"If we try to call the newly defined functions we get an error, because\n",
|
||
"we have not yet told JAX anything about the semantics of the new primitive."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"id": "_X3PAYxhGpWd",
|
||
"outputId": "90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(2.0, 10.0)\n",
|
||
" call multiply_add_prim(2.0, 2.0, 10.0)\n",
|
||
"\n",
|
||
"Found expected exception:\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Traceback (most recent call last):\n",
|
||
" File \"<ipython-input-5-acee329b29d0>\", line 2, in <module>\n",
|
||
" square_add_prim(2., 10.)\n",
|
||
" File \"<ipython-input-2-0ffadd93fbdc>\", line 47, in func_wrapper\n",
|
||
" res = func(*args)\n",
|
||
" File \"<ipython-input-4-c5402c1795f0>\", line 16, in square_add_prim\n",
|
||
" return multiply_add_prim(a, a, b)\n",
|
||
"NotImplementedError: Evaluation rule for 'multiply_add' not implemented\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"with expectNotImplementedError():\n",
|
||
" square_add_prim(2., 10.)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "elha0FdgHSEF"
|
||
},
|
||
"source": [
|
||
"### Primal evaluation rules"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"id": "FT34FFAGHARU",
|
||
"outputId": "4c54f1c2-8a50-4788-90e1-06aee412c43b"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<function __main__.multiply_add_impl>"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"@trace(\"multiply_add_impl\")\n",
|
||
"def multiply_add_impl(x, y, z):\n",
|
||
" \"\"\"Concrete implementation of the primitive.\n",
|
||
"\n",
|
||
" This function does not need to be JAX traceable.\n",
|
||
" Args:\n",
|
||
" x, y, z: the concrete arguments of the primitive. Will only be called with \n",
|
||
" concrete values.\n",
|
||
" Returns:\n",
|
||
" the concrete result of the primitive.\n",
|
||
" \"\"\"\n",
|
||
" # Note that we can use the original numpy, which is not JAX traceable\n",
|
||
" return np.add(np.multiply(x, y), z)\n",
|
||
"\n",
|
||
"# Now we register the primal implementation with JAX\n",
|
||
"multiply_add_p.def_impl(multiply_add_impl)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {
|
||
"id": "G5bstKaeNAVV",
|
||
"outputId": "deb94d5b-dfea-4e6f-9ec2-70b416c996c5"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(2.0, 10.0)\n",
|
||
" call multiply_add_prim(2.0, 2.0, 10.0)\n",
|
||
" call multiply_add_impl(2.0, 2.0, 10.0)\n",
|
||
" |<- multiply_add_impl = 14.0\n",
|
||
" |<- multiply_add_prim = 14.0\n",
|
||
"|<- square_add_prim = 14.0\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"assert square_add_prim(2., 10.) == 14."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "upBf-uAuHhPJ"
|
||
},
|
||
"source": [
|
||
"### JIT\n",
|
||
"\n",
|
||
"If we now try to use `jit` we get a `NotImplementedError`:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"id": "QG-LULjiHk4b",
|
||
"outputId": "d4ef4406-8dae-4c96-97ca-b662340474ee"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
"\n",
|
||
"Found expected exception:\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Traceback (most recent call last):\n",
|
||
" File \"<ipython-input-8-d4853f4fcae2>\", line 2, in <module>\n",
|
||
" api.jit(square_add_prim)(2., 10.)\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n",
|
||
" out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n",
|
||
" outs = primitive.impl(f, *args, **params)\n",
|
||
"NotImplementedError: Abstract evaluation for 'multiply_add' not implemented\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"with expectNotImplementedError():\n",
|
||
" api.jit(square_add_prim)(2., 10.)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "rHS1bAGHH44E"
|
||
},
|
||
"source": [
|
||
"#### Abstract evaluation rules\n",
|
||
"In order to JIT the function, and for other transformations as well, \n",
|
||
"JAX first evaluates it abstractly using only the \n",
|
||
"shape and type of the arguments. This abstract evaluation serves multiple\n",
|
||
"purposes:\n",
|
||
"\n",
|
||
" * Gets the sequence of JAX primitives that are used in the computation. This \n",
|
||
" sequence will be compiled. \n",
|
||
" * Computes the shape and type of all vectors and operations used in the computation. \n",
|
||
"\n",
|
||
"\n",
|
||
"For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. \n",
|
||
"In the latter case, JAX uses the actual concrete value wrapped as an abstract value."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"id": "ctQmEeckIbdo",
|
||
"outputId": "e751d0cc-460e-4ffd-df2e-fdabf9cffdc2"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<function __main__.multiply_add_abstract_eval>"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from jax._src import abstract_arrays\n",
|
||
"@trace(\"multiply_add_abstract_eval\")\n",
|
||
"def multiply_add_abstract_eval(xs, ys, zs):\n",
|
||
" \"\"\"Abstract evaluation of the primitive.\n",
|
||
"\n",
|
||
" This function does not need to be JAX traceable. It will be invoked with\n",
|
||
" abstractions of the actual arguments. \n",
|
||
" Args:\n",
|
||
" xs, ys, zs: abstractions of the arguments.\n",
|
||
" Result:\n",
|
||
" a ShapedArray for the result of the primitive.\n",
|
||
" \"\"\"\n",
|
||
" assert xs.shape == ys.shape\n",
|
||
" assert xs.shape == zs.shape\n",
|
||
" return abstract_arrays.ShapedArray(xs.shape, xs.dtype)\n",
|
||
"\n",
|
||
"# Now we register the abstract evaluation with JAX\n",
|
||
"multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "RPN88X6YI43A"
|
||
},
|
||
"source": [
|
||
"If we re-attempt to JIT, we see how the abstract evaluation proceeds, but\n",
|
||
"we get another error, about missing the actual XLA compilation rule:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"id": "eOcNR92SI2h-",
|
||
"outputId": "356ef229-3703-4696-cc3d-7c05de405fb0"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"|<- square_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"\n",
|
||
"Found expected exception:\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Traceback (most recent call last):\n",
|
||
" File \"<ipython-input-10-d4853f4fcae2>\", line 2, in <module>\n",
|
||
" api.jit(square_add_prim)(2., 10.)\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n",
|
||
" out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n",
|
||
" outs = primitive.impl(f, *args, **params)\n",
|
||
"NotImplementedError: XLA translation rule for primitive 'multiply_add' not found\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"with expectNotImplementedError():\n",
|
||
" api.jit(square_add_prim)(2., 10.)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "9IOV1R-fJMHp"
|
||
},
|
||
"source": [
|
||
"#### XLA Compilation rules\n",
|
||
"\n",
|
||
"JAX compilation works by compiling each primitive into a graph of XLA operations.\n",
|
||
"\n",
|
||
"This is the biggest hurdle to adding new functionality to JAX, because the \n",
|
||
"set of XLA operations is limited, and JAX already has pre-defined primitives\n",
|
||
"for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 0,
|
||
"metadata": {
|
||
"id": "FYQWSSjKJaWP"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from jax._src.lib import xla_client\n",
|
||
"@trace(\"multiply_add_xla_translation\")\n",
|
||
"def multiply_add_xla_translation(ctx, avals_in, avals_out, xc, yc, zc):\n",
|
||
" \"\"\"The compilation to XLA of the primitive.\n",
|
||
"\n",
|
||
" Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the\n",
|
||
" result of the function.\n",
|
||
"\n",
|
||
" Does not need to be a JAX-traceable function.\n",
|
||
" \"\"\"\n",
|
||
" return [xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)]\n",
|
||
"\n",
|
||
"# Now we register the XLA compilation rule with JAX\n",
|
||
"# TODO: for GPU? and TPU?\n",
|
||
"from jax.interpreters import xla\n",
|
||
"xla.register_translation(multiply_add_p, multiply_add_xla_translation, platform='cpu')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "K98LX-VaJkFu"
|
||
},
|
||
"source": [
|
||
"Now we succeed to JIT. Notice below that JAX first evaluates the function\n",
|
||
"abstractly, which triggers the `multiply_add_abstract_eval` function, and \n",
|
||
"then compiles the set of primitives it has encountered, including `multiply_add`.\n",
|
||
"At this point JAX invokes `multiply_add_xla_translation`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"id": "rj3TLsolJgEc",
|
||
"outputId": "e384bee4-1e9c-4344-f49c-d3b5ec08eb32"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"|<- square_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7f44d8b2a228>, <XlaOp at 0x7f44d8b2a228>, <XlaOp at 0x7f44d8b2a0d8>)\n",
|
||
"|<- multiply_add_xla_translation = <XlaOp at 0x7f44d8b2a880>\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Omrez-2_KFfo"
|
||
},
|
||
"source": [
|
||
"Below is another use of `jit` where we compile only\n",
|
||
"with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads\n",
|
||
"in the third argument to `multiply_add_abstract_eval` being \n",
|
||
"`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with\n",
|
||
"both `ShapedArray` and `ConcreteArray`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"id": "mPfTwIBoKOEK",
|
||
"outputId": "b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ShapedArray(float32[])>, 10.0)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, 10.0)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(10.0))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"|<- square_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7f44d8e86df8>, <XlaOp at 0x7f44d8e86df8>, <XlaOp at 0x7f44d8e86c00>)\n",
|
||
"|<- multiply_add_xla_translation = <XlaOp at 0x7f44d8e867d8>\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"assert api.jit(lambda x, y: square_add_prim(x, y), \n",
|
||
" static_argnums=1)(2., 10.) == 14."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "_Ya3B5l4J1VA"
|
||
},
|
||
"source": [
|
||
"### Forward differentiation\n",
|
||
"\n",
|
||
"JAX implements forward differentiation in the form of\n",
|
||
"a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)).\n",
|
||
"\n",
|
||
"If we attempt now to compute the `jvp` function we get an\n",
|
||
"error because we have not yet told JAX how to differentiate\n",
|
||
"the `multiply_add` primitive."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"id": "OxDx6NQnKwMI",
|
||
"outputId": "ce659ef3-c03c-4856-f252-49ec4b6eb964"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(10.0)>)\n",
|
||
" call multiply_add_prim(Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(10.0)>)\n",
|
||
"\n",
|
||
"Found expected exception:\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Traceback (most recent call last):\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 217, in process_primitive\n",
|
||
" jvp = primitive_jvps[primitive]\n",
|
||
"KeyError: multiply_add\n",
|
||
"\n",
|
||
"During handling of the above exception, another exception occurred:\n",
|
||
"\n",
|
||
"Traceback (most recent call last):\n",
|
||
" File \"<ipython-input-14-7e56904baba2>\", line 2, in <module>\n",
|
||
" api.jvp(square_add_prim, (2., 10.), (1., 1.))\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 978, in jvp\n",
|
||
" out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/linear_util.py\", line 165, in call_wrapped\n",
|
||
" ans = self.f(*args, **dict(self.params, **kwargs))\n",
|
||
"NotImplementedError: Forward-mode differentiation rule for 'multiply_add' not implemented\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# The second argument `(2., 10.)` are the argument values\n",
|
||
"# where we evaluate the Jacobian, and the third `(1., 1.)`\n",
|
||
"# are the values of the tangents for the arguments.\n",
|
||
"with expectNotImplementedError():\n",
|
||
" api.jvp(square_add_prim, (2., 10.), (1., 1.))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 0,
|
||
"metadata": {
|
||
"id": "zxG24C1JMIMM"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from jax.interpreters import ad\n",
|
||
"\n",
|
||
"\n",
|
||
"@trace(\"multiply_add_value_and_jvp\")\n",
|
||
"def multiply_add_value_and_jvp(arg_values, arg_tangents):\n",
|
||
" \"\"\"Evaluates the primal output and the tangents (Jacobian-vector product).\n",
|
||
"\n",
|
||
" Given values of the arguments and perturbation of the arguments (tangents), \n",
|
||
" compute the output of the primitive and the perturbation of the output.\n",
|
||
"\n",
|
||
" This method must be JAX-traceable. JAX may invoke it with abstract values \n",
|
||
" for the arguments and tangents.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" arg_values: a tuple of arguments\n",
|
||
" arg_tangents: a tuple with the tangents of the arguments. The tuple has \n",
|
||
" the same length as the arg_values. Some of the tangents may also be the \n",
|
||
" special value ad.Zero to specify a zero tangent.\n",
|
||
" Returns:\n",
|
||
" a pair of the primal output and the tangent.\n",
|
||
" \"\"\"\n",
|
||
" x, y, z = arg_values\n",
|
||
" xt, yt, zt = arg_tangents\n",
|
||
" _trace(\"Primal evaluation:\")\n",
|
||
" # Now we have a JAX-traceable computation of the output. \n",
|
||
" # Normally, we can use the ma primtive itself to compute the primal output. \n",
|
||
" primal_out = multiply_add_prim(x, y, z)\n",
|
||
" \n",
|
||
" _trace(\"Tangent evaluation:\")\n",
|
||
" # We must use a JAX-traceable way to compute the tangent. It turns out that \n",
|
||
" # the output tangent can be computed as (xt * y + x * yt + zt),\n",
|
||
" # which we can implement in a JAX-traceable way using the same \"multiply_add_prim\" primitive.\n",
|
||
" \n",
|
||
" # We do need to deal specially with Zero. Here we just turn it into a \n",
|
||
" # proper tensor of 0s (of the same shape as 'x'). \n",
|
||
" # An alternative would be to check for Zero and perform algebraic \n",
|
||
" # simplification of the output tangent computation.\n",
|
||
" def make_zero(tan):\n",
|
||
" return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan \n",
|
||
" \n",
|
||
" output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))\n",
|
||
" return (primal_out, output_tangent)\n",
|
||
"\n",
|
||
"# Register the forward differentiation rule with JAX \n",
|
||
"ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"id": "ma3KBkiAMfW1",
|
||
"outputId": "f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(10.0)>)\n",
|
||
" call multiply_add_prim(Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(10.0)>)\n",
|
||
" call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))\n",
|
||
" Primal evaluation:\n",
|
||
" call multiply_add_prim(2.0, 2.0, 10.0)\n",
|
||
" call multiply_add_impl(2.0, 2.0, 10.0)\n",
|
||
" |<- multiply_add_impl = 14.0\n",
|
||
" |<- multiply_add_prim = 14.0\n",
|
||
" Tangent evaluation:\n",
|
||
" call multiply_add_prim(2.0, 1.0, 1.0)\n",
|
||
" call multiply_add_impl(2.0, 1.0, 1.0)\n",
|
||
" |<- multiply_add_impl = 3.0\n",
|
||
" |<- multiply_add_prim = 3.0\n",
|
||
" call multiply_add_prim(1.0, 2.0, 3.0)\n",
|
||
" call multiply_add_impl(1.0, 2.0, 3.0)\n",
|
||
" |<- multiply_add_impl = 5.0\n",
|
||
" |<- multiply_add_prim = 5.0\n",
|
||
" |<- multiply_add_value_and_jvp = (14.0, 5.0)\n",
|
||
" |<- multiply_add_prim = Traced<ConcreteArray(14.0)>\n",
|
||
"|<- square_add_prim = Traced<ConcreteArray(14.0)>\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.\n",
|
||
"assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "69QsEcu-lP4u"
|
||
},
|
||
"source": [
|
||
"TO EXPLAIN: \n",
|
||
"\n",
|
||
" * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n",
|
||
" * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n",
|
||
" we do not call the multiply_add_abstract_eval.\n",
|
||
" * I think it would be useful to show the jaxpr here"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Sb6e3ZAHOPHv"
|
||
},
|
||
"source": [
|
||
"#### JIT of forward differentiation\n",
|
||
"\n",
|
||
"We can apply JIT to the forward differentiation function:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {
|
||
"id": "hg-hzVu-N-hv",
|
||
"outputId": "38d32067-e152-4046-ad80-7f95a31ba628"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_value_and_jvp((Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>), (Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>))\n",
|
||
" Primal evaluation:\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" Tangent evaluation:\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"|<- square_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7f44cff26ae8>, <XlaOp at 0x7f44cff26ae8>, <XlaOp at 0x7f44cff268b8>)\n",
|
||
"|<- multiply_add_xla_translation = <XlaOp at 0x7f44cff268f0>\n",
|
||
"call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7f44cff26ae8>, <XlaOp at 0x7f44cff26a08>, <XlaOp at 0x7f44cff26a40>)\n",
|
||
"|<- multiply_add_xla_translation = <XlaOp at 0x7f44cff26a78>\n",
|
||
"call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7f44cff26a08>, <XlaOp at 0x7f44cff26ae8>, <XlaOp at 0x7f44cff26a78>)\n",
|
||
"|<- multiply_add_xla_translation = <XlaOp at 0x7f44cff269d0>\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"assert api.jit(lambda arg_values, arg_tangents: \n",
|
||
" api.jvp(square_add_prim, arg_values, arg_tangents))(\n",
|
||
" (2., 10.), (1., 1.)) == (14., 5.)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "jlZt1_v2mU88"
|
||
},
|
||
"source": [
|
||
"Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn\n",
|
||
"evaluates abstractly both the primal and the tangent evaluation (a total of \n",
|
||
"3 invocations of the `ma` primitive). Then we compile the 3 occurrences\n",
|
||
"of the primitive."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "555yt6ZIOePB"
|
||
},
|
||
"source": [
|
||
"### Reverse differentiation\n",
|
||
"\n",
|
||
"If we attempt now to use reverse differentiation we\n",
|
||
"see that JAX starts by using the `multiply_add_value_and_jvp` to \n",
|
||
"compute the forward differentiation for abstract values, but then runs\n",
|
||
"into a `NotImplementedError`. \n",
|
||
"\n",
|
||
"When computing the reverse differentiation JAX first does abstract evaluation\n",
|
||
"of the forward differentiation code `multiply_add_value_and_jvp` to obtain a \n",
|
||
"trace of primitives that compute the output tangent. \n",
|
||
"Observe that JAX performs this abstract evaluation with concrete values\n",
|
||
"for the differentiation point, and abstract values for the tangents. \n",
|
||
"Observe also that JAX uses the special abstract tangent value `Zero` for\n",
|
||
"the tangent corresponding to the 3rd argument of `ma`. This reflects the \n",
|
||
"fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n",
|
||
"which flows to the 3rd argument to `multiply_add_prim`.\n",
|
||
"\n",
|
||
"Observe also that during the abstract evaluation of the tangent we pass the \n",
|
||
"value 0.0 as the tangent for the 3rd argument. This is due to the use\n",
|
||
"of the `make_zero` function in the definition of `multiply_add_value_and_jvp`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {
|
||
"id": "8eAVnexaOjBn",
|
||
"outputId": "e4ee89cf-ab4a-4505-9817-fa978a2865ab"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ConcreteArray(2.0)>, 10.0)\n",
|
||
" call multiply_add_prim(Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(2.0)>, 10.0)\n",
|
||
" call multiply_add_value_and_jvp((Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(2.0)>, 10.0), (Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Zero))\n",
|
||
" Primal evaluation:\n",
|
||
" call multiply_add_prim(Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(2.0)>, 10.0)\n",
|
||
" call multiply_add_impl(2.0, 2.0, 10.0)\n",
|
||
" |<- multiply_add_impl = 14.0\n",
|
||
" |<- multiply_add_prim = 14.0\n",
|
||
" Tangent evaluation:\n",
|
||
" call multiply_add_prim(Traced<ConcreteArray(2.0)>, Traced<ShapedArray(float32[])>, 0.0)\n",
|
||
" call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ConcreteArray(2.0)>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)\n",
|
||
" |<- multiply_add_prim = Traced<ConcreteArray(14.0)>\n",
|
||
"|<- square_add_prim = Traced<ConcreteArray(14.0)>\n",
|
||
"\n",
|
||
"Found expected exception:\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Traceback (most recent call last):\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 198, in get_primitive_transpose\n",
|
||
" return primitive_transposes[p]\n",
|
||
"KeyError: multiply_add\n",
|
||
"\n",
|
||
"During handling of the above exception, another exception occurred:\n",
|
||
"\n",
|
||
"Traceback (most recent call last):\n",
|
||
" File \"<ipython-input-18-48ff33d55c45>\", line 2, in <module>\n",
|
||
" api.grad(square_add_prim)(2., 10.)\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 340, in grad_f\n",
|
||
" _, g = value_and_grad_f(*args, **kwargs)\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 398, in value_and_grad_f\n",
|
||
" g = vjp_py(np.ones((), dtype=dtype))\n",
|
||
"NotImplementedError: Reverse-mode differentiation rule for 'multiply_add' not implemented\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# This is reverse differentiation w.r.t. the first argument of square_add_prim\n",
|
||
"with expectNotImplementedError():\n",
|
||
" api.grad(square_add_prim)(2., 10.)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "fSHLUMDN26AY"
|
||
},
|
||
"source": [
|
||
"The above error is because there is a missing piece for JAX to be able\n",
|
||
"to use the forward differentiation code to compute reverse differentiation."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "3ibDbGF-PjK9"
|
||
},
|
||
"source": [
|
||
"#### Transposition\n",
|
||
"\n",
|
||
"\n",
|
||
"As explained above, when computing reverse differentiation JAX obtains\n",
|
||
"a trace of primitives that compute the tangent using forward differentiation.\n",
|
||
"Then, **JAX interprets this trace abstractly backwards** and for each \n",
|
||
"primitive it applies a **transposition** rule.\n",
|
||
"\n",
|
||
"To understand what is going on, consider for now a simpler example of the function \"f(x, y) = x * y + y\". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:\n",
|
||
"```\n",
|
||
" a = xt * 4.\n",
|
||
" b = 2. * yt\n",
|
||
" c = a + b\n",
|
||
" ft = c + yt\n",
|
||
"```\n",
|
||
"\n",
|
||
"By construction, the tangent calculation is always linear in the input tangents. \n",
|
||
"The only non-linear operator that may arise in the tangent calculation is multiplication,\n",
|
||
"but then one of the operands is constant.\n",
|
||
"\n",
|
||
"JAX will produce the reverse differentiation computation by processing the\n",
|
||
"JVP computation backwards. For each operation in the tangent computation,\n",
|
||
"it accumulates the cotangents\n",
|
||
"of the variables used by the operation, using the cotangent of the result\n",
|
||
"of the operation:\n",
|
||
"```\n",
|
||
" # Initialize cotangents of inputs and intermediate vars\n",
|
||
" xct = yct = act = bct = cct = 0.\n",
|
||
" # Initialize cotangent of the output\n",
|
||
" fct = 1.\n",
|
||
" # Process \"ft = c + yt\"\n",
|
||
" cct += fct\n",
|
||
" yct += fct\n",
|
||
" # Process \"c = a + b\"\n",
|
||
" act += cct\n",
|
||
" bct += cct\n",
|
||
" # Process \"b = 2. * yt\"\n",
|
||
" yct += 2. * bct\n",
|
||
" # Process \"a = xt * 4.\"\n",
|
||
" xct += act * 4.\n",
|
||
"```\n",
|
||
"\n",
|
||
"One can verify that this computation produces `xct = 4.` and `yct = 3.`, which \n",
|
||
"are the partial derivatives of the function `f`. \n",
|
||
"\n",
|
||
"JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:\n",
|
||
"```\n",
|
||
"p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)\n",
|
||
"```\n",
|
||
"\n",
|
||
"Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other\n",
|
||
"arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned \n",
|
||
"for the constant arguments.\n",
|
||
"\n",
|
||
"In particular, \n",
|
||
"```\n",
|
||
" add_transpose(out_ct, _, _) = (out_ct, out_ct)\n",
|
||
" mult_transpose(out_ct, x, _) = (None, x * out_ct)\n",
|
||
" mult_transpose(out_ct, _, y) = (out_ct * y, None)\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 0,
|
||
"metadata": {
|
||
"id": "JaHxFdkRO42r"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"@trace(\"multiply_add_transpose\")\n",
|
||
"def multiply_add_transpose(ct, x, y, z):\n",
|
||
" \"\"\"Evaluates the transpose of a linear primitive.\n",
|
||
"\n",
|
||
" This method is only used when computing the backward gradient following \n",
|
||
" value_and_jvp, and is only needed for primitives that are used in the JVP \n",
|
||
" calculation for some other primitive. We need transposition for multiply_add_prim, \n",
|
||
" because we have used multiply_add_prim in the computation of the output_tangent in \n",
|
||
" multiply_add_value_and_jvp.\n",
|
||
"\n",
|
||
" In our case, multiply_add is not a linear primitive. However, it is used linearly \n",
|
||
" w.r.t. tangents in multiply_add_value_and_jvp:\n",
|
||
" output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n",
|
||
" \n",
|
||
" Always one of the first two multiplicative arguments is a constant.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" ct: the cotangent of the output of the primitive.\n",
|
||
" x, y, z: values of the arguments. The arguments that are used linearly\n",
|
||
" get an ad.UndefinedPrimal value. The other arguments get a constant\n",
|
||
" value.\n",
|
||
" Returns:\n",
|
||
" a tuple with the cotangent of the inputs, with the value None\n",
|
||
" corresponding to the constant arguments.\n",
|
||
" \"\"\"\n",
|
||
" if not ad.is_undefined_primal(x):\n",
|
||
" # This use of multiply_add is with a constant \"x\"\n",
|
||
" assert ad.is_undefined_primal(y)\n",
|
||
" ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))\n",
|
||
" res = None, ct_y, ct\n",
|
||
" else:\n",
|
||
" # This use of multiply_add is with a constant \"y\"\n",
|
||
" assert ad.is_undefined_primal(x)\n",
|
||
" ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))\n",
|
||
" res = ct_x, None, ct\n",
|
||
" return res\n",
|
||
"\n",
|
||
"\n",
|
||
"ad.primitive_transposes[multiply_add_p] = multiply_add_transpose"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "PpChox-Jp7wb"
|
||
},
|
||
"source": [
|
||
"Now we can complete the run of the `grad`:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {
|
||
"id": "PogPKS4MPevd",
|
||
"outputId": "d33328d4-3e87-45b5-9b31-21ad624b67af"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ConcreteArray(2.0)>, 10.0)\n",
|
||
" call multiply_add_prim(Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(2.0)>, 10.0)\n",
|
||
" call multiply_add_value_and_jvp((Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(2.0)>, 10.0), (Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Zero))\n",
|
||
" Primal evaluation:\n",
|
||
" call multiply_add_prim(Traced<ConcreteArray(2.0)>, Traced<ConcreteArray(2.0)>, 10.0)\n",
|
||
" call multiply_add_impl(2.0, 2.0, 10.0)\n",
|
||
" |<- multiply_add_impl = 14.0\n",
|
||
" |<- multiply_add_prim = 14.0\n",
|
||
" Tangent evaluation:\n",
|
||
" call multiply_add_prim(Traced<ConcreteArray(2.0)>, Traced<ShapedArray(float32[])>, 0.0)\n",
|
||
" call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ConcreteArray(2.0)>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)\n",
|
||
" |<- multiply_add_prim = Traced<ConcreteArray(14.0)>\n",
|
||
"|<- square_add_prim = Traced<ConcreteArray(14.0)>\n",
|
||
"call multiply_add_transpose(1.0, _, 2.0, _)\n",
|
||
" call multiply_add_prim(1.0, 2.0, 0.0)\n",
|
||
" call multiply_add_impl(1.0, 2.0, 0.0)\n",
|
||
" |<- multiply_add_impl = 2.0\n",
|
||
" |<- multiply_add_prim = 2.0\n",
|
||
"|<- multiply_add_transpose = (2.0, None, 1.0)\n",
|
||
"call multiply_add_transpose(1.0, 2.0, _, 0.0)\n",
|
||
" call multiply_add_prim(2.0, 1.0, 0.0)\n",
|
||
" call multiply_add_impl(2.0, 1.0, 0.0)\n",
|
||
" |<- multiply_add_impl = 2.0\n",
|
||
" |<- multiply_add_prim = 2.0\n",
|
||
"|<- multiply_add_transpose = (None, 2.0, 1.0)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"assert api.grad(square_add_prim)(2., 10.) == 4."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "8M1xLCXW4fK7"
|
||
},
|
||
"source": [
|
||
"Notice the two calls to `multiply_add_transpose`. They correspond to the two\n",
|
||
"uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the \n",
|
||
"last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "EIJs6FYmPg6c"
|
||
},
|
||
"source": [
|
||
"#### JIT of reverse differentiation \n",
|
||
"\n",
|
||
"Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n",
|
||
"abstract values, while in the absence of JIT we used `ConcreteArray`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {
|
||
"id": "FZ-JGbWZPq2-",
|
||
"outputId": "e42b5222-9c3e-4853-e13a-874f6605d178"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_value_and_jvp((Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>), (Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Zero))\n",
|
||
" Primal evaluation:\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" Tangent evaluation:\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ConcreteArray(0.0)>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(0.0))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
" |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"|<- square_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"call multiply_add_transpose(1.0, _, Traced<ShapedArray(float32[])>, _)\n",
|
||
" call multiply_add_prim(1.0, Traced<ShapedArray(float32[])>, Traced<ConcreteArray(0.0)>)\n",
|
||
" call multiply_add_abstract_eval(ConcreteArray(1.0), ShapedArray(float32[]), ConcreteArray(0.0))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>, None, 1.0)\n",
|
||
"call multiply_add_transpose(1.0, Traced<ShapedArray(float32[])>, _, Traced<ConcreteArray(0.0)>)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, 1.0, Traced<ConcreteArray(0.0)>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(1.0), ConcreteArray(0.0))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>, 1.0)\n",
|
||
"call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7f44cfec1b20>, <XlaOp at 0x7f44cfec13e8>, <XlaOp at 0x7f44cfec1880>)\n",
|
||
"|<- multiply_add_xla_translation = <XlaOp at 0x7f44cfec1420>\n",
|
||
"call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7f44cfec13e8>, <XlaOp at 0x7f44cfec1308>, <XlaOp at 0x7f44cfec1ae8>)\n",
|
||
"|<- multiply_add_xla_translation = <XlaOp at 0x7f44cfec1490>\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"assert api.jit(api.grad(square_add_prim))(2., 10.) == 4."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "-3lqPkdQPvl5"
|
||
},
|
||
"source": [
|
||
"### Batching\n",
|
||
"\n",
|
||
"The batching transformation takes a point-wise computation and turns it\n",
|
||
"into a computation on vectors. If we try it right now, we get a `NotImplementedError`:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {
|
||
"id": "hFvBR3I9Pzh3",
|
||
"outputId": "434608bc-281f-4d3b-83bd-eaaf3b51b1cd"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
"\n",
|
||
"Found expected exception:\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Traceback (most recent call last):\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 163, in get_primitive_batcher\n",
|
||
" return primitive_batchers[p]\n",
|
||
"KeyError: multiply_add\n",
|
||
"\n",
|
||
"During handling of the above exception, another exception occurred:\n",
|
||
"\n",
|
||
"Traceback (most recent call last):\n",
|
||
" File \"<ipython-input-22-70154d0e2ab6>\", line 3, in <module>\n",
|
||
" np.array([10., 20.]))\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n",
|
||
" lambda: _flatten_axes(out_tree(), out_axes))\n",
|
||
" File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n",
|
||
" out_vals, out_dims = batch2(fun, in_vals, in_dims)\n",
|
||
"NotImplementedError: Batching rule for 'multiply_add' not implemented\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# The arguments are two vectors instead of two scalars\n",
|
||
"with expectNotImplementedError():\n",
|
||
" api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),\n",
|
||
" np.array([10., 20.]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "gILasMiP6elR"
|
||
},
|
||
"source": [
|
||
"We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 0,
|
||
"metadata": {
|
||
"id": "KQfeqRIrP7zg"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from jax.interpreters import batching\n",
|
||
"\n",
|
||
"\n",
|
||
"@trace(\"multiply_add_batch\")\n",
|
||
"def multiply_add_batch(vector_arg_values, batch_axes):\n",
|
||
" \"\"\"Computes the batched version of the primitive.\n",
|
||
" \n",
|
||
" This must be a JAX-traceable function.\n",
|
||
" \n",
|
||
" Since the multiply_add primitive already operates pointwise on arbitrary\n",
|
||
" dimension tensors, to batch it we can use the primitive itself. This works as\n",
|
||
" long as both the inputs have the same dimensions and are batched along the\n",
|
||
" same axes. The result is batched along the axis that the inputs are batched.\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" vector_arg_values: a tuple of two arguments, each being a tensor of matching\n",
|
||
" shape.\n",
|
||
" batch_axes: the axes that are being batched. See vmap documentation.\n",
|
||
" Returns:\n",
|
||
" a tuple of the result, and the result axis that was batched. \n",
|
||
" \"\"\"\n",
|
||
" assert batch_axes[0] == batch_axes[1]\n",
|
||
" assert batch_axes[0] == batch_axes[2]\n",
|
||
" _trace(\"Using multiply_add to compute the batch:\")\n",
|
||
" res = multiply_add_prim(*vector_arg_values)\n",
|
||
" return res, batch_axes[0]\n",
|
||
"\n",
|
||
"\n",
|
||
"batching.primitive_batchers[multiply_add_p] = multiply_add_batch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {
|
||
"id": "VwxNk869P_YG",
|
||
"outputId": "9d22c921-5803-4d33-9e88-b6e439ba9738"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))\n",
|
||
" Using multiply_add to compute the batch:\n",
|
||
" call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])\n",
|
||
" call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])\n",
|
||
" |<- multiply_add_impl = [14. 29.]\n",
|
||
" |<- multiply_add_prim = [14. 29.]\n",
|
||
" |<- multiply_add_batch = ([14. 29.], 0)\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"|<- square_add_prim = Traced<ShapedArray(float32[])>\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(\n",
|
||
" np.array([2., 3.]),\n",
|
||
" np.array([10., 20.])),\n",
|
||
" [14., 29.])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "NmqLlV1TQDCC"
|
||
},
|
||
"source": [
|
||
"#### JIT of batching"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {
|
||
"id": "xqEdXVUgQCTt",
|
||
"outputId": "9c22fd9c-919c-491d-bbeb-32c241b808fa"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)\n",
|
||
" call multiply_add_batch((Traced<ShapedArray(float32[2])>, Traced<ShapedArray(float32[2])>, Traced<ShapedArray(float32[2])>), (0, 0, 0))\n",
|
||
" Using multiply_add to compute the batch:\n",
|
||
" call multiply_add_prim(Traced<ShapedArray(float32[2])>, Traced<ShapedArray(float32[2])>, Traced<ShapedArray(float32[2])>)\n",
|
||
" call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))\n",
|
||
" |<- multiply_add_abstract_eval = ShapedArray(float32[2])\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[2])>\n",
|
||
" |<- multiply_add_batch = (Traced<ShapedArray(float32[2])>, 0)\n",
|
||
" |<- multiply_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"|<- square_add_prim = Traced<ShapedArray(float32[])>\n",
|
||
"call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7f44cfecb9d0>, <XlaOp at 0x7f44cfecb9d0>, <XlaOp at 0x7f44cfecba40>)\n",
|
||
"|<- multiply_add_xla_translation = <XlaOp at 0x7f44cfec1340>\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))\n",
|
||
" (np.array([2., 3.]),\n",
|
||
" np.array([10., 20.])),\n",
|
||
" [14., 29.])"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"colab": {
|
||
"collapsed_sections": [],
|
||
"name": "How JAX primitives work.ipynb",
|
||
"provenance": [],
|
||
"toc_visible": true
|
||
},
|
||
"jupytext": {
|
||
"formats": "ipynb,md:myst"
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|