rocm_jax/docs/notebooks/Common_Gotchas_in_JAX.ipynb

2396 lines
365 KiB
Plaintext
Raw Normal View History

2019-09-30 11:00:02 -07:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hjM_sV_AepYf"
},
"source": [
"# 🔪 JAX - The Sharp Bits 🔪"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4k5PVzEo2uJO"
},
"source": [
"*levskaya@ mattjj@*\n",
"\n",
"When walking about the countryside of [Italy](https://iaml.it/blog/jax-intro), the people will not hesitate to tell you that __JAX__ has _\"una anima di pura programmazione funzionale\"_.\n",
"\n",
"__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. As such it needs to control the _unwanted proliferation_ of __side-effects__ in its programs so that analysis and transformation of its computations remain tractable!\n",
"\n",
"This requires us to write code in a _functional_ style with _explicit_ descriptions of how the state of a program changes, which results in __several important differences__ to how you might be used to programming in Numpy, Tensorflow or Pytorch.\n",
"\n",
"Herein we try to cover the most frequent points of trouble that users encounter when starting out in __JAX__."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "GoK_PCxPeYcy"
},
"outputs": [],
"source": [
"import numpy as onp\n",
"from jax import grad, jit\n",
"from jax import lax\n",
"from jax import random\n",
"import jax\n",
"import jax.numpy as np\n",
"import matplotlib as mpl\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib import rcParams\n",
"rcParams['image.interpolation'] = 'nearest'\n",
"rcParams['image.cmap'] = 'viridis'\n",
"rcParams['axes.grid'] = False"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "oBdKtkVW8Lha"
},
"source": [
"## 🔪 In-Place Updates"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "JffAqnEW4JEb"
},
"source": [
"In Numpy you're used to doing this:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
},
"colab_type": "code",
"id": "om4xV7_84N9j",
"outputId": "25ed90e1-74f9-420c-ba06-21e5d6a3b58e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"original array:\n",
"[[0. 0. 0.]\n",
" [0. 0. 0.]\n",
" [0. 0. 0.]]\n",
"updated array:\n",
"[[0. 0. 0.]\n",
" [1. 1. 1.]\n",
" [0. 0. 0.]]\n"
]
}
],
"source": [
"numpy_array = onp.zeros((3,3), dtype=np.float32)\n",
"print(\"original array:\")\n",
"print(numpy_array)\n",
"\n",
"# In place, mutating update\n",
"numpy_array[1, :] = 1.0\n",
"print(\"updated array:\")\n",
"print(numpy_array)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "go3L4x3w4-9p"
},
"source": [
"If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 198
},
"colab_type": "code",
"id": "2AxeCufq4wAp",
"outputId": "7013374b-041f-4270-db19-cfb4ab992f52",
"tags": [
"raises-exception"
]
2019-09-30 11:00:02 -07:00
},
"outputs": [
{
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-189-a717a200f584>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# In place update of JAX's array will yield an error!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax_array\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m: '_FilledConstant' object does not support item assignment"
]
}
],
"source": [
"jax_array = np.zeros((3,3), dtype=np.float32)\n",
"\n",
"# In place update of JAX's array will yield an error!\n",
"jax_array[1, :] = 1.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7mo76sS25Wco"
},
"source": [
"__What gives?!__ \n",
"\n",
"Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program. \n",
"\n",
2019-12-14 07:00:39 -08:00
"Instead, JAX offers the _functional_ update functions: [__index_update__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update), [__index_add__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add), [__index_min__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_min.html#jax.ops.index_min), [__index_max__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_max.html#jax.ops.index_max), and the [__index__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index.html#jax.ops.index) helper.\n",
2019-09-30 11:00:02 -07:00
"\n",
"️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "m5lg1RYq5D9p"
},
"outputs": [],
"source": [
"from jax.ops import index, index_add, index_update"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "X2Xjjvd-l8NL"
},
"source": [
"### index_update"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eM6MyndXL2NY"
},
"source": [
"If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 221
},
"colab_type": "code",
"id": "ygUJT49b7BBk",
"outputId": "c1dc7528-4a4a-4ee6-c9a2-c7e39f95ccb1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"original array:\n",
"[[0. 0. 0.]\n",
" [0. 0. 0.]\n",
" [0. 0. 0.]]\n",
"old array unchanged:\n",
"[[0. 0. 0.]\n",
" [0. 0. 0.]\n",
" [0. 0. 0.]]\n",
"new array:\n",
"[[0. 0. 0.]\n",
" [1. 1. 1.]\n",
" [0. 0. 0.]]\n"
]
}
],
"source": [
"jax_array = np.zeros((3, 3))\n",
"print(\"original array:\")\n",
"print(jax_array)\n",
"\n",
"new_jax_array = index_update(jax_array, index[1, :], 1.)\n",
"\n",
"print(\"old array unchanged:\")\n",
"print(jax_array)\n",
"\n",
"print(\"new array:\")\n",
"print(new_jax_array)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7to-sF8EmC_y"
},
"source": [
"### index_add"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "iI5cLY1xMBLs"
},
"source": [
"If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 221
},
"colab_type": "code",
"id": "tsw2svao8FUp",
"outputId": "2492b20d-0b8e-4f61-816d-00b8a08ce29f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"original array:\n",
"[[1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 1. 1. 1.]]\n",
"new array post-addition:\n",
"[[1. 1. 1. 8. 8. 8.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 8. 8. 8.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 8. 8. 8.]]\n"
]
}
],
"source": [
"print(\"original array:\")\n",
"jax_array = np.ones((5, 6))\n",
"print(jax_array)\n",
"\n",
"new_jax_array = index_add(jax_array, index[::2, 3:], 7.)\n",
"print(\"new array post-addition:\")\n",
"print(new_jax_array)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🔪 Out-of-Bounds Indexing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "IndexError",
"evalue": "index 11 is out of bounds for axis 0 with size 10",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-2-eac95ae2edf8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m11\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m: index 11 is out of bounds for axis 0 with size 10"
]
}
],
"source": [
"onp.arange(10)[11]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, raising an error on other accelerators can be more difficult. Therefore, JAX does not raise an error and instead returns the last value in the array. "
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(9, dtype=int32)"
]
},
"execution_count": 0,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.arange(10)[11]"
]
},
2019-09-30 11:00:02 -07:00
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MUycRNh6e50W"
},
"source": [
"## 🔪 Random Numbers"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "O8vvaVt3MRG2"
},
"source": [
"> _If all scientific papers whose results are in doubt because of bad \n",
"> `rand()`s were to disappear from library shelves, there would be a \n",
"> gap on each shelf about as big as your fist._ - Numerical Recipes"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Qikt9pPW9L5K"
},
"source": [
"### RNGs and State\n",
"You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "rr9FeP41fynt",
"outputId": "180b7c87-7050-4123-dc42-2356da6f14a2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.7117779558041075\n",
"0.014396253746679077\n",
"0.7717174868106601\n"
]
}
],
"source": [
"print(onp.random.random())\n",
"print(onp.random.random())\n",
"print(onp.random.random())"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ORMVVGZJgSVi"
},
"source": [
"Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937-1}$ and at any point can be described by __624 32bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "7Pyp2ajzfPO2"
},
"outputs": [],
"source": [
"onp.random.seed(0)\n",
"rng_state = onp.random.get_state()\n",
"#print(rng_state)\n",
"# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n",
"# 2481403966, 4042607538, 337614300, ... 614 more numbers..., \n",
"# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "aJIxHVXCiM6m"
},
"source": [
"This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "GAHaDCYafpAF"
},
"outputs": [],
"source": [
"_ = onp.random.uniform()\n",
"rng_state = onp.random.get_state()\n",
"#print(rng_state) \n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n",
"\n",
"# Let's exhaust the entropy in this PRNG statevector\n",
"for i in range(311):\n",
" _ = onp.random.uniform()\n",
"rng_state = onp.random.get_state()\n",
"#print(rng_state) \n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n",
"\n",
"# Next call iterates the RNG state for a new batch of fake \"entropy\".\n",
"_ = onp.random.uniform()\n",
"rng_state = onp.random.get_state()\n",
"# print(rng_state) \n",
"# --> ('MT19937', array([1499117434, 2949980591, 2242547484, \n",
"# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "N_mWnleNogps"
},
"source": [
"The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n",
"\n",
"The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5Kb state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow. "
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Uvq7nV-j4vKK"
},
"source": [
"### JAX PRNG"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "COjzGBpO4tzL"
},
"source": [
"\n",
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Three-fry counter-based PRNG](https://github.com/google/jax/blob/master/design_notes/prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by two unsigned-int32s that we call a __key__:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "yPHE7KTWgAWs",
"outputId": "6c2db189-d971-4d60-eb6b-c7ee3a4704b7"
},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0], dtype=uint32)"
]
},
"execution_count": 196,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"from jax import random\n",
"key = random.PRNGKey(0)\n",
"key"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "XjYyWYNfq0hW"
},
"source": [
"JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! \n",
"\n",
"Reusing the same state will cause __sadness__ and __monotony__, depriving the enduser of __lifegiving chaos__:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "7zUdQMynoE5e",
"outputId": "9e1e1f08-19c9-4d22-c78f-4d3e113e185d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.20584233]\n",
"[0 0]\n",
"[-0.20584233]\n",
"[0 0]\n"
]
}
],
"source": [
"print(random.normal(key, shape=(1,)))\n",
"print(key)\n",
"# No no no!\n",
"print(random.normal(key, shape=(1,)))\n",
"print(key)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hQN9van8rJgd"
},
"source": [
"Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "ASj0_rSzqgGh",
"outputId": "ea3fae99-6642-4016-b0c0-938214384fe7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"old key [0 0]\n",
" \\---SPLIT --> new key [4146024105 967050713]\n",
" \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n"
]
}
],
"source": [
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
"print(\" \\---SPLIT --> new key \", key)\n",
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "tqtFVE4MthO3"
},
"source": [
"We propagate the __key__ and make new __subkeys__ whenever we need a new random number:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "jbC34XLor2Ek",
"outputId": "436713d1-06a3-408e-fbaa-1fedeea73c73"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"old key [4146024105 967050713]\n",
" \\---SPLIT --> new key [2384771982 3928867769]\n",
" \\--> new subkey [1278412471 2182328957] --> normal [-0.5866507]\n"
]
}
],
"source": [
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
"print(\" \\---SPLIT --> new key \", key)\n",
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0KLYUluz3lN3"
},
"source": [
"We can generate more than one __subkey__ at a time:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "lEi08PJ4tfkX",
"outputId": "7599b43d-930e-4c20-d549-b7694281a59a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.37533447]\n",
"[0.9864503]\n",
"[0.1455319]\n"
]
}
],
"source": [
"key, *subkeys = random.split(key, 4)\n",
"for subkey in subkeys:\n",
" print(random.normal(subkey, shape=(1,)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rg4CpMZ8c3ri"
},
"source": [
"## 🔪 Control Flow"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "izLTvT24dAq0"
},
"source": [
"### ✔ python control_flow + autodiff ✔\n",
"\n",
"If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "aAx0T3F8lLtu",
"outputId": "1f75bb41-2d50-451e-c05d-cb946b580d8d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12.0\n",
"-4.0\n"
]
}
],
"source": [
"def f(x):\n",
" if x < 3:\n",
" return 3. * x ** 2\n",
" else:\n",
" return -4 * x\n",
"\n",
"print(grad(f)(2.)) # ok!\n",
"print(grad(f)(4.)) # ok!"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hIfPT7WMmZ2H"
},
"source": [
"### python control flow + JIT\n",
"\n",
"Using control flow with `jit` is more complicated, and by default it has more constraints.\n",
"\n",
"This works:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "OZ_BJX0CplNC",
"outputId": "d75b0e66-273d-461a-814d-a95c40d41ef4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"24\n"
]
}
],
"source": [
"@jit\n",
"def f(x):\n",
" for i in range(3):\n",
" x = 2 * x\n",
" return x\n",
"\n",
"print(f(3))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "22RzeJ4QqAuX"
},
"source": [
"So does this:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "pinVnmRWp6w6",
"outputId": "f7829934-8cdd-4bba-b540-d9df38c71e95"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6.0\n"
]
}
],
"source": [
"@jit\n",
"def g(x):\n",
" y = 0.\n",
" for i in range(x.shape[0]):\n",
" y = y + x[i]\n",
" return y\n",
"\n",
"print(g(np.array([1., 2., 3.])))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "TStltU2dqf8A"
},
"source": [
"But this doesn't, at least by default:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
},
"colab_type": "code",
"id": "9z38AIKclRNM",
"outputId": "f911fb55-f489-4300-f9b1-9142d252f3f9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ERROR: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.\n"
]
}
],
"source": [
"@jit\n",
"def f(x):\n",
" if x < 3:\n",
" return 3. * x ** 2\n",
" else:\n",
" return -4 * x\n",
"\n",
"# This will fail!\n",
"try:\n",
" f(2)\n",
"except Exception as e:\n",
" print(\"ERROR:\", e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "pIbr4TVPqtDN"
},
"source": [
"__What gives!?__\n",
"\n",
"When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n",
"\n",
"For example, if we evaluate an `@jit` function on the array `np.array([1., 2., 3.], np.float32)`, we might want to compile code that we can reuse to evaluate the function on `np.array([4., 5., 6.], np.float32)` to save on compile time.\n",
"\n",
"To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/master/jax/abstract_arrays.py), and different transformations use different abstraction levels.\n",
"\n",
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), np.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
"\n",
"But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), np.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), np.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n",
"\n",
"The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "-Tzp0H7Bt1Sn",
"outputId": "1435a6a3-2b1c-4acd-be81-c1361021f3c4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12.0\n"
]
}
],
"source": [
"def f(x):\n",
" if x < 3:\n",
" return 3. * x ** 2\n",
" else:\n",
" return -4 * x\n",
"\n",
"f = jit(f, static_argnums=(0,))\n",
"\n",
"print(f(2.))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MHm1hIQAvBVs"
},
"source": [
"Here's another example, this time involving a loop:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "iwY86_JKvD6b",
"outputId": "469a4aeb-2dbd-4f03-9aef-9fd646a717d7"
},
"outputs": [
{
"data": {
"text/plain": [
"array(5., dtype=float32)"
]
},
"execution_count": 206,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def f(x, n):\n",
" y = 0.\n",
" for i in range(n):\n",
" y = y + x[i]\n",
" return y\n",
"\n",
"f = jit(f, static_argnums=(1,))\n",
"\n",
"f(np.array([2., 3., 4.]), 2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "nSPTOX8DvOeO"
},
"source": [
"In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "wWdg8LTYwCW3"
},
"source": [
"️⚠️ **functions with argument-__value__ dependent shapes**\n",
"\n",
"These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "Tqe9uLmUI_Gv",
"outputId": "dbb43bac-8141-40a3-c760-95656181b598"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[4. 4. 4. 4. 4.]\n",
"error! `full` requires shapes to be concrete. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.\n",
"[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n",
"[4. 4. 4. 4. 4.]\n"
]
}
],
"source": [
"def example_fun(length, val):\n",
" return np.ones((length,)) * val\n",
"# un-jit'd works fine\n",
"print(example_fun(5, 4))\n",
"\n",
"bad_example_jit = jit(example_fun)\n",
"# this will fail:\n",
"try:\n",
" print(bad_example_jit(10, 4))\n",
"except Exception as e:\n",
" print(\"error!\", e)\n",
"# static_argnums tells JAX to recompile on changes at these argument positions:\n",
"good_example_jit = jit(example_fun, static_argnums=(0,))\n",
"# first compile\n",
"print(good_example_jit(10, 4))\n",
"# recompiles\n",
"print(good_example_jit(5, 4))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MStx_r2oKxpp"
},
"source": [
"`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! \n",
"\n",
"Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions: "
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "m2ABpRd8K094",
"outputId": "06fe7d4e-2c59-4499-c04e-94166916be74"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traced<ShapedArray(int32[]):JaxprTrace(level=-1/1)>\n",
"Traced<ShapedArray(int32[]):JaxprTrace(level=-1/1)>\n"
]
},
{
"data": {
"text/plain": [
"array(4, dtype=int32)"
]
},
"execution_count": 12,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"@jit\n",
"def f(x):\n",
" print(x)\n",
" y = 2 * x\n",
" print(y)\n",
" return y\n",
"f(2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "uCDcWG4MnVn-"
},
"source": [
"### Structured control flow primitives\n",
"\n",
2019-12-14 07:00:39 -08:00
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
2019-09-30 11:00:02 -07:00
"\n",
" - `lax.cond` _differentiable_\n",
" - `lax.while_loop` __fwd-mode-differentiable__\n",
" - `lax.fori_loop` __fwd-mode-differentiable__\n",
" - `lax.scan` _differentiable_\n",
"\n"
2019-09-30 11:00:02 -07:00
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Sd9xrLMXeK3A"
},
"source": [
"#### cond\n",
"python equivalent:\n",
"\n",
"```\n",
"def cond(pred, true_operand, true_fun, false_operand, false_fun):\n",
" if pred:\n",
" return true_fun(true_operand)\n",
" else:\n",
" return false_fun(false_operand)\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "SGxz9JOWeiyH",
"outputId": "b91c6e01-c3a7-41a0-b4d2-f815f273c8a7"
},
"outputs": [
{
"data": {
"text/plain": [
"array([-1.], dtype=float32)"
]
},
"execution_count": 207,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"from jax import lax\n",
"\n",
"operand = np.array([0.])\n",
"lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
"# --> array([1.], dtype=float32)\n",
"lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
"# --> array([-1.], dtype=float32)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xkOFAw24eOMg"
},
"source": [
"#### while_loop\n",
"\n",
"python equivalent:\n",
"```\n",
"def while_loop(cond_fun, body_fun, init_val):\n",
" val = init_val\n",
" while cond_fun(val):\n",
" val = body_fun(val)\n",
" return val\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "jM-D39a-c436",
"outputId": "496ba1d8-e1d9-4432-d44b-c1104e1e966d"
},
"outputs": [
{
"data": {
"text/plain": [
"array(10, dtype=int32)"
]
},
"execution_count": 208,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"init_val = 0\n",
"cond_fun = lambda x: x<10\n",
"body_fun = lambda x: x+1\n",
"lax.while_loop(cond_fun, body_fun, init_val)\n",
"# --> array(10, dtype=int32)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "apo3n3HAeQY_"
},
"source": [
"#### fori_loop\n",
"python equivalent:\n",
"```\n",
"def fori_loop(start, stop, body_fun, init_val):\n",
" val = init_val\n",
" for i in range(start, stop):\n",
" val = body_fun(i, val)\n",
" return val\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "dt3tUpOmeR8u",
"outputId": "3155b3ce-589c-437c-a456-de81b3db0a64"
},
"outputs": [
{
"data": {
"text/plain": [
"array(45, dtype=int32)"
]
},
"execution_count": 209,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"init_val = 0\n",
"start = 0\n",
"stop = 10\n",
"body_fun = lambda i,x: x+i\n",
"lax.fori_loop(start, stop, body_fun, init_val)\n",
"# --> array(45, dtype=int32)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "SipXS5qiqk8e"
},
"source": [
"#### Summary\n",
"\n",
"$$\n",
"\\begin{array} {r|rr} \n",
"\\hline \\\n",
"\\textrm{construct} \n",
"& \\textrm{jit} \n",
"& \\textrm{grad} \\\\\n",
"\\hline \\\n",
"\\textrm{if} & ❌ & ✔ \\\\\n",
"\\textrm{for} & ✔* & ✔\\\\\n",
"\\textrm{while} & ✔* & ✔\\\\\n",
"\\textrm{lax.cond} & ✔ & ✔\\\\\n",
"\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n",
"\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n",
"\\textrm{lax.scan} & ✔ & ✔\\\\\n",
2019-09-30 11:00:02 -07:00
"\\hline\n",
"\\end{array}\n",
"$$\n",
"<center>$\\ast$ = argument-__value__-independent loop condition - unrolls the loop </center>"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "bxuUjFVG-v1h"
},
"source": [
"## 🔪 Convolutions"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0pcn2LeS-03b"
},
"source": [
"JAX and XLA offer the very general N-dimensional __conv_general_dilated__ function, but it's not very obvious how to use it. We'll give some examples of the common use-cases. There are also the convenience functions `lax.conv` and `lax.conv_general_padding` for the most common kinds of convolutions.\n",
"\n",
"A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285) is highly recommended reading!\n",
"\n",
"Let's define a simple diagonal edge kernel:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 286
},
"colab_type": "code",
"id": "Yud1Y3ss-x1K",
"outputId": "1674482b-501a-43eb-91c6-0bef42a73d6d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Edge Conv kernel:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ8AAAD8CAYAAABpXiE9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADe5JREFUeJzt3X+snmV9x/H3Zy1gJkwqJdKUKj+j\nc24GPEGUxTRDEySGLpEl8IeC0XQ4yZRoMtQEE5Nl6h8uYxpJA0RYDDaCgeNSQ2DAcFmKVFIohSCF\nuLS1EyyuyHSysu/+ODfm8XB+9Xru8zzP0fcrefJc931f576+vdp8ev9sU1VI0pH6vXEXIGllMjwk\nNTE8JDUxPCQ1MTwkNTE8JDUZKjySvDbJXUme7L7XzNPvpSQ7u8/0MGNKmgwZ5jmPJF8CnquqLyS5\nGlhTVX8zR78XqurYIeqUNGGGDY8ngI1VdSDJOuC+qnrjHP0MD+m3zLDh8V9VdXzXDvCzl5dn9TsM\n7AQOA1+oqtvn2d9mYDPAq38/b3vTGUc31ybt+tmJ4y5h4r24d99Pq6ppolYv1iHJ3cBJc2z67OBC\nVVWS+ZLoDVW1P8lpwD1JdlXVU7M7VdUWYAvA1FtfVd+/c8OivwBpPqdvvWLcJUy8H33iU//R+rOL\nhkdVvXu+bUl+kmTdwGnLM/PsY3/3/XSS+4CzgFeEh6SVY9hbtdPAZV37MuCO2R2SrElyTNdeC5wH\nPDbkuJLGbNjw+ALwniRPAu/ulkkyleT6rs8fAjuSPAzcy8w1D8NDWuEWPW1ZSFUdBM6fY/0O4CNd\n+9+BPx5mHEmTxydMJTUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwk\nNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1\nMTwkNeklPJJckOSJJHuSXD3H9mOSbO22P5DklD7GlTQ+Q4dHklXAV4H3Am8GLk3y5lndPgz8rKrO\nAP4e+OKw40oarz6OPM4B9lTV01X1IvBNYNOsPpuAm7r2rcD5SdLD2JLGpI/wWA/sHVje162bs09V\nHQYOASf0MLakMZmoC6ZJNifZkWTHswdfGnc5khbQR3jsBzYMLJ/crZuzT5LVwGuAg7N3VFVbqmqq\nqqZOPGFVD6VJWi59hMeDwJlJTk1yNHAJMD2rzzRwWde+GLinqqqHsSWNyephd1BVh5NcCdwJrAJu\nrKrdST4P7KiqaeAG4J+S7AGeYyZgJK1gQ4cHQFVtA7bNWnfNQPt/gL/oYyxJk2GiLphKWjkMD0lN\nDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0M\nD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU16CY8kFyR5Isme\nJFfPsf3yJM8m2dl9PtLHuJLGZ/WwO0iyCvgq8B5gH/BgkumqemxW161VdeWw40maDH0ceZwD7Kmq\np6vqReCbwKYe9itpgg195AGsB/YOLO8D3j5Hv/cneRfwQ+Cqqto7u0OSzcBmgNev76O0316nb71i\n3CVMvDOu2j7uEibej4b42VFdMP0OcEpV/QlwF3DTXJ2qaktVTVXV1IknrBpRaZJa9BEe+4ENA8sn\nd+t+raoOVtWvusXrgbf1MK6kMeojPB4EzkxyapKjgUuA6cEOSdYNLF4EPN7DuJLGaOgLC1V1OMmV\nwJ3AKuDGqtqd5PPAjqqaBv46yUXAYeA54PJhx5U0Xr1clayqbcC2WeuuGWh/Gvh0H2NJmgw+YSqp\nieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ\n4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIalJL+GR5MYkzyR5\ndJ7tSXJtkj1JHklydh/jShqfvo48vg5csMD29wJndp/NwNd6GlfSmPQSHlV1P/DcAl02ATfXjO3A\n8UnW9TG2pPEY1TWP9cDegeV93brfkGRzkh1Jdjx78KURlSapxURdMK2qLVU1VVVTJ56watzlSFrA\nqMJjP7BhYPnkbp2kFWpU4TENfLC763IucKiqDoxobEnLYHUfO0lyC7ARWJtkH/A54CiAqroO2AZc\nCOwBfgF8qI9xJY1PL+FRVZcusr2Aj/UxlqTJMFEXTCWtHIaHpCaGh6QmhoekJoaHpCaGh6Qmhoek\nJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6Qm\nhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmvYRHkhuTPJPk0Xm2b0xyKMnO7nNNH+NKGp9e/qNr4OvA\nV4CbF+jzvap6X0/jSRqzXo48qup+4Lk+9iVpZejryGMp3pHkYeDHwKeqavfsDkk2A5sBVq1Zw+lb\nrxhheSvLGVdtH3cJ+h03qgumDwFvqKq3Av8I3D5Xp6raUlVTVTW16thXj6g0SS1GEh5V9XxVvdC1\ntwFHJVk7irElLY+RhEeSk5Kka5/TjXtwFGNLWh69XPNIcguwEVibZB/wOeAogKq6DrgY+GiSw8Av\ngUuqqvoYW9J49BIeVXXpItu/wsytXEm/JXzCVFITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NS\nE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1IT\nw0NSE8NDUhPDQ1ITw0NSE8NDUpOhwyPJhiT3Jnksye4kH5+jT5Jcm2RPkkeSnD3suJLGq4//6Pow\n8MmqeijJccAPktxVVY8N9HkvcGb3eTvwte5b0go19JFHVR2oqoe69s+Bx4H1s7ptAm6uGduB45Os\nG3ZsSePT6zWPJKcAZwEPzNq0Htg7sLyPVwaMpBWkt/BIcixwG/CJqnq+cR+bk+xIsuOlF/67r9Ik\nLYNewiPJUcwExzeq6ttzdNkPbBhYPrlb9xuqaktVTVXV1KpjX91HaZKWSR93WwLcADxeVV+ep9s0\n8MHursu5wKGqOjDs2JLGp4+7LecBHwB2JdnZrfsM8HqAqroO2AZcCOwBfgF8qIdxJY3R0OFRVf8G\nZJE+BXxs2LEkTQ6fMJXUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTw\nkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ\n1MTwkNTE8JDUZOjwSLIhyb1JHkuyO8nH5+izMcmhJDu7zzXDjitpvFb3sI/DwCer6qEkxwE/SHJX\nVT02q9/3qup9PYwnaQIMfeRRVQeq6qGu/XPgcWD9sPuVNNlSVf3tLDkFuB94S1U9P7B+I3AbsA/4\nMfCpqto9x89vBjZ3i28BHu2tuH6sBX467iIGWM/CJq0emLya3lhVx7X8YG/hkeRY4F+Bv62qb8/a\n9gfA/1XVC0kuBP6hqs5cZH87qmqql+J6Mmk1Wc/CJq0emLyahqmnl7stSY5i5sjiG7ODA6Cqnq+q\nF7r2NuCoJGv7GFvSePRxtyXADcDjVfXlefqc1PUjyTnduAeHHVvS+PRxt+U84APAriQ7u3WfAV4P\nUFXXARcDH01yGPglcEktfr60pYfa+jZpNVnPwiatHpi8mprr6fWCqaTfHT5hKqmJ4SGpycSER5LX\nJrkryZPd95p5+r008Jj79DLUcUGSJ5LsSXL1HNuPSbK12/5A92zLslpCTZcneXZgXj6yjLXcmOSZ\nJHM+g5MZ13a1PpLk7OW
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"# 2D kernel - HWIO layout\n",
"kernel = onp.zeros((3, 3, 3, 3), dtype=np.float32)\n",
"kernel += onp.array([[1, 1, 0],\n",
" [1, 0,-1],\n",
" [0,-1,-1]])[:, :, onp.newaxis, onp.newaxis]\n",
"\n",
"print(\"Edge Conv kernel:\")\n",
"plt.imshow(kernel[:, :, 0, 0]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dITPaPdh_cMI"
},
"source": [
"And we'll make a simple synthetic image:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 286
},
"colab_type": "code",
"id": "cpbGsIGa_Qyx",
"outputId": "44f0c042-3c74-4f39-9ed2-cd651cbc13fc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original Image:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQMAAAD8CAYAAABzYsGzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADO1JREFUeJzt3V2MXOV9x/Hvr6ZwkSIBhVoInNog\nJxJE1ZYgEqkBkbZJAFU19ILaiho3QTVIWGqlShWkUoPam6oNRYqaEBnVwkgNL2pFsCIScK0q3IQG\nO7F4CwRDjPDW2AUqSJsoic2/F/NsM3F2s7M7c3Zmh+9HOppznjkz53k865/Oy8z5p6qQpF8adwck\nTQbDQBJgGEhqDANJgGEgqTEMJAEdhkGSq5I8n+Rgklu62o6k0UgX3zNIsgb4LvAR4DDwBLClqp4d\n+cYkjURXewaXAQer6qWq+jFwH7Cpo21JGoFTOnrf84BX+pYPAx9YaOUkfg1S6s5rVXXOYit1FQaL\nSrIN2Dau7UvvIC8PslJXYTALrOtbPr+1/b+q2gHsAPcMpEnQ1TmDJ4CNSTYkORXYDOzuaFuSRqCT\nPYOqOp5kO/AIsAbYWVXPdLEtSaPRyaXFJXfCwwSpS/ur6tLFVvIbiJIAw0BSYxhIAgwDSY1hIAkw\nDCQ1hoEkwDCQ1BgGkgDDQFJjGEgCDANJjWEgCTAMJDWGgSTAMJDUGAaSAMNAUrPsMEiyLsm/J3k2\nyTNJ/rS135ZkNsmBNl0zuu5K6sowN0Q9Dvx5VX0ryenA/iR72nN3VNVnh++epJWy7DCoqiPAkTb/\n/STfoVdJSdIqNJJzBknWA78J/Edr2p7kySQ7k5w5im1I6tbQYZDkV4B/Bf6sqt4C7gQuBGbo7Tnc\nvsDrtiXZl2TfsH2QNLyh6iYk+WXgK8AjVfUP8zy/HvhKVb1vkfexboLUnW7rJiQJ8E/Ad/qDIMm5\nfatdBzy93G1IWjnDXE34LeCPgKeSHGhtnwa2JJkBCjgE3DhUDyWtCMurSdNvoMOErkqyT4SlJEw6\n64W0Ovh1ZEmAYSCpMQwkAYaBpMYwkAQYBpIaw0ASYBhIagwDSYBhIKmZ6q8j+xVjaXDuGUgCDANJ\njWEgCTAMJDWGgSTAMJDUGAaSgBF8zyDJIeD7wAngeFVdmuQs4H5gPb2bol5fVf897LYkdWdUewYf\nrqqZvpsu3gLsraqNwN62LGmCdXWYsAnY1eZ3Add2tB1JIzKKMCjg0ST7k2xrbWtbYVaAV4G1J7/I\n8mrSZBnFbxM+VFWzSX4N2JPkuf4nq6rmq4tQVTuAHWDdBGkSDL1nUFWz7fEY8CBwGXB0rsxaezw2\n7HYkdWuoMEjyriSnz80DH6VXW3E3sLWtthV4aJjtSOresIcJa4EHezVYOQX4UlV9LckTwANJbgBe\nBq4fcjuSOmatRWn6dVuSXdJ0MQwkAYaBpMYwkAQYBpIaw0ASYBhIagwDSYBhIKkxDCQBhoGkxjCQ\nBBgGkhrDQBJgGEhqDANJgGEgqTEMJAFD3AMxyXvplVCbcwHwV8AZwJ8A/9XaP11VDy+7h5JWxEju\ngZhkDTALfAD4JPA/VfXZJbzeeyBK3VnReyD+DvBiVb08oveTtMJGFQabgXv7lrcneTLJziRnzvcC\ny6tJk2Xow4QkpwL/CVxcVUeTrAVeo1eD8W+Ac6vqU4u8h4cJUndW7DDhauBbVXUUoKqOVtWJqnob\nuIteuTVJE24UYbCFvkOEuRqLzXX0yq1JmnBDlVdr9RU/AtzY1/x3SWboHSYcOuk5SRPK8mrS9LO8\nmqTBGQaSAMNAUmMYSAIMA0mNYSAJMAwkNYaBJMAwkNQYBpKAIX+boFVgKV/0Tme90CrgnoEkwDCQ\n1BgGkgDDQFJjGEgCDANJjWEgCRgwDFr9g2NJnu5rOyvJniQvtMczW3uSfC7JwVY74ZKuOi9pdAbd\nM7gbuOqktluAvVW1EdjblqF36/SNbdoG3Dl8NyV1baAwqKrHgDdOat4E7Grzu4Br+9rvqZ7HgTNO\nun26pAk0zDmDtVV1pM2/Cqxt8+cBr/Std7i1aRyyhEnvaCP5bUJV1VJvd55kG73DCEkTYJg9g6Nz\nu//t8VhrnwXW9a13fmv7GVW1o6ouHeR+7pK6N0wY7Aa2tvmtwEN97Z9oVxU+CLzZdzghaVJV1aIT\nvVqKR4Cf0DsHcAPwq/SuIrwA/BtwVls3wOeBF4GngEsHeP9ycnLqbNo3yP9zy6tJ08/yapIGZxhI\nAgwDSY1hIAkwDCQ1hoEkwDCQ1BgGkgDDQFJjGEgCDANJjWEgCTAMJDWGgSTAMJDUGAaSAMNAUmMY\nSAIGCIMFSqv9fZLnWvm0B5Oc0drXJ/lhkgNt+mKXnZc0OoPsGdzNz5dW2wO8r6p+A/gucGvfcy9W\n1UybbhpNNyV1bdEwmK+0WlU9WlXH2+Lj9GojSFrFRnHO4FPAV/uWNyT5dpKvJ7l8BO8vaQUMVV4t\nyV8Cx4F/bk1HgHdX1etJ3g98OcnFVfXWPK+1vJo0QZa9Z5Dkj4HfAz5ec5VQqn5UVa+3+f30Cqm8\nZ77XW15NmizLCoMkVwF/Afx+Vf2gr/2cJGva/AXARuClUXRUUrcWPUxIci9wJXB2ksPAZ+hdPTgN\n2JME4PF25eAK4K+T/AR4G7ipqt6Y940lTRTLq0nTz/JqkgZnGEgCDANJjWEgCTAMJDWGgSTAMJDU\nGAaSAMNAUmMYSAIMA0mNYSAJMAwkNYaBJMAwkNQYBpIAw0BSYxhIApZfXu22JLN9ZdSu6Xvu1iQH\nkzyf5GNddVzSaC23vBrAHX1l1B4GSHIRsBm4uL3mC3N3S5Y02ZZVXu0X2ATc1+onfA84CFw2RP8k\nrZBhzhlsb1WYdyY5s7WdB7zSt87h1iaNWS1hemdabhjcCVwIzNArqXb7Ut8gybYk+5LsW2YfJI3Q\nssKgqo5W1Ymqehu4i58eCswC6/pWPb+1zfcelleTJshyy6ud27d4HTB3pWE3sDnJaUk20Cuv9s3h\nuihpJSy3vNqVSWboHWAdAm4EqKpnkjwAPEuvOvPNVXWim65LGiXLq+kdYil/YumsF2NieTVJgzMM\nJAGGgaTGMJAEGAaSmkUvLUrTYequEIycewaSAMNAUmMYSAIMA0mNYSAJMAwkNYaBJMAwkNQYBpIA\nw0BSYxhIAgwDSY1hIAlYfq3F+/vqLB5KcqC1r0/yw77nvthl5yWNziA/Yb4b+EfgnrmGqvrDufkk\ntwNv9q3/YlXNjKqDklbGomFQVY8lWT/fc0kCXA/89mi7JWmlDXvO4HLgaFW90Ne2Icm3k3w9yeUL\nvdDyatJkGfZOR1uAe/uWjwDvrqrXk7wf+HKSi6vqrZNfWFU7gB1g3QRpEix7zyDJKcAfAPfPtbVS\n7K+3+f3Ai8B7hu2kpO4Nc5jwu8BzVXV4riHJOUnWtPkL6NVafGm4LkpaCYNcWrwX+Abw3iSHk9zQ\nntrMzx4iAFwBPNkuNf4LcFNVvTHKDkvqhrUWpelnrUVJgzMMJAGGgaTGMJAEGAaSGsNAEmAYSGoM\nA0mAYSCpMQwkAYaBpMYwkAQYBpIaw0ASYBhIagwDSYBhIKkxDCQBhoGkxjCQBBgGkpphKyqNymvA\n/7bHaXM20zkumN6xTdu4fn2QlSbiVukASfYNcjvn1WZaxwXTO7ZpHddiPEyQBBgGkppJCoMd4+5A\nR6Z1XDC9Y5vWcf1CE3POQNJ4TdKegaQxGnsYJLkqyfNJDia5Zdz9GVaSQ0meSnIgyb7WdlaSPUle\naI9njrufi0myM8mxJE/3tc07jvR8rn2GTya5ZHw9X9wCY7styWz73A4kuabvuVvb2J5P8rHx9Lp7\nYw2DJGuAzwNXAxcBW5JcNM4+jciHq2qm7/LULcDeqtoI7G3Lk+5u4KqT2hYax9XAxjZtA+5coT4u\n1938/NgA7mif20xVPQz
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"# NHWC layout\n",
"img = onp.zeros((1, 200, 198, 3), dtype=np.float32)\n",
"for k in range(3):\n",
" x = 30 + 60*k\n",
" y = 20 + 60*k\n",
" img[0, x:x+10, y:y+10, k] = 1.0\n",
"\n",
"print(\"Original Image:\")\n",
"plt.imshow(img[0]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "_m90y74OWorG"
},
"source": [
"### lax.conv and lax.conv_with_general_padding"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Pv9_QPDnWssM"
},
"source": [
"These are the simple convenience functions for convolutions\n",
"\n",
"️⚠️ The convenience `lax.conv`, `lax.conv_with_general_padding` helper function assume __NCHW__ images and __OIHW__ kernels."
2019-09-30 11:00:02 -07:00
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 629
},
"colab_type": "code",
"id": "kppxbxpZW0nb",
"outputId": "2c872f2b-b71a-4821-d870-0b3a4f1eeee9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"out shape: (1, 3, 200, 198)\n",
"First output channel:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGXdJREFUeJzt3X+spndZ5/HPtTPaxKmbtjvdpkKh\nhZQGNd2CYyVRCLtIbYmh4h/Qxigq2UICxGZNFDRZiImJq1a7ullMCQ2Q1AIuVhvTKl3WlWxilSk2\nY6EMTLEN0x3aDgWxxbB2uPaPc81wZjjTDnPOc55hzuuVnJz7+T4/7u/cuc/0Pff3eU6ruwMAQPKv\nlj0BAIBThTACABjCCABgCCMAgCGMAACGMAIAGMIIAGAsLIyq6sqq2ltV+6rqbYvaDwDARqlF/ILH\nqtqW5DNJXplkf5KPJ7m2uz+14TsDANgg2xf0upcn2dfdn0uSqvpAkquTrBlG287c0dvPOWdBUwEA\ntrKnHn88h554sk7ksYsKo2cl+fyq2/uT/NBxJ3HOOfmeX7x+QVMBALay/3vDjSf82KW9+bqqrquq\n3VW1+9ATTy5rGgAARywqjB5OcsGq28+esSO6+6bu3tXdu7aduWNB0wAAOHGLCqOPJ7m4qi6qqu9M\nck2S2xe0LwCADbGQ9xh191NV9ZYkf5FkW5Kbu/uTi9gXAMBGWdSbr9PddyS5Y1GvDwCw0fzmawCA\nIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYw\nAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggA\nYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABjCCABgCCMAgCGMAACGMAIAGCcdRlV1QVX9ZVV9qqo+WVW/MOPvrKqHq+re+XrVxk0XAGBxtq/j\nuU8l+cXu/kRVfXeSe6rqrrnvd7v7t9c/PQCAzXPSYdTdB5IcmO1/qqr7kzxroyYGALDZNuQ9RlV1\nYZIXJfmbGXpLVe2pqpur6uyN2AcAwKKtO4yq6swkH05yfXd/Jcm7kjw/yWVZuaJ0w3Ged11V7a6q\n3YeeeHK90wAAWLd1hVFVfUdWouiW7v7jJOnuR7r7UHd/Pcm7k1y+1nO7+6bu3tXdu7aduWM90wAA\n2BDr+VRaJXlPkvu7+3dWjZ+/6mGvSXLfyU8PAGDzrOdTaT+c5KeT/H1V3Ttjv5Lk2qq6LEkneTDJ\nG9c1QwCATbKeT6X9nyS1xl13nPx0AACWx2++BgAYwggAYAgjAIAhjAAAxno+lcZxnHX/ynvSd+5Z\n/C+uPHjpN34H1Jdf2AvfHwCczlwxAgAYwggAYFhKW4AjS2h37/nG4EsuXci+Vi+fnXvJwSTJY3t3\nLmRfAHC6c8UIAGC4YrRIq64S7Xvddy1kF9dfcec3jd2496qF7AsATneuGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAIzty57A6ejgpTuSJF9+YR8Zu/6KOxeyr7ee/dCR7d//0nMXsg8A2Cpc\nMQIAGMIIAGBYSluAw0to515ycOH7Wr18dstDP7jw/QHA6cwVIwCAIYwAAIaltAV6bO/OI9s37r1q\niTMBAE6EK0YAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABj+3pfoKoeTPJPSQ4leaq7d1XV\nOUk+mOTCJA8meW13f2m9+wIAWKSNumL077v7su7eNbffluSj3X1xko/ObQCAU9qiltKuTvK+2X5f\nkp9Y0H4AADbMRoRRJ/lIVd1TVdfN2HndfWC2v5DkvGOfVFXXVdXuqtp96IknN2AaAADrs+73GCX5\nke5+uKr+bZK7qurTq+/s7q6qPvZJ3X1TkpuS5IznXPBN9wMAbLZ1XzHq7ofn+6NJbktyeZJHqur8\nJJnvj653PwAAi7auMKqqHVX13Ye3k1yR5L4ktyd5/Tzs9Un+dD37AQDYDOtdSjsvyW1Vdfi1/rC7\n/7yqPp7kQ1X1hiQPJXntOvcDALBw6wqj7v5ckn+3xvgXk7xiPa8NALDZ/OZrAIAhjAAAhjACABjC\nCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACG\nMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMII\nAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCA\nIYwAAIYwAgAY20/2iVV1SZIPrhp6XpL/nOSsJP8xyWMz/ivdfcdJzxAAYJOcdBh1994klyVJVW1L\n8nCS25L8XJLf7e7f3pAZAgBsko1aSntFkge6+6ENej0AgE23UWF0TZJbV91+S1Xtqaqbq+rstZ5Q\nVddV1e6q2n3oiSc3aBoAACdv3WFUVd+Z5NVJ/miG3pXk+VlZZjuQ5Ia1ntfdN3X3ru7ete3MHeud\nBgDAum3EFaOrknyiux9Jku5+pLsPdffXk7w7yeUbsA8AgIXbiDC6NquW0arq/FX3vSbJfRuwDwCA\nhTvpT6UlSVXtSPLKJG9cNfybVXVZkk7y4DH3AQCcstYVRt39ZJJ/c8zYT69rRgAAS+I3XwMADGEE\nADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMDYvuwJAItz1v11ZHvnnicXuq+Dl+44sv3lF/ZC9wWwKK4YAQAMYQQAMCylwWns\nqOWzu/esfH/JpQvZ1+rls3MvOXhk+7G9OxeyP4BFcMUIAGAIIwCAYSkNtopZQtv3uu9ayMtff8Wd\na47fuPeqhewPYBFcMQIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAxvZlTwBYnIOX7jiy/eUXdpLk\n+ivuXMi+3nr2Q0e2f/9Lz13IPgAW7YSuGFXVzVX1aFXdt2rsnKq6q6o+O9/PnvGqqt+rqn1Vtaeq\nXryoyQMAbKQTXUp7b5Irjxl7W5KPdvfFST46t5PkqiQXz9d1Sd61/mkCACzeCS2ldffHqurCY4av\nTvLy2X5fkv+d5Jdn/P3d3Unurqqzqur87j6wERMGTtzh5bMkOfeSgwvd1+rls1se+sGF7gtgUdbz\n5uvzVsXOF5KcN9vPSvL5VY/bP2NHqarrqmp3Ve0+9MST65gGAMDG2JBPpc3VoX7GBx79nJu6e1d3\n79p25o5nfgIAwIKt51NpjxxeIquq85M8OuMPJ7lg1eOePWPAEj22d2eS5Ma9Vy15JgCnrvVcMbo9\nyetn+/VJ/nTV+M/Mp9N
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"out = lax.conv(np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
" np.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor\n",
2019-09-30 11:00:02 -07:00
" (1, 1), # window strides\n",
" 'SAME') # padding mode\n",
"print(\"out shape: \", out.shape)\n",
"print(\"First output channel:\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(onp.array(out)[0,0,:,:]);"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 629
},
"colab_type": "code",
"id": "aonr1tWvYCW9",
"outputId": "63727dd7-1758-4aa0-f93f-557758a160a8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"out shape: (1, 3, 202, 200)\n",
"First output channel:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGr1JREFUeJzt3X+s5XV95/HXe2dakg7dAAtLKKKg\nQaJtWLRTJGk17lpZMI3U/qGQxtLWLJqoKdluWrXJapo0cdvauu1mbTASMaGoXUolDbSybrdmk9I6\nWIL8cOpgIUIQGH9UwcYWfO8f9z0zZ3DGGeaec+905vFIbu73fM6P72e+fO/wnO/3fM+t7g4AAMm/\n2uwJAAAcLYQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABjZWFUVZdU1c6q2lVV71jVegAAlqVW8QGP\nVbUlyd8leXWSh5J8JskV3X3v0lcGALAkW1f0uhcm2dXdX0ySqvpoksuSHDCMtpy4rbeecsqKpgIA\nHO/+6UsP7e7u0w71uFWF0ZlJvrRw+6EkLzvoJE45JT/0y1evaCoAwPHugav/y4OH87hNe/N1VV1V\nVTuqasfTTzy5WdMAANhrVWH0cJKzFm4/Z8b26u5runt7d2/fcuK2FU0DAODwrSqMPpPk3Ko6p6q+\nP8nlSW5e0boAAJZiJe8x6u6nquptSf48yZYk13b3PatYFwDAsqzqzdfp7luS3LKq1wcAWDaffA0A\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQ\nRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgB\nAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAw\nhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBG\nAABDGAEADGEEADCEEQDAOOIwqqqzquovqureqrqnqn5pxt9TVQ9X1Z3z9ZrlTRcAYHW2ruO5TyX5\n5e7+bFX9YJI7quq2ue93u/u31z89AICNc8Rh1N2PJHlklr9ZVfclOXNZEwMA2GhLeY9RVZ2d5CVJ\n/nqG3lZVd1XVtVV18jLWAQCwausOo6o6McmNSa7u7m8k+UCSFyS5IGtHlN53kOddVVU7qmrH0088\nud5pAACs27rCqKq+L2tRdH13/3GSdPej3f10d38nyQeTXHig53b3Nd29vbu3bzlx23qmAQCwFOu5\nKq2SfCjJfd39OwvjZyw87HVJ7j7y6QEAbJz1XJX240nemORzVXXnjL0ryRVVdUGSTvJAkjeva4YA\nABtkPVel/b8kdYC7bjny6QAAbB6ffA0AMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABjPR/wyCGc\ndN++j3k69a7V/j643efv+7UqX39Rr3RdAHCscsQIAGA4YrRC+x0luv2ute8Xnb+SdS0eJTrtvN1J\nksd3nrqSdQHAscoRIwCAIYwAAIZTaRtlTqHtesMPrOTlr7741u8ae//OS1eyLgA4VjliBAAwhBEA\nwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABD\nGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMLZu9gSOZbvP37Z3+esv6iTJ1RffupJ1vf3kB/cu//7X\nnreSdQDAsc4RIwCA4YjRCu05SpQkp523e6XrWjxKdP2DP7bSdQHAscoRIwCAIYwAAIZTaRvk8Z2n\nJknev/PSTZ4JAHAwjhgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA2LreF6iqB5J8M8nTSZ7q7u1VdUqS\njyU5O8kDSV7f3V9b77oAAFZpWUeM/n13X9Dd2+f2O5J8qrvPTfKpuQ0AcFRb1am0y5JcN8vXJfnp\nFa0HAGBplhFGneSTVXVHVV01Y6d39yOz/OUkpz/zSVV1VVXtqKodTz/x5BKmAQCwPut+j1GSn+ju\nh6vq3ya5rao+v3hnd3dV9TOf1N3XJLkmSU547lnfdT8AwEZb9xGj7n54vj+W5KYkFyZ5tKrOSJL5\n/th61wMAsGrrCqOq2lZVP7hnOcnFSe5OcnOSK+dhVyb5xHrWAwCwEdZ7Ku30JDdV1Z7X+sPu/rOq\n+kySj1fVm5I8mOT161wPAMDKrSuMuvuLSf7dAca/kuRV63ltAICN5pOvAQCGMAIAGMIIAGAIIwCA\nIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYw\nAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggA\nYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABhbj/SJVXVeko8tDD0/yX9NclKS/5Tk8Rl/V3ffcsQzBADYIEccRt29M8kFSVJVW5I8nOSmJL+Q\n5He7+7eXMkMAgA2yrFNpr0pyf3c/uKTXAwDYcMsKo8uT3LBw+21VdVdVXVtVJx/oCVV1VVXtqKod\nTz/x5JKmAQBw5NYdRlX1/Ulem+SPZugDSV6QtdNsjyR534Ge193XdPf27t6+5cRt650GAMC6LeOI\n0aVJPtvdjyZJdz/a3U9393eSfDDJhUtYBwDAyi0jjK7Iwmm0qjpj4b7XJbl7CesAAFi5I74qLUmq\naluSVyd588Lwb1bVBUk6yQPPuA8A4Ki1rjDq7ieT/JtnjL1xXTMCANgkPvkaAGAIIwCAIYwAAIYw\nAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggA\nYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgbN3sCQCrddJ9lSQ59a4nV76u3edvS5J8/UW98nUBrIIjRgAAQxgBAAyn0uAY\nt/cU2u137Ru86PyVrGvPKbTTztu9d+zxnaeuZF0Aq+CIEQDAEEYAAMOpNDheLJw+2/WGH1jJKq6+\n+NbvGnv/zktXsi6AVXDECABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAcVhhVFXXVtVjVXX3wtgpVXVbVX1hvp88\n41VVv1dVu6rqrqp66aomDwCwTId7xOjDSS55xtg7knyqu89N8qm5nSSXJjl3vq5K8oH1TxMAYPW2\nHs6DuvvTVXX2M4YvS/LKWb4uyf9N8qsz/pHu7iS3V9VJVXVGdz+yjAkDz87u87clSb7+ot47dvXF\nt65kXW8/+cEkye9/7XkreX2AVVvPe4xOX4idLyc5fZbPTPKlhcc9NGMAAEe1pbz5eo4O9SEfuKCq\nrqqqHVW14+knnlzGNAAA1uWwTqUdxKN7TpFV1RlJHpvxh5OctfC458zYfrr7miTXJMkJzz3rWUUV\ncPj2nEI77bzdK1/XnlNo1z/4YytfF8AqrOeI0c1JrpzlK5N8YmH85+bqtIuS/IP3FwEA/xIc1hGj\nqroha2+0PrWqHkry7iT
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"out = lax.conv_with_general_padding(\n",
" np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
" np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor\n",
" (1, 1), # window strides\n",
" ((2,2),(2,2)), # general padding 2x2\n",
" (1,1), # lhs/image dilation\n",
" (1,1)) # rhs/kernel dilation\n",
"print(\"out shape: \", out.shape)\n",
"print(\"First output channel:\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(onp.array(out)[0,0,:,:]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "lyOwGRez_ycJ"
},
"source": [
"### Dimension Numbers define dimensional layout for conv_general_dilated\n",
"\n",
"The important argument is the 3-tuple of axis layout arguments:\n",
"(Input Layout, Kernel Layout, Output Layout)\n",
" - __N__ - batch dimension\n",
" - __H__ - spatial height\n",
" - __W__ - spatial height\n",
" - __C__ - channel dimension\n",
" - __I__ - kernel _input_ channel dimension\n",
" - __O__ - kernel _output_ channel dimension\n",
"\n",
"⚠️ To demonstrate the flexibility of dimension numbers we choose a __NHWC__ image and __HWIO__ kernel convention for `lax.conv_general_dilated` below."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "oXKebfCb_i2B",
"outputId": "0b80fca6-0eb7-4baf-d824-458c3739d052"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n"
]
}
],
"source": [
"dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape\n",
" kernel.shape, # only ndim matters, not shape \n",
" ('NHWC', 'HWIO', 'NHWC')) # the important bit\n",
"print(dn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "elZys_HzFVG6"
},
"source": [
"#### SAME padding, no stride, no dilation"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 629
},
"colab_type": "code",
"id": "rgb2T15aFVG6",
"outputId": "93fed3a7-69d2-4046-de2f-487ff34b5ee2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"out shape: (1, 200, 198, 3)\n",
"First output channel:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGXdJREFUeJzt3X+spndZ5/HPtTPaxKmbtjvdpkKh\nhZQGNd2CYyVRCLtIbYmh4h/Qxigq2UICxGZNFDRZiImJq1a7ullMCQ2Q1AIuVhvTKl3WlWxilSk2\nY6EMTLEN0x3aDgWxxbB2uPaPc81wZjjTDnPOc55hzuuVnJz7+T4/7u/cuc/0Pff3eU6ruwMAQPKv\nlj0BAIBThTACABjCCABgCCMAgCGMAACGMAIAGMIIAGAsLIyq6sqq2ltV+6rqbYvaDwDARqlF/ILH\nqtqW5DNJXplkf5KPJ7m2uz+14TsDANgg2xf0upcn2dfdn0uSqvpAkquTrBlG287c0dvPOWdBUwEA\ntrKnHn88h554sk7ksYsKo2cl+fyq2/uT/NBxJ3HOOfmeX7x+QVMBALay/3vDjSf82KW9+bqqrquq\n3VW1+9ATTy5rGgAARywqjB5OcsGq28+esSO6+6bu3tXdu7aduWNB0wAAOHGLCqOPJ7m4qi6qqu9M\nck2S2xe0LwCADbGQ9xh191NV9ZYkf5FkW5Kbu/uTi9gXAMBGWdSbr9PddyS5Y1GvDwCw0fzmawCA\nIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYw\nAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggA\nYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABjCCABgCCMAgCGMAACGMAIAGCcdRlV1QVX9ZVV9qqo+WVW/MOPvrKqHq+re+XrVxk0XAGBxtq/j\nuU8l+cXu/kRVfXeSe6rqrrnvd7v7t9c/PQCAzXPSYdTdB5IcmO1/qqr7kzxroyYGALDZNuQ9RlV1\nYZIXJfmbGXpLVe2pqpur6uyN2AcAwKKtO4yq6swkH05yfXd/Jcm7kjw/yWVZuaJ0w3Ged11V7a6q\n3YeeeHK90wAAWLd1hVFVfUdWouiW7v7jJOnuR7r7UHd/Pcm7k1y+1nO7+6bu3tXdu7aduWM90wAA\n2BDr+VRaJXlPkvu7+3dWjZ+/6mGvSXLfyU8PAGDzrOdTaT+c5KeT/H1V3Ttjv5Lk2qq6LEkneTDJ\nG9c1QwCATbKeT6X9nyS1xl13nPx0AACWx2++BgAYwggAYAgjAIAhjAAAxno+lcZxnHX/ynvSd+5Z\n/C+uPHjpN34H1Jdf2AvfHwCczlwxAgAYwggAYFhKW4AjS2h37/nG4EsuXci+Vi+fnXvJwSTJY3t3\nLmRfAHC6c8UIAGC4YrRIq64S7Xvddy1kF9dfcec3jd2496qF7AsATneuGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAIzty57A6ejgpTuSJF9+YR8Zu/6KOxeyr7ee/dCR7d//0nMXsg8A2Cpc\nMQIAGMIIAGBYSluAw0to515ycOH7Wr18dstDP7jw/QHA6cwVIwCAIYwAAIaltAV6bO/OI9s37r1q\niTMBAE6EK0YAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABj+3pfoKoeTPJPSQ4leaq7d1XV\nOUk+mOTCJA8meW13f2m9+wIAWKSNumL077v7su7eNbffluSj3X1xko/ObQCAU9qiltKuTvK+2X5f\nkp9Y0H4AADbMRoRRJ/lIVd1TVdfN2HndfWC2v5DkvGOfVFXXVdXuqtp96IknN2AaAADrs+73GCX5\nke5+uKr+bZK7qurTq+/s7q6qPvZJ3X1TkpuS5IznXPBN9wMAbLZ1XzHq7ofn+6NJbktyeZJHqur8\nJJnvj653PwAAi7auMKqqHVX13Ye3k1yR5L4ktyd5/Tzs9Un+dD37AQDYDOtdSjsvyW1Vdfi1/rC7\n/7yqPp7kQ1X1hiQPJXntOvcDALBw6wqj7v5ckn+3xvgXk7xiPa8NALDZ/OZrAIAhjAAAhjACABjC\nCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACG\nMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMII\nAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCA\nIYwAAIYwAgAY20/2iVV1SZIPrhp6XpL/nOSsJP8xyWMz/ivdfcdJzxAAYJOcdBh1994klyVJVW1L\n8nCS25L8XJLf7e7f3pAZAgBsko1aSntFkge6+6ENej0AgE23UWF0TZJbV91+S1Xtqaqbq+rstZ5Q\nVddV1e6q2n3oiSc3aBoAACdv3WFUVd+Z5NVJ/miG3pXk+VlZZjuQ5Ia1ntfdN3X3ru7ete3MHeud\nBgDAum3EFaOrknyiux9Jku5+pLsPdffXk7w7yeUbsA8AgIXbiDC6NquW0arq/FX3vSbJfRuwDwCA\nhTvpT6UlSVXtSPLKJG9cNfybVXVZkk7y4DH3AQCcstYVRt39ZJJ/c8zYT69rRgAAS+I3XwMADGEE\nADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMDYvuwJAItz1v11ZHvnnicXuq+Dl+44sv3lF/ZC9wWwKK4YAQAMYQQAMCylwWns\nqOWzu/esfH/JpQvZ1+rls3MvOXhk+7G9OxeyP4BFcMUIAGAIIwCAYSkNtopZQtv3uu9ayMtff8Wd\na47fuPeqhewPYBFcMQIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAxvZlTwBYnIOX7jiy/eUXdpLk\n+ivuXMi+3nr2Q0e2f/9Lz13IPgAW7YSuGFXVzVX1aFXdt2rsnKq6q6o+O9/PnvGqqt+rqn1Vtaeq\nXryoyQMAbKQTXUp7b5Irjxl7W5KPdvfFST46t5PkqiQXz9d1Sd61/mkCACzeCS2ldffHqurCY4av\nTvLy2X5fkv+d5Jdn/P3d3Unurqqzqur87j6wERMGTtzh5bMkOfeSgwvd1+rls1se+sGF7gtgUdbz\n5uvzVsXOF5KcN9vPSvL5VY/bP2NHqarrqmp3Ve0+9MST65gGAMDG2JBPpc3VoX7GBx79nJu6e1d3\n79p25o5nfgIAwIKt51NpjxxeIquq85M8OuMPJ7lg1eOePWPAEj22d2eS5Ma9Vy15JgCnrvVcMbo9\nyetn+/VJ/nTV+M/Mp9N
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
" kernel, # rhs = conv kernel tensor\n",
" (1,1), # window strides\n",
" 'SAME', # padding mode\n",
" (1,1), # lhs/image dilation\n",
" (1,1), # rhs/kernel dilation\n",
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
"print(\"out shape: \", out.shape)\n",
"print(\"First output channel:\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "E4i3TI5JFVG9"
},
"source": [
"#### VALID padding, no stride, no dilation"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 629
},
"colab_type": "code",
"id": "1HQwudKVFVG-",
"outputId": "9edf1704-b920-4666-8a49-2db7dd622fd1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"out shape: (1, 198, 196, 3) DIFFERENT from above!\n",
"First output channel:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGXFJREFUeJzt3X+spndZ5/HPtTNK4tRNpzvdpkKh\nhZQGNd2qYyVRCLtIbYmh4h/Qxigq2UICxGZNFDRZiImJq1a7ullMCQ2Q1AKK1WbTKl3WlWxilSk2\ntVBGptiG6da2Q0FsMawdrv1jrhme6ZxhxjnnOc845/VKTs79fJ8f93fu3Gf6nvvHaXV3AABI/tWq\nJwAAcLoQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEAjKWFUVVdWVV7q2pfVb19WesBANgotYxf8FhV\n25L8TZJXJdmf5BNJru3uT2/4ygAANsj2JX3u5Un2dffnkqSqPpjk6iRrhtG2s3b09nPOWdJUAICt\n7Jknn8zBp56uk3ntssLouUk+v/B4f5LvO+4kzjkn3/az1y9pKgDAVvZ/b7jxpF+7souvq+q6qtpT\nVXsOPvX0qqYBAHDEssLokSQXLDx+3owd0d03dffu7t697awdS5oGAMDJW1YYfSLJxVV1UVV9c5Jr\nkty+pHUBAGyIpVxj1N3PVNVbk/xJkm1Jbu7uTy1jXQAAG2VZF1+nu+9IcseyPh8AYKP5zdcAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQ\nRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgB\nAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAw\nhBEAwBBGAABDGAEADGEEADCEEQDAOOUwqqoLqupPq+rTVfWpqvqZGX9XVT1SVffO16s3broAAMuz\nfR3vfSbJz3b3J6vqW5PcU1V3zXO/2d2/vv7pAQBsnlMOo+5+NMmjs/wPVfVAkudu1MQAADbbhlxj\nVFUXJvmuJH8xQ2+tqvuq6uaq2nmc91xXVXuqas/Bp57eiGkAAKzLusOoqs5K8pEk13f3l5O8O8mL\nklyWQ0eUbljrfd19U3fv7u7d287asd5pAACs27rCqKq+KYei6Jbu/oMk6e7Huvtgd38tyXuSXL7+\naQIALN967kqrJO9N8kB3/8bC+PkLL3ttkvtPfXoAAJtnPXelfX+SH0/y11V174z9QpJrq+qyJJ3k\noSRvWtcMAQA2yXruSvs/SWqNp+449ekAAKyO33wNADDWcyqNBWc/8PWDZ7vuW/6vHzhw6dfv5PvS\nS3rp6wOArcARIwCAIYwAAIZTaRvkqNNnd9/39eWXXrqU9S2ePjv3kgNJkif27lrKugBgq3DECABg\nOGK0DAtHifa9/luWsorrr7jzmLEb9161lHUBwFbhiBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMDYvuoJnCkOXLrjyPKXXtJHlq+/4s6lrO9tOx8+svzbX3zBUtYBAFuNI0YAAEMYAQAMp9I2\nyOLps3MvObD09S2ePrvl4e9d+voAYCtwxAgAYAgjAIDhVNoSPLF315HlG/detcKZAAD/HI4YAQAM\nYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAGP7ej+gqh5K8g9JDiZ5prt3\nV9U5ST6U5MIkDyV5XXd/cb3rAgBYpo06YvTvu/uy7t49j9+e5GPdfXGSj81jAIDT2rJOpV2d5P2z\n/P4kP7Kk9QAAbJiNCKNO8tGquqeqrpux87r70Vn+uyTnbcB6AACWat3XGCX5ge5+pKr+bZK7quoz\ni092d1dVP/tNE1HXJcm2nTs3YBoAAOuz7iNG3f3IfH88yW1JLk/yWFWdnyTz/fE13ndTd+/u7t3b\nztqx3mkAAKzbusKoqnZU1bceXk5yRZL7k9ye5A3zsjck+aP1rAcAYDOs91TaeUluq6rDn/W73f3H\nVfWJJB+uqjcmeTjJ69a5HgCApVtXGHX355L8uzXGv5Dklev5bACAzeY3XwMADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxh\nBAAwhBEAwBBGAABj+6m+saouSfKhhaEXJvnPSc5O8h+TPDHjv9Ddd5zyDAEANskph1F3701yWZJU\n1bYkjyS5LclPJfnN7v71DZkhAMAm2ahTaa9M8mB3P7xBnwcAsOk2KoyuSXLrwuO3VtV9VXVzVe1c\n6w1VdV1V7amqPQefenqDpgEAcOrWHUZV9c1JXpPk92bo3UlelEOn2R5NcsNa7+vum7p7d3fv3nbW\njvVOAwBg3TbiiNFVST7Z3Y8lSXc/1t0Hu/trSd6T5PINWAcAwNJtRBhdm4XTaFV1/sJzr01y/was\nAwBg6U75rrQkqaodSV6V5E0Lw79aVZcl6SQPPes5AIDT1rrCqLufTvJvnjX24+uaEQDAivjN1wAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxh\nBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEA\nwBBGAABj+6onACzH2Q/UkeVd9z299PUduHTHkeUvvaSXvj6AZXDECABgCCMAgOFUGpyhjjp9dvd9\nX19+6aVLWd/i6bNzLzmQJHli766lrAtgWRwxAgAYjhjBVrBwlGjf679lKau4/oo7jxm7ce9VS1kX\nwLI4YgQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADC2r3oCwHIcuHTHkeUvvaSPLF9/xZ1L\nWd/bdj58ZPm3v/iCpawDYNkcMQIAGMIIAGA4lQZnqMXTZ+decmDp61s8fXbLw9+79PUBLMNJHTGq\nqpur6vGqun9h7JyququqPjvfd854VdVvVdW+qrqvqr57WZMHANhIJ3sq7X1JrnzW2NuTfKy7L07y\nsXmcJFcluXi+rkvy7vVPEwBg+U7qVFp3f7yqLnzW8NVJXjHL70/yv5P8/Ix/oLs7yd1VdXZVnd/d\nj27EhIF/vif27jqyfOPeq1Y4E4DT23ouvj5vIXb+Lsl5s/zcJJ9feN3+GTtKVV1XVXuqas/Bp55e\nxzQAADbGhtyVNkeH+oQvPPo9N3X37u7eve2sHSd+AwDAkq0njB6rqvOTZL4/PuOPJLlg4XXPmzEA\ngNPaesLo9iRvmOU3JPm
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
" kernel, # rhs = conv kernel tensor\n",
" (1,1), # window strides\n",
" 'VALID', # padding mode\n",
" (1,1), # lhs/image dilation\n",
" (1,1), # rhs/kernel dilation\n",
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
"print(\"out shape: \", out.shape, \"DIFFERENT from above!\")\n",
"print(\"First output channel:\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "VYKZdqLIFVHB"
},
"source": [
"#### SAME padding, 2,2 stride, no dilation"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 629
},
"colab_type": "code",
"id": "mKq2-zmmFVHC",
"outputId": "73b80162-fdca-4f6a-ec06-4645d6fdc9f6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"out shape: (1, 100, 99, 3) <-- half the size of above\n",
"First output channel:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAJCCAYAAAAvEKYoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFHtJREFUeJzt3V+sZWdZx/HfYw8tzKi0RZnUtsoY\niIaYKHZKEKwxU00QiJ0LgjRqiqnpjX9QNLZ6Y0ww0cSAXBhMQ9VekAKppG2M0RBaknrTdEpNKq1K\nU4TOpP8MFk1rgBMfL85WBuaUczpn7/OH5/O5mbPevU7fN1lZk2/Xfvee6u4AAEzxbXu9AACA3SR+\nAIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBglB3FT1W9uar+paoeraqblrUoAIBVqXP9hueq\nOi/Jvyb56SSnktyf5NrufviFfmft0OF+ycsvPqf5AAC+ma9+6YtZf/652uq8tR3M8fokj3b3Y0lS\nVR9Jck2SF4yfl7z84hx913t2MCUAwOY+91fv29Z5O3nb69Ikj59xfGoxBgCwb618w3NV3VBVJ6vq\n5Przz616OgCAb2on8XM6yeVnHF+2GPs63X1zdx/r7mNrhw7vYDoAgJ3bSfzcn+Q1VXW0qs5P8s4k\ndy1nWQAAq3HOG567e72qfjXJ3yc5L8lfdPdnlrYyAIAV2MmnvdLdf5vkb5e0FgCAlfMNzwDAKOIH\nABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCA\nUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF\n/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQP\nADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAA\no4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK\n+AEARhE/AMAo4gcAGEX8AACjiB8AYJS1vV7AQXPk/i+fNbZ29wMrmWv9+BWbjj915QUrmQ8AJvDk\nBwAYRfwAAKOIHwBgFPEDAIwifgCAUXza60Xa7JNdp29840rmOnHtvZuO33HbVSuZDwAm8OQHABhF\n/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQP\nADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMsrbXCzho1o9fcdbYiWvvXclc733lQ5uO35Gr\nVjIfAEzgyQ8AMIr4AQBGET8AwCjiBwAYZcv4qarLq+qeqnq4qj5TVe9ejF9cVZ+oqs8u/rxo9csF\nANiZ6u5vfkLVJUku6e5PV9V3JHkgyYkk70ryxe7+o6q6KclF3X3jN/tvveySy/vou96znJUDAJzh\nc3/1vvz3E4/XVudt+eSnu5/o7k8vfv6vJI8kuTTJNUluXZx2azaCCABgX3tRe36q6lVJXpfkviRH\nuvuJxUtPJjmy1JUBAKzAtuOnqr49yV8n+Y3u/s8zX+uN9842ff+sqm6oqpNVdXL9+ed2tFgAgJ3a\nVvxU1UuyET4f7u6PL4afWuwH+r99QU9v9rvdfXN3H+vuY2uHDi9jzQAA52w7n/aqJLckeaS733fG\nS3cluW7x83VJ7lz+8gAAlms7/7bXm5L8YpKHquofF2O/l+SPknysqq5P8vkk71jNEgEAlmfL+Onu\nf0jyQh8bu3q5ywEAWC3f8AwAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwi\nfgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIH\nABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCA\nUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF\n/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQP\nADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAA\no4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK\n+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gf\nAGAU8QMAjCJ+AIBRxA8AMIr4AQBG2Xb8VNV5VfVgVf3N4vhoVd1XVY9W1Uer6vzVLRMAYDlezJOf\ndyd55IzjP07y/u5+dZL/SHL9MhcGALAK24qfqrosyVuTfGhxXEmOJ7l9ccqtSU6sYoEAAMu03Sc/\nf5rkd5L8z+L4FUme7e71xfGpJJdu9otVdUNVnayqk+vPP7ejxQIA7NSW8VNVb0vydHc/cC4TdPfN\n3X2su4+tHTp8Lv8JAIClWdvGOW9K8rNV9ZYkL03ynUk+kOTCqlpbPP25LMnp1S0TAGA5tnzy092/\n292Xdferkrwzyd3d/fNJ7kny9sVp1yW5c2WrBABYkp18z8+NSd5TVY9mYw/QLctZEgDA6mznba//\n192fSvKpxc+PJXn98pcEALA6vuEZABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCji\nBwAYRfwAAKOIHwBgFPEDAIwifgCAUV7Uv+oO7B9H7v/ypuNrdz+wkvnWj19x1thTV16wkrkAVsmT\nHwBgFPEDAIwifgCAUcQPADCK+AEARvFpLzigXuhTXadvfONK5jtx7b1njd1x21UrmQtglTz5AQBG\nET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTx\nAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gfAGCUtb1eAHBu1o9fsen4iWvvXcl8733l\nQ2eN3ZGrVjIXwCp58gMAjCJ+AIBRxA8AMIr4AQBGseEZDqinrrxg0/E7blvNJmSbm4FvFZ78AACj\niB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4\nAQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8A\nYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBG\nET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTx\nAwCMIn4AgFHEDwAwyrbip6ourKrbq+qfq+qRqvqxqrq4qj5RVZ9d/HnRqhcLALBT233y84Ekf9fd\nP5jkh5M8kuSmJJ/s7tck+eTiGABgX9syfqrq5Ul+IsktSdLdX+nuZ5Nck+TWxWm3JjmxqkUCACzL\ndp78HE3yTJK/rKoHq+pDVXU4yZHufmJxzpNJjmz2y1V1Q1WdrKqT688/t5xVAwCco+3Ez1qSH03y\nwe5+XZLn8g1vcXV3J+nNfrm7b+7uY919bO3Q4Z2uFwBgR7YTP6eSnOru+xbHt2cjhp6qqkuSZPHn\n06tZIgDA8mwZP939ZJL
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
" kernel, # rhs = conv kernel tensor\n",
" (2,2), # window strides\n",
" 'SAME', # padding mode\n",
" (1,1), # lhs/image dilation\n",
" (1,1), # rhs/kernel dilation\n",
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
"print(\"out shape: \", out.shape, \" <-- half the size of above\")\n",
"plt.figure(figsize=(10,10))\n",
"print(\"First output channel:\")\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "gPxttaiaFVHE"
},
"source": [
"#### VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 629
},
"colab_type": "code",
"id": "_pGr0x6qFVHF",
"outputId": "a0f489eb-bab7-42c4-c030-fb756f63a4a4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"out shape: (1, 176, 174, 3)\n",
"First output channel:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkUAAAJCCAYAAADOe7N5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGwpJREFUeJzt3X+spXd92Pn3Zz2xU2xtjeOGEtu7\ndm2TikTNBk1ZV9lWCXQbk0Zx/oiCadq4LZLlLk0TSjeBRFp2/0BK2qpuot2CvIHitAhMKQ1WlP6g\nLilaaQ2ZkITfhAECHsvERAm0OFoTO9/94x52r+yZ2sy512fmzuslje45z3nOPZ+HZ+b6zXOe59xZ\nawUAcKH7r3Y9AADAuUAUAQAkigAAKlEEAFCJIgCAShQBAFSiCACgOsQompmbZ+YTM3NyZl59WK8D\nAHAQ5jA+vHFmLqp+q/ofq1PVr1YvW2t99MBfDADgABw7pO/7wurkWuvTVTPztuqW6rRRdNFll65j\nV1xxSKMAABeyrzxw6nfXWn/iqdY7rCi6qnpg3/1T1X9/xiGuuKJvetWPHdIoAMCF7Ld/7O999ums\nt7MTrWfm9pk5MTMnHv/yI7saAwCgOrwoerC6Zt/9qzfL/j9rrbvWWsfXWscvuuzSQxoDAODpOawo\n+tXqxpm5bmYurm6t7j2k1wIA2NqhnFO01npsZv529W+ri6o3rbU+chivBQBwEA7rROvWWr9c/fJh\nfX8AgIPkE60BABJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEA\nQHWIvxD2a3HJA490wyvvP5DvdfLOm067/KC+/9k400wAwLnDkSIAgEQRAEAligAAKlEEAFCJIgCA\nShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBA\nJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAqo7teoCqR6+5\ntJOvuulQX+PknYf7/QGA85sjRQAAiSIAgEoUAQBUoggAoBJFAACVKAIAqM6RS/IveeCRbnjl/Qfy\nvc6nS+8/9dI37Oy1r7/njp29NgCcixwpAgBIFAEAVKIIAKDaIopm5pqZec/MfHRmPjIzP7pZfsXM\nvHtmPrn5+uyDGxcA4HBsc6TosepVa63nVzdVr5iZ51evru5ba91Y3be5DwBwTjvrKFprPbTW+sDm\n9n+uPlZdVd1S3b1Z7e7q+7cdEgDgsB3IOUUzc2317dX7questR7aPPT56jkH8RoAAIdp6yiamcuq\nf1n92FrrP+1/bK21qnWG590+Mydm5sQf9ui2YwAAbGWrKJqZr2sviN6y1nrnZvHvzMxzN48/t3r4\ndM9da9211jq+1jr+dV2yzRgAAFvb5uqzqd5YfWyt9Y/2PXRvddvm9m3Vu85+PACAZ8Y2v+bjO6q/\nVn1oZn5js+wnq5+u3j4zL68+W/3gdiMCABy+s46itdb/Vc0ZHn7x2X5fAIBd8InWAACJIgCAShQB\nAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoA\nACpRBABQiSIAgKqO7XqAqkevubSTr7pp12M8466/545djwAAbDhSBACQKAIAqEQRAEAligAAKlEE\nAFCdI1efPRM+9dI37Oy1z3SV2bk4EwBcqBwpAgBIFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpR\nBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUo\nAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAA1QFE0cxcNDO/PjO/tLl/3cy8b2ZOzsw9\nM3Px9mMCAByugzhS9KPVx/bd/5nqzrXWDdXvVy8/gNcAADhUx7Z58sxcXf3l6nXV352ZqV5U/ZXN\nKndX/2v1+m1e5yBcf88dux7hSc7FmQDgQrXtkaJ/XP149Ueb+99QfXGt9djm/qnqqtM9cWZun5kT\nM3Pi8S8/suUYAADbOesompnvrR5ea/3a2Tx/rXXXWuv4Wuv4RZdderZjAAAciG3ePvuO6vtm5nuq\nr6/+6+pnq8tn5tjmaNHV1YPbjwkAcLjO+kjRWus1a62r11rXVrdW/2Gt9UPVe6of2Kx2W/WuracE\nADhkh/E5RT/R3knXJ9s7x+iNh/AaAAAHaqurz75qrfUr1a9sbn+6euFBfF8AgGeKT7QGAEgUAQBU\noggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAq\nUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACV\nKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBK\nFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqLaMopm5\nfGbeMTMfn5mPzcyfm5krZubdM/PJzddnH9SwAACHZdsjRT9b/Zu11p+uvq36WPXq6r611o3VfZv7\nAADntLOOopn549VfqN5Ytdb6ylrri9Ut1d2b1e6uvn/bIQEADts2R4quq75Q/dOZ+fWZ+fmZubR6\nzlrroc06n6+ec7onz8ztM3NiZk48/uVHthgDAGB720TRseoF1evXWt9ePdIT3ipba61qne7Ja627\n1lrH11rHL7rs0i3GAADY3jZRdKo6tdZ63+b+O9qLpN+ZmedWbb4+vN2IAACH76yjaK31+eqBmfnm\nzaIXVx+t7q1u2yy7rXrXVhMCADwDjm35/B+p3jIzF1efrv5Ge6H19pl5efXZ6ge3fA0AgEO3VRSt\ntX6jOn6ah168zfcFAHim+URrAIBEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQR\nAEAligAAqu1/ISxckG545f07e+2Td9502uXn4kwA5xNHigAAEkUAAJUoAgCoRBEAQCWKAAAqUQQA\nUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIA\nqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUNWxXQ8A56OT\nd9606xGe5FycCeB84kgRAECiCACgEkUAAJUoAgCoRBEAQOXqMzgrN7zy/p299vl0ldmnXvqGnb32\n9ffcsbPXBs5PjhQBACSKAAAqUQQAUG0ZRTPzypn5yMx8eGbeOjNfPzPXzcz7ZubkzNwzMxcf1LAA\nAIflrKNoZq6q/k51fK31rdVF1a3Vz1R3rrVuqH6/evlBDAoAcJi2ffvsWPXHZuZY9azqoepF1Ts2\nj99dff+WrwEAcOjOOorWWg9W/7D6XHsx9KXq16ovrrUe26x2qrpq2yEBAA7bNm+fPbu6pbqu+qbq\n0urmr+H5t8/MiZk58fiXHznbMQAADsQ2b5/9xeoza60vrLX+sHpn9R3V5Zu306qurh483ZPXWnet\ntY6vtY5fdNmlW4wBALC9baLoc9VNM/OsmZnqxdVHq/dUP7BZ57bqXduNCABw+LY5p+h97Z1Q/YHq\nQ5vvdVf1E9XfnZmT1TdUbzyAOQEADtVWv/tsrfXa6rVPWPzp6oXbfF8AgGeaT7QGAEgUAQBUoggA\noBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQA\nUIkiAIBKFAEAVHVs1wP
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
" kernel, # rhs = conv kernel tensor\n",
" (1,1), # window strides\n",
" 'VALID', # padding mode\n",
" (1,1), # lhs/image dilation\n",
" (12,12), # rhs/kernel dilation\n",
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
"print(\"out shape: \", out.shape)\n",
"plt.figure(figsize=(10,10))\n",
"print(\"First output channel:\")\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "v-RhEeUfFVHI"
},
"source": [
"#### VALID padding, no stride, lhs=input dilation ~ Transposed Convolution"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 629
},
"colab_type": "code",
"id": "B9Ail8ppFVHJ",
"outputId": "03d00b5a-ec38-435a-81f7-79d4737239c0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"out shape: (1, 399, 395, 3) <-- larger than original!\n",
"First output channel:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGYRJREFUeJzt3H+s5XV95/HXuzOIDZgC1SUU2JVF\nNsY26UimlKZN42pskX/GJq7iJpU1JiO7mKhxN6L/1CY1sZu1bMzuQmikYtMWiNZIGvqDVZKmf4CO\nFpEf2l78EZgdYetvMEvD9L1/3Dd6Z5xh7sy9557p3Mcjubnf8znfc8/nfPhe8pzzPd9b3R0AAJKf\nWPYEAABOFsIIAGAIIwCAIYwAAIYwAgAYwggAYCwsjKrqiqr6clWtVNV1i3oeAIDNUov4O0ZVtSPJ\n3yV5dZLHknw2yRu7+6FNfzIAgE2yqHeMLkuy0t1f6e5/THJrkj0Lei4AgE2xc0E/9/wkj665/ViS\nXzzazjvOPKN3nnPOgqYCAGxnz3zrWzn45FO1nn0XFUbHVFV7k+xNkh1nn52fedc7ljUVAOAU9n8+\n+N/Xve+iTqXtT3LhmtsXzNgPdfdN3b27u3fvOPOMBU0DAGD9FhVGn01ySVVdVFXPS3JVkjsW9FwA\nAJtiIafSuvuZqnpbkr9MsiPJzd394CKeCwBgsyzsM0bdfWeSOxf18wEANpu/fA0AMIQRAMAQRgAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxh\nBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEA\nwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABD\nGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEE\nADCEEQDAEEYAAEMYAQCMnRt5cFV9Lcn3kxxM8kx3766qc5LcluTFSb6W5PXd/e2NTRMAYPE24x2j\nf9vdu7p799y+LsmnuvuSJJ+a2wAAJ71FnErbk+SW2b4lyWsX8BwAAJtuo2HUSf6qqj5XVXtn7Nzu\nPjDb30hy7pEeWFV7q2pfVe07+ORTG5wGAMDGbegzRkl+pbv3V9W/SHJXVX1p7Z3d3VXVR3pgd9+U\n5KYkOf1fXnjEfQAAttKG3jHq7v3z/Ykkn0hyWZLHq+q8JJnvT2x0kgAAW+GEw6iqzqiqFzy7neTX\nkjyQ5I4kV89uVyf55EYnCQCwFTZyKu3cJJ+oqmd/zh93919U1WeT3F5Vb0ny9SSv3/g0AQAW74TD\nqLu/kuTnjzD+zSSv2sikAACWwV++BgAYwggAYAgjAIAhjAAAhjACABgb/cvXHMVL3nnPlj7fyvWX\nb+nzAcCpyDtGAADDO0YLtlXv5Dzyhhtz8W3XbMlzAcCpyjtGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABj57IncKpauf7yJMkjb7hxyTMBANZLGC3Yxbdds+wp\nAADr5FQaAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMI4ZRlV1c1U9\nUVUPrBk7p6ruqqq/n+9nz3hV1YeqaqWq7q+qSxc5eQCAzbSed4w+kuSKw8auS/Kp7r4kyafmdpK8\nJskl87U3yQ2bM00AgMU7Zhh1918n+dZhw3uS3DLbtyR57Zrxj/aqe5KcVVXnbdZkAQAW6UQ/Y3Ru\ndx+Y7W8kOXe2z0/y6Jr9HpuxH1NVe6tqX1XtO/jkUyc4DQCAzbPhD193dyfpE3jcTd29u7t37zjz\njI1OAwBgw040jB5/9hTZfH9ixvcnuXDNfhfMGADASe9Ew+iOJFfP9tVJPrlm/E1zddrlSb675pQb\nAMBJbeexdqiqP0nyiiQvrKrHkvxWkg8kub2q3pLk60leP7vfmeTKJCtJfpDkzQuYMwDAQhwzjLr7\njUe561VH2LeTXLvRSQEALIO/fA0AMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCE\nEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAOOYYVRVN1fVE1X1wJqx91XV\n/qq6b76uXHPfe6pqpaq+XFW/vqiJAwBstvW8Y/SRJFccYfz67t41X3cmSVW9LMlVSX52HvO/qmrH\nZk0WAGCRjhlG3f3XSb61zp+3J8mt3f10d381yUqSyzYwPwCALbORzxi9rarun1NtZ8/Y+UkeXbPP\nYzP2Y6pqb1Xtq6p9B598agPTAADYHCcaRjckuTjJriQHknzweH9Ad9/U3bu7e/eOM884wWkAAGye\nEwqj7n68uw929z8l+f386HTZ/iQXrtn1ghkDADjpnVAYVdV5a27+RpJnr1i7I8lVVXV6VV2U5JIk\nn9nYFAEAtsbOY+1QVX+S5BVJXlhVjyX5rSSvqKpdSTrJ15K8NUm6+8Gquj3JQ0meSXJtdx9czNQB\nADbXMcOou994hOEPP8f+70/y/o1MCgBgGfzlawCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjC\nCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgLFz2RMA\nFu8l77xny55r5frLt+y5ADabMIJtYquC5ZE33Jgkufi2a7bk+QA2k1NpAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwDhmGFXVhVV1d1U9VFUPVtXbZ/ycqrqrqv5+vp89\n41VVH6qqlaq6v6ouXfSLAADYDDvXsc8zSd7V3Z+vqhck+VxV3ZXkPyT5VHd/oKquS3JdkncneU2S\nS+brF5PcMN+BJVm5/vI88oYblz0NgJPeMd8x6u4D3f352f5+koeTnJ9kT5JbZrdbkrx2tvck+Wiv\nuifJWVV13qbPHABgk63nHaMfqqoXJ3l5knuTnNvdB+aubyQ5d7bPT/Lomoc9NmMH1oylqvYm2Zsk\nO84++zinDRyvi2+7ZtlTADjprfvD11V1ZpKPJ3lHd39v7X3d3Un6eJ64u2/q7t3dvXvHmWccz0MB\nABZiXWFUVadlNYr+qLv/dIYff/YU2Xx/Ysb3J7lwzcMvmDEAgJPaeq5KqyQfTvJwd//emrvuSHL1\nbF+d5JNrxt80V6ddnuS7a065AQCctNbzGaNfTvKbSb5YVffN2HuTfCDJ7VX1liRfT/L6ue/OJFcm\nWUnygyRv3tQZAwAsyDH
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
" kernel, # rhs = conv kernel tensor\n",
" (1,1), # window strides\n",
" ((0, 0), (0, 0)), # padding mode\n",
" (2,2), # lhs/image dilation\n",
" (1,1), # rhs/kernel dilation\n",
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
2019-09-30 11:00:02 -07:00
"print(\"out shape: \", out.shape, \"<-- larger than original!\")\n",
"plt.figure(figsize=(10,10))\n",
"print(\"First output channel:\")\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "A-9OagtrVDyV"
},
"source": [
"We can use the last to, for instance, implement _transposed convolutions_:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 629
},
"colab_type": "code",
"id": "5EYIj77-NdHE",
"outputId": "d2e82a42-9c8e-4973-f760-511a14805527"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"out shape: (1, 400, 396, 3) <-- transposed_conv\n",
"First output channel:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGXVJREFUeJzt3W2s5nV95/HPtzOIDWPKUF1Cgaws\nsDG2SUcypdO0aVyNFXkyNrGKD5QYk5FdTKrpbop9UpvUpN2ssjHZ1WCkYtMWiNZIDL1hlcT4ABQt\nIje1Pd4FZkfY1ls0S8P0uw/OFz3Dzs2ZOec61zjn9UpOzv/6Xf/rXL/rx/+Y91z/63+s7g4AAMlP\nLXsCAACnC2EEADCEEQDAEEYAAEMYAQAMYQQAMBYWRlV1VVV9uapWquqGRT0PAMBmqUX8HaOq2pHk\nH5K8IsljST6X5PXd/fCmPxkAwCZZ1DtGVyZZ6e6vdve/JLk1yf4FPRcAwKbYuaCfe2GSR9fcfizJ\nLx9r5x27zumd5523oKkAANvZ09/6Vg4/+YNaz76LCqMTqqoDSQ4kyY7du/Nzv/O2ZU0FADiD/e93\n//d177uoU2kHk1y85vZFM/Yj3X1Td+/t7r07dp2zoGkAAKzfosLoc0kur6pLquo5Sa5JcseCngsA\nYFMs5FRadz9dVW9N8jdJdiS5ubsfWsRzAQBsloV9xqi770xy56J+PgDAZvOXrwEAhjACABjCCABg\nCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIA\nGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwA\nAIYwAgAYwggAYOzcyIOr6utJvp/kcJKnu3tvVZ2X5LYkL0zy9SSv7e5vb2yaAACLtxnvGP2H7t7T\n3Xvn9g1JPtndlyf55NwGADjtLeJU2v4kt8z2LUlevYDnAADYdBsNo07yt1X1+ao6MGPnd/eh2f5m\nkvM3+BwAAFtiQ58xSvJr3X2wqv5Nkruq6u/X3tndXVV9tAdOSB1Ikh27d29wGgAAG7ehd4y6++B8\nfyLJx5JcmeTxqrogSeb7E8d47E3dvbe79+7Ydc5GpgEAsClOOYyq6pyqet4z20l+I8mDSe5Icu3s\ndm2Sj290kgAAW2Ejp9LOT/Kxqnrm5/x5d/91VX0uye1V9eYk30jy2o1PEwBg8U45jLr7q0l+8Sjj\n/5zk5RuZFADAMvjL1wAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA2Oj/iSzH8ZXXvX9L\nn+/S267b0ucDgDONMNoCWxUsl739nqzcuG9LngsAzkROpQEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQ\nRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEAjJ3LnsCZ7NLbrkuSXPb2e5Y8EwBgPYTRFli5cd+ypwAA\nrINTaQAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQCME4ZRVd1cVU9U1YNrxs6rqruq\n6h/n++4Zr6p6b1WtVNUDVXXFIicPALCZ1vOO0YeSXPWssRuSfLK7L0/yybmdJK9Kcvl8HUjyvs2Z\nJgDA4p0wjLr700m+9azh/Ulume1bkrx6zfiHe9U9Sc6tqgs2a7IAAIt0qp8xOr+7D832N5OcP9sX\nJnl0zX6PzRgAwGlvwx++7u5O0if7uKo6UFX3VdV9h5/8wUanAQCwYacaRo8/c4psvj8x4weTXLxm\nv4tm7P/T3Td1997u3rtj1zmnOA0AgM1zqmF0R5JrZ/vaJB9fM/7GuTptX5LvrjnlBgBwWtt5oh2q\n6i+SvDTJ86vqsSS/n+SPktxeVW9O8o0kr53d70xydZKVJD9M8qYFzBkAYCFOGEbd/fpj3PXyo+zb\nSa7f6KQAAJbBX74GABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABg\nCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIBxwjCqqpur6omqenDN2Dur6mBV3T9fV6+5\n7x1VtVJVX66qVy5q4gAAm2097xh9KMlVRxm/sbv3zNedSVJVL05yTZKfn8f8z6rasVmTBQBYpBOG\nUXd/Osm31vnz9ie5tbuf6u6vJVlJcuUG5gcAsGU28hmjt1bVA3OqbfeMXZjk0TX7PDZjAACnvVMN\no/cluTTJniSHkrz7ZH9AVR2oqvuq6r7DT/7gFKcBALB5TimMuvvx7j7c3f+a5AP58emyg0kuXrPr\nRTN2tJ9xU3fv7e69O3adcyrTAADYVKcURlV1wZqbv5nkmSvW7khyTVWdXVWXJLk8yWc3NkUAgK2x\n80Q7VNVfJHlpkudX1WNJfj/JS6tqT5JO8vUkb0mS7n6oqm5P8nCSp5Nc392HFzN1AIDNdcIw6u7X\nH2X4g8fZ/11J3rWRSQEALIO/fA0AMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCE\nEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAGPnsicA\nbI2vvO79W/Zcl9523ZY9F8BmEkawjWxVsFz29nuycuO+LXkugM3kVBoAwBBGAABDGAEADGEEADCE\nEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAA44RhVFUXV9XdVfVwVT1UVb894+dV1V1V9Y/zffeMV1W9t6pWquqB\nqrpi0S8CAGAzrOcdo6eT/E53vzjJviTXV9WLk9yQ5JPdfXmST87tJHlVksvn60CS9236rAEAFuCE\nYdTdh7r7C7P9/SSPJLkwyf4kt8xutyR59WzvT/LhXnVPknOr6oJNnzkAwCbbeTI7V9ULk7wkyb1J\nzu/uQ3PXN5OcP9sXJnl0zcMem7FDAZbm0tuuy2Vvv2fZ0wA4ra07jKpqV5KPJnlbd3+vqn50X3d3\nVfXJPHFVHcjqqbbs2L37ZB4KnKKVG/ctewoAp7V1XZVWVWdlNYr+rLv/coYff+YU2Xx/YsYPJrl4\nzcMvmrEjdPdN3b23u/fu2HXOqc4fAGDTrOeqtErywSSPdPd71tx1R5JrZ/vaJB9fM/7GuTptX5Lv\nrjnlBgBw2lrPqbRfTfKGJF+qqvtn7PeS/FGS26vqzUm+keS1c9+dSa5OspLkh0netKkzBgBYkBOG\nUXd/Jkkd4+6XH2X/TnL
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"# The following is equivalent to tensorflow:\n",
"# N,H,W,C = img.shape\n",
"# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))\n",
"\n",
"# transposed conv = 180deg kernel roation plus LHS dilation\n",
"# rotate kernel 180deg:\n",
"kernel_rot = np.rot90(np.rot90(kernel, axes=(0,1)), axes=(0,1))\n",
"# need a custom output padding:\n",
"padding = ((2, 1), (2, 1))\n",
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
" kernel_rot, # rhs = conv kernel tensor\n",
" (1,1), # window strides\n",
" padding, # padding mode\n",
" (2,2), # lhs/image dilation\n",
" (1,1), # rhs/kernel dilation\n",
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
"print(\"out shape: \", out.shape, \"<-- transposed_conv\")\n",
"plt.figure(figsize=(10,10))\n",
"print(\"First output channel:\")\n",
"plt.imshow(onp.array(out)[0,:,:,0]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "v8HsE-NCmUxx"
},
"source": [
"### 1D Convolutions"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WeP0rw0tm7HK"
},
"source": [
"You aren't limited to 2D convolutions, a simple 1D demo is below:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 680
},
"colab_type": "code",
"id": "jJ-jcAn3cig-",
"outputId": "614ed589-e097-4bfe-f596-3421e3492698"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"in shapes: (1, 200, 2) (3, 2, 2)\n",
"ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))\n",
"out shape: (1, 200, 2)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlYAAAEyCAYAAAA4KJ7OAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XuwJVd13/Hf6jOMEA9J4BkI1owY\nYQscxXEMNQFSYMfYciIpseTEiUuqpPwIsSoVk9jlR0ouEkzh5A/sipMiUWyLMuVH2cjgVyZluXBs\nMKacCGsAAXogGAvZGllIgwQSDghJp1f+6NN9j65n5uy9u/c99571/VSpZubOGU2fvrdvr9m/vVab\nuwsAAADjNes+AAAAgE1BYQUAADARCisAAICJUFgBAABMhMIKAABgIhRWAAAAE6GwAgAAmAiFFQAA\nwEQorAAAACayb11/8YEDB/zIkSPr+usBAACSfehDH/qsux9c9bq1FVZHjhzR8ePH1/XXAwAAJDOz\nP0t5HVEgAADARCisAAAAJkJhBQAAMBEKKwAAgIlQWAEAAEyEwgoAAGAiFFYAAAATWVlYmdk7zOwh\nM7v9DL9vZvY2MzthZh8zs1dMf5gAAAC7X8qK1S9Iuvwsv3+FpEsW/10n6WfGHxYAAMDes3Lyurv/\nkZkdOctLrpb0S+7ukm4xswvM7EXu/sBEx4gdcvv9j+rj9z+67sPY9UzS677mBXrhec/M/rN/8fkv\n6f2fPDX9Qe0RX/H5j+mCxz6Z9NrDz3+WXnR+/jneCNZIL7tSevaB/D/7yKelT79/+mPaRBf9Heng\ny9Z9FNgwUzzS5kJJ9y39+uTiY3+lsDKz69Staumiiy6a4K/GlH7k3R/VJz7zhXUfxp7wva85oh//\ntr+R/ef+23s/pXf+yX2rX7ihPrD/B3W4iVtYZvnG+6RvfmP+n3vvT0i3/8b0x7OJXvI66bt+e91H\ngQ2zo88KdPcbJd0oSUePHvWd/Lux2uNPzvWtl75QP3H11677UHa1f/C2D+jxJ9uiP/v4k62+8vxn\n6jf/9WsmPqq94eDPSV+66Dv0hW/4D2d93X/6nbv02b/8sn7lX75qh45sl3nby6WnvlT2Z598XDrw\nMgqGVd79PdJTj6/7KLCBpiis7pd0eOnXhxYfwx7TuvScc/bpr0WNXxLtm5natuzfBa279u9rAp/j\nVuc+5wKde+HFZ33Vl899RKf+319K533lDh3XLtPsk9qy4l0+l/adE/fcpXrGudITX1z3UWADTTFu\n4Zik71p0B75a0qPsr9qb5q3LbN1HsfvNzNR6WWE1b11N5JPsrdTMVr5s1pgKa9fNYLPuXJVIPMfh\njTnHwFmsXLEys3dK+iZJB8zspKQfl/QMSXL3n5V0s6QrJZ2Q9EVJ31vrYFFX665Z5Jt+IjPTvLCw\nat3VNIHPcdt2G7NXMFPxquBGMOtWnkq086RzHJ415ecYOIuUrsBrV/y+S/r+yY4Ia9O6axb5pp9o\n1pgK6yq1rWIXr952KwUrdCtWgQurZuSKVcI5Dm/MOQbOgn/WYDBvu9UYnF1jXaRXYu7B41afK+UE\nNCNWBTeCNd3KUwlnxSqJNeX72ICz4OrDoFuxWvdR7H5NMyIKbIOvCrbzpP0/jVnse57NRkSB7LFK\nQhSISriNYsAeqzQzM/mIPVahC6vkKFBEgaOiQL61r0QUiEq4+jDougID3/QTNWYjosDgcWtiTDXm\nHG+EMTEVUWCaMXErcBZcfRiEj6kSNY1pXnjPa1vXLOopdk8eBdBE37w+JqZKjFvDGxO3AmdBYYVB\n66KwSjBrRBRYoj9nKVGgBZ9jRVdgfUSBqITCCoPwHWuJxnSshY5b+9WBpCiwvPNyI9AVWB9dgaiE\nqw+DLqYKetPPMGb/T+gGgb5QaBIKq6b8sUEbga7A+ogCUQmFFQahY6oMowaERo5b+9glOQoMXFjR\nFVhf0xAFogquPkjq9gy10TvWEo0aEBr5eYw5UeCIWWEbga7A+ugKRCVcfZC0ta84bEyVYcweK4+8\nKjhEgQwIXYmuwPqIAlEJhRUkaSgUot7zc3RRYPkjbZqoxWtOFMiAULoCa6MrEJVQWEHSVrTVUFmt\nNGpAaCsKq9QBoZELK7oC66MrEJVw9UHSUhRIYbVSN7yy7M965Ocx9oVVSlegdQ0CpSuDe56NXLEi\nClxtzDkGziLqt3hsQxSYrrHymGreBo4C25w5Vt05CjtxYdQeK7oCk5ixxwpVcPVB0lIUGPWmn2E2\n6lmBHjduHboC0/ZYSYGHhDYzugJra2Z0BaIKrj5I2opciAJXGxcFBu68HKLAtGcFSoE3sNuIGUtE\ngWmIAlEJhRUksWKVozEVTwXvosCJD2ivKIoCIxdWI8YtsGK12phzDJwFVx8kLe2xCnvXTzcbMbxy\n3kaOAvMmr0vRo8AxXYGsWK3Uj1uIWryjGgorSGJAaI5mxONWPPKzAouiwJoHtIvRFVhfX3xSWGFi\nFFaQtBwFrvlA9oBuKjgDQrMNUeDq999/HYZ9EDNdgfX154g4EBPj6oOkrb0sYWOqDGOiwNYDn+Os\nrsBFFBh1NYGuwPr6eWp0BmJiXH2QtPU9POxqSoYxz7FrI29ez4kC2bxOFFjbEAXSGYhpUVhB0tbK\nQNip4BlGDQjlIcx5XYFR73l0BdZHFIhKuPogaSkKZMVqpVlTPiC0jTx5vWRAaNQVK7oC6+tX9YgC\nMTEKK0ja2iQc9qafwax8QGjrgc9xXyQlrKbYsGIVtLAaGwWyYrXasGIVdVkUtXD1QdJyFBj0pp9h\n1ox7VmDYuLVfGUh4CPMs/B6rWVlE1abvYwuPPVaoJOq3eGzD5vV0sxFzrNrQzwrMGBAafY5VUzjH\nKuMch9ewYoU6KKwgaXmP1ZoPZA+wEQ9hbiPPsfL0zev9KQo7ed2asp37nj4rLDxj3ALqoLCCpK0b\nGFHgarNmxIDQNvDk9SEKzFmxClxYFUWB6ec4vCEKpLDCtCisIIkBoTlmzcjN61HPccGzAsMWVkSB\n9TXssUIdFFaQxLiFHGZlYwDa6I8NyooCgz+E2ZqyiCrjHIdHFIhKuPogaWuTcNiYKsOs8FmBffEa\n9hxndKwNUWDUxYTirkCiwGR0BaISCitI4iHMObooML+wmkePW4eYKmHcQr+YQBSYZ5gVRmG1ElEg\nKqGwgqSlmCrqTT9DPyDUM2/64UdalESBUQsrugLr688RUSAmRmEFSUtRIIXVSlsbq/P+XBv9eYw5\nXYFMXicKrI2uQFQS9Vs8tpkzxypZaUw1j94gwIDQdHQF1kcUiEoorCCJZwXmKO1YC3+OGRCajq7A\n+ugKRCVcfZC0HFMFveln6M9R7vaf8HFrf8IyosDcfWwbw0auWBEFrkZXICqhsIKk5a7AoDf9DH1d\nlLuxOnznZZu+mtI3UYTevD5mjxUrVqv154jCChNLuvrM7HIzu9vMTpjZ9af5/YvM7H1m9hEz+5iZ\nXTn9oaImBoSma0qjwPDjFjIKq+gDQvsVp9zOQPZYpWuIAlHHyu9wZjaTdIOkKyRdKulaM7t028v+\nvaR3ufvLJV0j6X9MfaCoK3xMlWErCiwrrMIOCM2IqUrj1o1RGlMN55gVq5WIAlFJytX3Skkn3P0e\nd39C0k2Srt72Gpd03uLn50v6i+kOETshfEyVoXQ1JXzcmhMFht+8vjgBuXEgUWC6IQpkxQrT2pfw\nmgsl3bf065OSXrXtNW+W9Htm9m8kPVvSZZMcHXZM+JgqQ+n+n2FAaNRzPESBq1esmugDQocoMPOm\nn3GOwys9x8AKU/2z5lpJv+DuhyRdKemXzf7qP5nM7DozO25mx0+dOjXRX40phI+pMmx1rOX9ufAD\nQouiwKCF1egokMJqJaJAVJLyLf5+SYeXfn1o8bFlr5f0Lkly9/8r6ZmSDmz/H7n7je5+1N2PHjx4\nsOyIUcU8+uNWMpTGVOE
"text/plain": [
"<Figure size 720x360 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlUAAAEyCAYAAADTHyXNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XmQJGl53/HfU2f2VT0sO+x9wXJo\nwcDiAaEDJBAWC5LAQiEZrAOMHBsKIYfQYYwChwx2KGxZssJhCVteB4SwQAJhaQ0hQAIUyAosrlm0\n7AnscggWL8vsLtNV3dN1v/4j6+iuzqzqrMqqrK73+4mY6Jnqruq3srorf/Pk+z6vOecEAACA2eSy\nHgAAAMAqIFQBAACkgFAFAACQAkIVAABACghVAAAAKSBUAQAApIBQBQAAkAJCFQAAQAoIVQAAACko\nZPFNL774Ynfttddm8a0BAAASue222x52zp2e9HWZhKprr71WZ8+ezeJbAwAAJGJmf3+cr+PyHwAA\nQAoIVQAAACkgVAEAAKSAUAUAAJACQhUAAEAKCFUAAAApIFQBAACkYOZQZWaBmX3azD5nZneb2VvS\nGBgAAMBJkkbzz4akFzrnds2sKOnjZvYh59wnU3hsAACAE2HmUOWcc5J2e/8s9v64WR8XGWnUpHve\nL3VbWY9kuTzuBumq56TzWM5J9/xvqb6TzuMto3xJuuHlUmkj65EAwMKksk2NmeUl3Sbpeklvdc59\nKuJrbpZ0syRdffXVaXxbzMMd75E+8CtZj2L5bF4i/eoX03msb94hvfc16TzWMnNOuvEnsx4FACxM\nKqHKOdeR9EwzOyXpVjN7mnPurpGvuUXSLZJ05swZKlnL6sKj4cfX3ynlMtkacvl8/D9LZ98WhgSz\n2R/vwiPhx3/yTumKfzj74y2b1r70u88aPk8A8ESqZ03n3Hkz+5ikmyTdNenrsYTqO1JxQzpFNXFg\n+wqp25ZaF9K5nNW/7HfRE6TK5bM/3rJxTrL8al/eBIAIaaz+O92rUMnM1iT9I0mfn/VxkZH6eSmo\nZD2K5VLuHY+0QkL/cVb1OJuFz41QBcAzaVSqLpP0jt68qpykP3HO/XkKj4ss1HekYDvrUSyX/vGo\n76RTWRqEqhU+zsE2oQqAd9JY/XeHpBtTGAuWAaHqqIOhKg31HclyUmkzncdbRoQqAB6iozoOI1Qd\nFZwKP6YZqoLtdCa9LytCFQAPEapwGKHqqHlUqlb9GBOqAHiIUIXD6tXVP+EnlXqo8uAYB9tSo5r1\nKABgoQhVGHIuDA7lFV2VNq1gDqv/Vv0Yl6lUAfAPoQpDzT3JdVa/ipJUoSwVAi7/JRFsS81dqdPO\neiQAsDCEKgz5sNR/WmnOEarvDCe/r6r+zxCXAAF4hFCFIUJVvNRD1Yof48E8tPPZjgMAFohQhSFC\nVby0QlWnJbX2Vv8Ypz25HwBOAEIVhvqXalb90tQ00gpVjdrw8VYZoQqAhwhVGKJSFS+tFgH9y2Gr\nfowHoYo5VQD8QajC0Kpv9DuLckobBPtyjNNuQwEAJwChCkP9Ksqq91CaRv/yn3OzPY4v1UAu/wHw\nEKEKQ/WdsB9TMch6JMsn2JY6Taldn+1xfAlVpS1JRqgC4BVCFYZ8WOo/rbQqL76EqlwuvARIqALg\nEUIVhnzYk25aqYWq/gpLD44zmyoD8AyhCkNUquL120zMupqtviPJepfHVhybKgPwDKEKQ4SqeGle\n/gsq4eWxVRecolIFwCsevLPj2Oo7rPyLM2gRMOO2K/UdqexJcE2rDQUAnBCEKgxRqYqXaqXKk2PM\nnCoAniFUIeScXyf8pAhVyRGqAHiGUIVQuy51W/6c8JMqBFK+NHtIaHi0wrI/Ub3byXokALAQhCqE\nfOmfNC2zdCovvlWqJFYAAvAGoQohQtVkabQI8DFUsakyAE8QqhAiVE0262q2bqd3+c+TFZZsqgzA\nM4QqhAhVk816+a/hUTd1iU2VAXiHUIUQoWqyWUOVb8eYUAXAM4QqhPpNLX054U+DUJUMoQqAZwhV\nCPm00e+0Zg5Vnh1jQhUAzxCqEKrvhH2YCkHWI1lewXbYz6vdmO7+vlWqykxUB+AXQhVC/aX+ZlmP\nZHnN2iLAt1CVy4fBij5VADxBqEKIzZQnm/VyVv9+Ph1nNlUG4BFCFUI+NaWcFqEqOfb/A+CRmUOV\nmV1lZh8zs3vM7G4z+8U0BoYFI1RNNghV56e7f31HKm1J+UJ6Y1p2hCoAHkmjUtWW9CvOuRskPVfS\n68zshhQeF4vk00a/05q1UuXjMQ62pw+hAHDCzByqnHMPOuc+2/t7TdK9kq6Y9XGxYFSqJkvj8p9v\nx5hKFQCPpDqnysyulXSjpE+l+bhYAB9P+En1j8+0q9l8PMbBNhsqA/BGaqHKzDYl/amk1zvnjryL\nmtnNZnbWzM6eO3curW+LNLTqYf8l3074SRXXpVxhhkrVef+OcbAdhtBuN+uRAMDcpRKqzKyoMFC9\nyzn3Z1Ff45y7xTl3xjl35vTp02l8W6TFt41+p2U2W4uA+o4UeLTyTwqfr+tKzd2sRwIAc5fG6j+T\n9DZJ9zrnfmf2IWHhfGtKOYtZ5gj5evlPYl4VAC+kUan6Hkk/LemFZnZ7789LU3hcLIpve9LNYtpQ\n1e1KjZp/x5hQBcAjMzfMcc59XBJ7m5xk/SXvvp3wpzFtqGruhpfBfDvGhCoAHqGjOrj8l8S0ocrX\nY0yoAuARQhX8PeFPY9oWAb4e41nbUADACUKogp970k1r1kqVb8e4TKUKgD8IVQhPeJaXShtZj2T5\nBdtSa0/qtJLdz9tKVS9EEqoAeIBQheFSf2O9wUSDOUIJL2f5GqryRam4QagC4AVCFfzc6Hdag1CV\ncJPgQYPVU+mO5yRgU2UAniBUwc+mlNOadjXboFLl2ZwqiU2VAXiDUAVCVRKzhKriRng5zDeEKgCe\nIFSBUJXEtC0CfNxMuW/aNhQAcMIQquDnRr/TKk+5ms3nYxzMsAk1AJwghCr0TvgeTqCexiyX/7yu\nVBGqAKw+QpXvOi2pdcHfE35SpU3JclOEKo9XWPZDlXNZjwQA5opQ5bv+XBdfT/hJ5XLhJUAqVccX\nbEuuIzX3sh4JAMwVocp3/f5Bvp7wpzHN5SzfQ5XEJUAAK49Q5TtfO33PIulqNucIVRKbKgNYeYQq\n3xGqkktaqWruhZe/fD3GVKoAeIJQ5bv+ia7s6XL/aSQNVb4f4zKhCoAfCFW+azBRPbGkocr3Y0yl\nCoAnCFW+4/JfctNWqnw9xoQqAJ4gVPmuvhP2XSptZj2SkyPYlpo1qdM+3tcPQpWnDVb7neT7K00B\nYEURqnxX3wnn+uT4UTi2pKvZfK9UFcpSYY1KFYCVx5nUdz4v9Z8WoSo5NlUG4AFCle983uh3Wkk3\nVR40WPX4OLOpMgAPEKp8x2bKySWdeF3fkQpBeBnMV2yqDMADhCrf+bzR77QShyqOMaEKgA8IVb5j\nTlVy01SqfD/GhCoAHiBU+Y4TfnKEquQIVQA8QKjyWacd9lvy/YSfVLkiyQhVSfRDlXNZjwQA5oZQ\n5TPft0+ZVi4XBqvjtgggVIXPv9uS2vWsRwIAc0Oo8pnvG/3OIkmLgH6DVZ8lbUMBACcQocpnVKqm\nd9w5Qs6Fx9n3Y8z+fwA8QKjyGZ2+p3fcUNWuS50mx7jfC41QBWCFpRKqzOztZvYtM7srjcfDghCq\npnfcUMUxDlGpAuCBtCpVfyDpppQeC4vCCX96hKpkCFUAPFBI40Gcc39jZtem8VhYIE74iXW6Tq1O\nV4XSlvKNHdmkOwyOsb9bATXaHbnCpgLpWKHKOadGuytJyudMxTyzFCZptrvq9tpVlAs5mU38yQQw\nB6mEKpxQ9R1Jxsq0Y2p1unreb35M36zW9UuFh/UvClW5Tke5fD7+Tp4H17d//Cv6t39+j8pq6guB\ndPeXv66nPnv8fW7+w9v
"text/plain": [
"<Figure size 720x360 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"# 1D kernel - WIO layout\n",
"kernel = onp.array([[[1, 0, -1], [-1, 0, 1]], \n",
" [[1, 1, 1], [-1, -1, -1]]], \n",
" dtype=np.float32).transpose([2,1,0])\n",
"# 1D data - NWC layout\n",
"data = onp.zeros((1, 200, 2), dtype=np.float32)\n",
"for i in range(2):\n",
" for k in range(2):\n",
" x = 35*i + 30 + 60*k\n",
" data[0, x:x+30, k] = 1.0\n",
"\n",
"print(\"in shapes:\", data.shape, kernel.shape)\n",
"\n",
"plt.figure(figsize=(10,5))\n",
"plt.plot(data[0]);\n",
"dn = lax.conv_dimension_numbers(data.shape, kernel.shape,\n",
" ('NWC', 'WIO', 'NWC'))\n",
"print(dn)\n",
"\n",
"out = lax.conv_general_dilated(data, # lhs = image tensor\n",
" kernel, # rhs = conv kernel tensor\n",
" (1,), # window strides\n",
" 'SAME', # padding mode\n",
" (1,), # lhs/image dilation\n",
" (1,), # rhs/kernel dilation\n",
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
"print(\"out shape: \", out.shape)\n",
"plt.figure(figsize=(10,5))\n",
"plt.plot(out[0]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7XOgXqCTmaPa"
},
"source": [
"### 3D Convolutions"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 530
},
"colab_type": "code",
"id": "QNvSiq5-mcLd",
"outputId": "eecbad0f-f443-43c1-83d6-f8fba22c7383"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)\n",
"ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))\n",
"out shape: (1, 30, 30, 30, 1)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvcnLbWle7/l52tXt5m1PFxEZkZFN\niVkXi4K6WJC3qEFiUYLinagoWorgKMdOFMmBIjh3Ig6cJE5qpCNBKP8AKfRSpqhpNhFx2rfZ7eqe\ntgbrPZGNV0zNNE5ExvrA4exz1l6btfZmfddvfX/NI3LOzMzMzMx8MMhXfQAzMzMzHydm0Z2ZmZn5\nAJlFd2ZmZuYDZBbdmZmZmQ+QWXRnZmZmPkD0v7J9Lm2YmZmZ+bcj/qUNc6Q7MzMz8wEyi+7MzMzM\nB8gsujOvlM997nP85V/+5as+jJmZDwzxr3SkzZ7uzEeeX/mVX+H111/nd37nd171ocx8fJg93ZmZ\nmZkPA7PozrxS3nrrLf7iL/6CL33pS/zsz/4sv/zLv8xyueRzn/scf/VXf/Ud7/u93/s9fvRHf5TT\n01N+9Vd/lWEYAPjjP/5jPv/5z3/H5woh+OpXv8of/uEf8uUvf5nf//3fZ7FY8FM/9VMf6PnNzHw3\ns+jOfGj40z/9U37+53+e7XbLT//0T/PFL37xO7Z/+ctf5s///M/5p3/6J/7hH/7he7ILfv3Xf51f\n/MVf5Dd+4zc4Ho/82Z/92X/U4c/MfE/MojvzoeHzn/88P/mTP4lSil/6pV/ib/7mb75j+xe/+EXe\neOMNzs7O+M3f/E3+5E/+5BUd6czMv59ZdGc+NDx48OD913VdMwwDIYT3/++NN954//Wbb77JkydP\nPtDjm5n5QTCL7sxHhnfffff91++88w6PHj0CoGkauq57f9uzZ8++Yz8h/sVE8szMB84sujMfGf7g\nD/6A9957j9vbW373d3+Xn/u5nwPgx37sx/jbv/1b/vqv/5phGPjSl770Hfvdv3+fr33ta6/giGdm\n/jmz6M58ZPiFX/gFfuInfoK3336bT33qU/zWb/0WAJ/97Gf57d/+bb7whS/wmc985p9VMvzar/0a\nX/nKVzg5OeFnfuZnXsWhz8y8z9wcMfOR4K233uKP/uiP+MIXvvCqD2Vm5nthbo6YmZmZ+TAwi+7M\nzMzMB8hsL8zMzMz84JnthZmZmZkPA7PozszMzHyAzKI7MzMz8wEyi+7MzMzMB8gsujMzMzMfILPo\nzszMzHyA/GtLsM/M/JsYQ6DzHiEEjTEYpQDIOc+DZ2ZmmOt0Z34A5JxxMeJioHUeqyehjSmxLEpu\nxpYhOpSQPKhXlMrgUyDmiJEaJdQrPoOZmR84/2KEMYvuzL+bnDMpZ/bjgI+JvXOQEudNjUuR3Thw\nCANLa1naipAiPkWWVvN0uCbnjBKKtxePKJShDVtSTtR6iZXlqz69mZnvh1l0Z36wjCFwGEc673Ap\ncVHXHMeR7dDjUmTMHoXgNvaclzX3ygUbt+NZf4sXI59evkapLGN0+DhQ6p5j2CIyKGn49PJ/QuSe\ng/saQihW9jNYtX7Vpz0z870yi+7M90/Omd57hhA4upFlUTDEwK4fKJVmE3q+dnvNNgycFjUPqgVe\neG7GAwM9hRRYrRlSy3mx4txWXLlvcN0/Y2kkn17+J5QwtOEFMvcU4golCjIJISSvN/8n3v+/OP93\nSLlmUf0faPXg7tgGoJh945kPC7Pozvz7iSkRUuLoRkJKxJzZ9D0nZQUi85XrZ1x3HVJArS1KwZAD\nVimu/Q3P+z1RdJwVax6WZ3T5Gc+H5xSqY2UuWWiJ4ECtVizVgTE+wYWnrOwp96v/jCDS+7+npqMQ\nEaUfkVMHInNS/1di93+T0y1CnmGa/wupXyfHa8gdyAuErF/1Vzjz8WMW3Zl/H2MIbIeBkCKbYeCy\nrpFC8OSwZzsO9NERQmZMjkMaOdU1o+j5x8MLutRSSMvD+hSlB3rfIUTHkLe4NGBlR63PeFiccgz/\nDRefslSg5QMurQYOlLKkFs/Iqcfm51j9Jovyx8nxihS+QiUSWn8Wqe6R0xZyxhT/C4z/DyBBlIjm\n10HdA/93QAD9SYQ8f8Xf7MwPObPoznzv5Jw5jCNDCOyGnpNqEtrr9ohPkUjmvcOOTd8RiTS64qwu\n+erxim/sr8kEam0R0iOkZF0W7MIzbv01mpZC1bze3EOJG4awQ4trFBkrByrZY9QlZ0rgw19TiANW\nGpR8wJkSxLSjlAqLQyAocUj9CG3+ZwhfhfANlKyQ9sdBNuS4JSMQcgnx8d2lUCIXXyQB3v0VIDDF\nf0aph6/2i5/5YWIW3Zl/nZASIUUOw0gkYYTiRddilabSmr+/veZZeyDkyFJb1mXBxre86Dr2Y88x\nDex8i1SCR/UKJT3f7J/Qpi1WKRaqYG0cSQQMkSyfQT5SyR6rGh6V98nxq4S0oRYtRhWspEOLiBYr\nKrGH9IxGJBAFSt6nkZKQbrEopLAoLFJKhLgg67eJ4SuI+AKhTlH2v4CsCOE9kijx2SOEJpMRSKrF\nF+nCY9rx/0PKkpPqf6fQrwGQckTOpW0z3zuz6M78y6Sc8XGyD3JO3LQdi6JgZS2PD3seHw/4GAHQ\nUjDg8TGSgW+0Gx7vb0FmVrZkZQ2HfORmOBBFT0yRwowoJbi0J3hesAvPMLKllIqFrjhTA54WIw6U\nsqOko1IeLRasVU1K72AZKIXDyIq1kqQMAosUR2QeqIQGoVHiDVI+EvIGLQSaBqnWQAKxYhRLXHxC\nTgeEOqcs/jeysHT+azhWBAqsfkTOIymPXCx+gWfDP3HwzzCi5I3mx1na+4QU8NmjhcZI80p/v5kP\nJbPozvxzcs7sxoE+BDZ9x8IUrMqCq7blMA5kMrf9yGEcGbPHIDlvKh53B/5xc8UYPUYqjFYYDS55\nrvo9t2GHEz1GwaltWBvBLl7T5wOFHDFScGI8WioqJUnpBVLsqMVAqWAhCmrd49OeUg5URBZqRAGG\nEiU1Ph+phKMASlVSUeAJxJTIwqPIGAFCWKR4kzY9xecjEkkhGoS6T8KRWbDLFSG1uDSi9eucVv8r\nPktux3+kz5eU5jXW9hE+9bjY8lrzX/hG+x4xR4QQfHrxGU7tKX1w+ByxUlOqWYg/5syiO/MtXIwM\n3nN0jpQTi2ISWhcjjTY8b1ve3W/xIWCV5mLRMMbAk8OOq74lkzn4EaGh0oqc4clww5NhT86BwmjW\n1qCUp0s9fTqA6DHSUyjBmS4RqsOlDVoMLJSjUoJTHUk5oERCscfKgVIEGhWppUGkHp89lYzUItGo\nQEKh0QxZQ/YY4bBSsJBLMpI+OhwehUSLRBIJJWqS+ARX/jEueaQwnOiKpD5LF4+EXHHMDZDockWj\n3+BR/T9yiIGn3Xt4TvhE82lWZoVPniH2PCzf5sVwfP9Ke7O55MTWtN4RcqZUilLPQvwxYhbdjzs5\nZ0JKuBjZDQNGSW77npQzF3XDTdfy9c0tLkSkFBRaY6ViMw6EGLkdep50O0JI1NZwWpUcfM+77YFb\ndyThkUqglWBhNX0euXUbhtxRqkhjDSdaEhkZ8h4lBkrlsSrQKEEtIeWWlEYK7TmTLY2GQkKIHiEC\nhQhY4VEis5YRJSxtCkQSSxFYqYgmE5CQBYdcIIgIApUULOQJ++Rpk6NPioUUIAwuK4yo6MUjrtwz\nxixRwnJpT4jqR9m6HT4X+LygUgWeMxbmjDfrT7D1jnePLyjkKT+yfoOlKSdffBy4X50xxogSgpgy\nl3VDYyyDD6ScMEpR6Hn8yQ8ps+h+nMk5sx0Geu/ZjwMCuFws2A8DL45HQkoMzuNSIjC19i6N4egc\n/7i5Ztf3FEZTGkNpNYN3PO+OXI8tbfQgEuvSUmvDs3HD9XAgKY+1UApBYyRKRg6pI3MEAkvjWRuN\nlZIxHkiMrHVHpQNaQCU
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvVnMZVtd7v0b3WxX8663qW5X7QYQ\nUOCwP/X4fR/iufiUCzWgwYgK2EW8IeHKLiqJQY0GNeiVRsVAJAE0JhpNNBoU7KLGaDjm6AE3u6ld\ne1f7NqudzWi/i1nUkdCogJtmz19SSb1rrDnWqnrXfNacz3j+/yFSSoyMjIyMPDPIz/UbGBkZGXk2\nMYruyMjIyDPIKLojIyMjzyCj6I6MjIw8g4yiOzIyMvIMov+d8THaMDIyMvKfR3yygfFKd2RkZOQZ\nZBTdkZGRkWeQUXRHRkZGnkFG0R3h9a9/PRcvXmQ2m/H85z+ft7/97ffGPvCBDyClZDKZMJlMuHz5\nMq95zWv4+7//+8/hO/6v5cEHH+R973vf5+18I1/YjKI7wo/+6I/yxBNPsF6v+f3f/33e/OY38w//\n8A/3xi9dusR2u2Wz2fC3f/u3vPCFL+RrvuZr+NM//dPP4bseGfnCZBTdEV70oheR5zkAQgiEEDz6\n6KMf9zwhBJcvX+Ynf/InecMb3sCP/MiPfNI5/+qv/oqXvexl7O3tceXKFd75zncCsFqt+K7v+i6O\njo544IEH+Omf/mlijAC8853v5OUvfzk/+IM/yGKx4KGHHuKP/uiPAPit3/otvvIrv/JjXuMXf/EX\nedWrXvUJX//69eu86lWvYn9/n+c973n8+q//+r2x7/me7+HNb37zvZ8/8IEPcPnyZQC+8zu/kyef\nfJJXvvKVTCYTfu7nfo4nnngCIQS/9mu/xqVLl7h48SK/8Au/8GnPN/LsZhTdEQDe+MY3UlUVL3zh\nC7l48SLf8A3f8Cmf/+pXv5p//Md/ZLfbfdzY1atX+fqv/3re9KY3cefOHT74wQ/y8MMPA/CmN72J\n1WrFY489xp//+Z/zm7/5m7zjHe+4d+zf/d3f8YIXvIDj42N++Id/mO/7vu8jpcQrX/lKPvzhD/PI\nI4/ce+673/1uXvva137C9/ft3/7tXL58mevXr/M7v/M7/NiP/Rh/9md/9u/+P7zrXe/i/vvv5w/+\n4A/Ybrf88A//8L2x97///TzyyCP8yZ/8CW9961v/Q5bBp5pv5NnJKLojAPzyL/8ym82Gv/zLv+TV\nr371vSvfT8alS5dIKbFcLj9u7N3vfjdf93Vfx3d8x3dgjOHg4ICHH36YEALvfe97+dmf/Vmm0ykP\nPvggP/ADP8C73vWue8c+8MADfP/3fz9KKb77u7+bGzducOvWLaqq4pu+6Zt4z3veA8AjjzzChz70\noU94pXvt2jX++q//mre+9a0URcHDDz/MG97wBn7zN3/zM/o/+omf+AnquuYlL3kJ3/u933vvvYyM\n/GcYRXfkHkopXv7yl/PUU0/xK7/yK5/yuU8//TRCCPb29j5u7Nq1azz3uc/9uMePj49xzvHAAw/c\ne+yBBx7g6aefvvfzhQsX7v29qioAttstAK997WvvCd273/1uvvmbv/nec/4t169fZ39/n+l0+klf\n59PhypUrHzPf9evXP6P5Rp6djKI78nF47z+hp/tv+d3f/V2+/Mu/nLquP27sypUrn/D4w8NDjDFc\nvXr13mNPPvkk991333/ofb3iFa+4Z1e85z3v+aTWwqVLlzg9PWWz2XzC16nrmqZp7o3dvHnzY44X\n4hMXE127du1j5rt06dJnNN/Is5NRdJ/l3L59m/e+971st1tCCPzxH/8x73nPe/jar/3aj3tuSomn\nn36at7zlLbz97W/nZ37mZz7hnK973et43/vex2//9m/jvefk5IQPfvCDKKV4zWtew4//+I+z2Wy4\nevUqb3vb23j961//H3qvxhi+9Vu/lR/6oR/i9PSUV7ziFZ/weVeuXOFlL3sZP/qjP0rXdfzTP/0T\nv/Ebv3HvdR5++GH+8A//kNPTU27evMkv/dIvfczx58+f57HHHvu4eX/qp36Kpmn453/+Z97xjnfw\nbd/2bZ/RfCPPUlJKn+rPyBc5t2/fTv/jf/yPNJ/P03Q6TS9+8YvTr/3ar90bf//735+EEKmu61RV\nVbp48WL6lm/5lvQ3f/M3n3Lev/iLv0hf9VVflabTabp8+XJ65zvfmVJK6fT0NL3uda9Lh4eH6fLl\ny+ktb3lLCiGklFJ6xzvekb76q7/6Y+YB0iOPPPIx8wLpjW9846d8/WvXrqVv/MZvTIvFIj3nOc9J\nv/Irv3JvrG3b9JrXvCZNp9P0kpe8JL3tbW9L9913373x3/u930tXrlxJ8/k8/fzP/3x6/PHHE5B+\n9Vd/NV28eDGdP38+vfWtb/205xt5VvBJdVWkT71dz9jwZuRZzxNPPMFDDz2Ecw6t/70eUSMjwNjw\nZmRkZOTzg1F0R0ZGRp5BRnthZGRk5LPPaC+MjIyMfD4wiu7IyMjIM8gouiMjIyPPIKPojoyMjDyD\njKI7MjIy8gwyiu7IyMjIM8hYXjPyWaV1lq1zAMzynFwNH7GU0tj4ZWSEMac78lkgpoQNns57VrYl\n1wYS+BQ4LCZ03mODRwnJLM8xShFTIqaEFAI5ivHIFx+f9EM9iu7Ip01KiZASp11DHwLLbkdCcL4e\n+thue0tKiVmeU2hNiBEfI3WWsbMWAIFgXhRoKXExklJCS4mSo/M18gXNKLojn11aZznrO7a2p4ue\ni/WMdddx2u4QQlDctRUIgkVdURlNjJFV2w/Nz6sCJSUhRkJMaKXovUcIgQT2ypKUEjYEAAqtRyEe\n+UJiFN2Rz5yYEjtnab3jrG1YFBWNdxx3OyYqw6fAv54c40NinufMioKjYoL3HoQg12r4QCWo84xM\nKVwIrLqOQhtmZY4QAus9keFTK4W4d8xeUdCHQH/XqqizDH1XiEfPeOTzjFF0Rz59Qoy4GFj2HT4G\nfEjcbrccljVGKv738S02XY9Wkggc5SVGqsEmQLHrO6TUVJlhXuakmOitR0pJpiUI0FKTaYWS4GNk\n1ffM85zCGIQQ9N4TY7on3iFGUoJZUbCzPT5GlJTM8sGqCDGSAHV3d+ORkWeYUXRHPj067zlpd7gQ\nudPsuDidIJFcWy/Z2h6BZOstpEh0CZUp5llO3wRcCtQ642hSkitNEoOHSxJsbIcSktwo9oqSTd8T\nXMQYhVISqSRaCpSQCAEuRna9ZVGVGK1IKdFaRwRyrTBK4WMkxkRlDJu7nrEUgkVZooS4Z1WMnvHI\nM8AouiP/cWJKrPuOxjlOuob9vERJya3dBp8iuVQ8uVqx6nuMkAghuDyZsrWOJ87uIIRinhVIJbiQ\nT9mrCgiJXe8QCiZFzl5WEBGkFIhx2EesC54kErXJyLXmpGkRgNYSpSVKChQSJQQg6IIjJpjmGUpJ\nUkpsOgsiMckGq8LdFVopBC5GICGFZFGWwPClAoNnrEchHvnsMYruyL+PCwEfI8u+JcYhRXB9t6FQ\nmonOeHR1yrX1ilxKEolLkzkpws3NBusdAjhrexZ5ztF0Qq4U9BIbeso8o1CaIjeURpMLQ+8dLkZk\ngllVUBYZgUjwAZE0SOijQyCos5xE4HTXQ0ooJSgyQ7ybdiCBkNB6j0RQGoOUghgTq64lN5o6G7aV\nt96TBMMxd62HRGK/rHAh0HqHFJLaGIxSw/joGY/85xhFd+ST89Gc7WnbQoKbzYZZljPPcq6ul9zc\nbZER2uSptEYlRevdcMvuAk+tzpBCMSszMimZZSWEiEQQIkx1zqTImNUFzkesc/QuMi0MRmlKZUgK\nZEogBd4mjJFUVYZRis4OwhtTRN31c6WUZEJhg2fV9QAoKZkWGTvnkAgkgiSGvLASklzrewt1W2up\nM3PPM+7ccMWbSGQfzRG
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"# Random 3D kernel - HWDIO layout\n",
"kernel = onp.array([\n",
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n",
" [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n",
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n",
" dtype=np.float32)[:, :, :, onp.newaxis, onp.newaxis]\n",
"\n",
"# 3D data - NHWDC layout\n",
"data = onp.zeros((1, 30, 30, 30, 1), dtype=np.float32)\n",
"x, y, z = onp.mgrid[0:1:30j, 0:1:30j, 0:1:30j]\n",
"data += (onp.sin(2*x*np.pi)*onp.cos(2*y*np.pi)*onp.cos(2*z*np.pi))[None,:,:,:,None]\n",
"\n",
"print(\"in shapes:\", data.shape, kernel.shape)\n",
"dn = lax.conv_dimension_numbers(data.shape, kernel.shape,\n",
" ('NHWDC', 'HWDIO', 'NHWDC'))\n",
"print(dn)\n",
"\n",
"out = lax.conv_general_dilated(data, # lhs = image tensor\n",
" kernel, # rhs = conv kernel tensor\n",
" (1,1,1), # window strides\n",
" 'SAME', # padding mode\n",
" (1,1,1), # lhs/image dilation\n",
" (1,1,1), # rhs/kernel dilation\n",
" dn) # dimension_numbers\n",
"print(\"out shape: \", out.shape)\n",
"\n",
"# Make some simple 3d density plots:\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"def make_alpha(cmap):\n",
" my_cmap = cmap(np.arange(cmap.N))\n",
" my_cmap[:,-1] = np.linspace(0, 1, cmap.N)**3\n",
" return mpl.colors.ListedColormap(my_cmap)\n",
"my_cmap = make_alpha(plt.cm.viridis)\n",
"fig = plt.figure()\n",
"ax = fig.gca(projection='3d')\n",
"ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)\n",
"ax.axis('off')\n",
"ax.set_title('input')\n",
"fig = plt.figure()\n",
"ax = fig.gca(projection='3d')\n",
"ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)\n",
"ax.axis('off')\n",
"ax.set_title('3D conv output');"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DKTMw6tRZyK2"
},
"source": [
"## 🔪 NaNs"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ncS0NI4jZrwy"
},
"source": [
"### Debugging NaNs\n",
"\n",
"If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:\n",
"- setting the `JAX_DEBUG_NANS=True` environment variable.\n",
"- adding `from jax.config import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file\n",
2019-09-30 11:00:02 -07:00
"- adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`.\n",
"\n",
"This will cause computations to error-out immediately on production of a NaN.\n",
"\n",
"⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions!\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YTktlwTTMgFl"
},
"source": [
"## Double (64bit) precision\n",
2019-09-30 11:00:02 -07:00
"\n",
"At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to `double`. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "CNNGtzM3NDkO",
"outputId": "211d9880-4518-4a7d-f652-e3663274825f"
},
"outputs": [
{
"data": {
"text/plain": [
"dtype('float32')"
]
},
"execution_count": 14,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)\n",
"x.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "VcvqzobxNPbd"
},
"source": [
"To use double-precision numbers, you need to set the `jax_enable_x64` configuration variable __at startup__. \n",
"\n",
"There are a few ways to do this:\n",
"\n",
"1. You can enable 64bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n",
"\n",
"2. You can manually set the `jax_enable_x64` configuration flag at startup:\n",
"\n",
"```\n",
"# again, this only works on startup!\n",
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"```\n",
"\n",
"3. You can parse command-line flags with `absl.app.run(main)`\n",
"\n",
"```\n",
"from jax.config import config\n",
"config.config_with_absl()\n",
"```\n",
"\n",
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
"\n",
"```\n",
"from jax.config import config\n",
"if __name__ == '__main__':\n",
" # calls config.config_with_absl() *and* runs absl parsing\n",
" config.parse_flags_with_absl()\n",
"```\n",
"\n",
"Note that #2-#4 work for _any_ of JAX's configuration options.\n",
"\n",
"We can then confirm that `x64` mode is enabled:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "HqGbBa9Rr-2g"
},
"outputs": [],
"source": [
"from jax import numpy as np, random\n",
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)\n",
"x.dtype # --> dtype('float64')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6Cks2_gKsXaW"
},
"source": [
"### Caveats\n",
"⚠️ XLA doesn't support 64-bit convolutions on all backends!"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WAHjmL0E2XwO"
},
"source": [
"## Fin.\n",
"\n",
"If something's not covered here that has caused you weeping and gnashing of teeth, please let us know and we'll extend these introductory _advisos_!"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Common Gotchas in JAX",
"provenance": [],
"toc_visible": true,
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}