mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 03:56:08 +00:00

Clarify that the index is clamped to the bounds of the array when accessing out of bounds.
2808 lines
379 KiB
Plaintext
2808 lines
379 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "hjM_sV_AepYf"
|
|
},
|
|
"source": [
|
|
"# 🔪 JAX - The Sharp Bits 🔪"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "4k5PVzEo2uJO"
|
|
},
|
|
"source": [
|
|
"*levskaya@ mattjj@*\n",
|
|
"\n",
|
|
"When walking about the countryside of [Italy](https://iaml.it/blog/jax-intro), the people will not hesitate to tell you that __JAX__ has _\"una anima di pura programmazione funzionale\"_.\n",
|
|
"\n",
|
|
"__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU). \n",
|
|
"JAX works great for many numerical and scientific programs, but __only if they are written with certain constraints__ that we describe below."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "GoK_PCxPeYcy"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"from jax import grad, jit\n",
|
|
"from jax import lax\n",
|
|
"from jax import random\n",
|
|
"import jax\n",
|
|
"import jax.numpy as jnp\n",
|
|
"import matplotlib as mpl\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"from matplotlib import rcParams\n",
|
|
"rcParams['image.interpolation'] = 'nearest'\n",
|
|
"rcParams['image.cmap'] = 'viridis'\n",
|
|
"rcParams['axes.grid'] = False"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "cxwbr3XK2_mK"
|
|
},
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "gX8CZU1g2agP"
|
|
},
|
|
"source": [
|
|
"## 🔪 Pure functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "2oHigBkW2dPT"
|
|
},
|
|
"source": [
|
|
"JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs. \n",
|
|
"\n",
|
|
"Here are some examples of functions that are not functially pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 121,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 102
|
|
},
|
|
"colab_type": "code",
|
|
"id": "A6R-pdcm4u3v",
|
|
"outputId": "389605df-a4d5-4d4b-8d74-64e9d5d39456"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Executing function\n",
|
|
"First call: 4.0\n",
|
|
"Second call: 5.0\n",
|
|
"Executing function\n",
|
|
"Third call, different type: [5.]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def impure_print_side_effect(x):\n",
|
|
" print(\"Executing function\") # This is a side-effect \n",
|
|
" return x\n",
|
|
"\n",
|
|
"# The side-effects appear during the first run \n",
|
|
"print (\"First call: \", jit(impure_print_side_effect)(4.))\n",
|
|
"\n",
|
|
"# Subsequent runs with parameters of same type and shape may not show the side-effect\n",
|
|
"# This is because JAX now invokes a cached compilation of the function\n",
|
|
"print (\"Second call: \", jit(impure_print_side_effect)(5.))\n",
|
|
"\n",
|
|
"# JAX re-runs the Python function when the type or shape of the argument changes\n",
|
|
"print (\"Third call, different type: \", jit(impure_print_side_effect)(jnp.array([5.])))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 122,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 68
|
|
},
|
|
"colab_type": "code",
|
|
"id": "-N8GhitI2bhD",
|
|
"outputId": "f16ce914-1387-43b4-9b8a-1d6e3b97b11d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"First call: 4.0\n",
|
|
"Second call: 5.0\n",
|
|
"Third call, different type: [14.]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"g = 0.\n",
|
|
"def impure_uses_globals(x):\n",
|
|
" return x + g\n",
|
|
"\n",
|
|
"# JAX captures the value of the global during the first run\n",
|
|
"print (\"First call: \", jit(impure_uses_globals)(4.))\n",
|
|
"g = 10. # Update the global\n",
|
|
"\n",
|
|
"# Subsequent runs may silently use the cached value of the globals\n",
|
|
"print (\"Second call: \", jit(impure_uses_globals)(5.))\n",
|
|
"\n",
|
|
"# JAX re-runs the Python function when the type or shape of the argument changes\n",
|
|
"# This will end up reading the latest value of the global\n",
|
|
"print (\"Third call, different type: \", jit(impure_uses_globals)(jnp.array([4.])))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 123,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 51
|
|
},
|
|
"colab_type": "code",
|
|
"id": "RTB6iFgu4DL6",
|
|
"outputId": "e93d2a70-1c18-477a-d69d-d09ed556305a"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"First call: 4.0\n",
|
|
"Saved global: Traced<ShapedArray(float32[], weak_type=True):JaxprTrace(level=-1/1)>\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"g = 0.\n",
|
|
"def impure_saves_global(x):\n",
|
|
" global g\n",
|
|
" g = x\n",
|
|
" return x\n",
|
|
"\n",
|
|
"# JAX runs once the transformed function with special Traced values for arguments\n",
|
|
"print (\"First call: \", jit(impure_saves_global)(4.))\n",
|
|
"print (\"Saved global: \", g) # Saved global has an internal JAX value"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Mlc2pQlp6v-9"
|
|
},
|
|
"source": [
|
|
"A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "TP-Mqf_862C0",
|
|
"outputId": "78df2d95-2c6f-41c9-84a9-feda6329e75e"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def pure_uses_internal_state(x):\n",
|
|
" state = dict(even=0, odd=0)\n",
|
|
" for i in range(10):\n",
|
|
" state['even' if i % 2 == 0 else 'odd'] += x\n",
|
|
" return state['even'] + state['odd']\n",
|
|
"\n",
|
|
"print(jit(pure_uses_internal_state)(5.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"45\n",
|
|
"0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/Users/igor/projects/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.\n",
|
|
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import jax.numpy as jnp\n",
|
|
"import jax.lax as lax\n",
|
|
"from jax import make_jaxpr\n",
|
|
"\n",
|
|
"# lax.fori_loop\n",
|
|
"array = jnp.arange(10)\n",
|
|
"print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45\n",
|
|
"iterator = iter(range(10))\n",
|
|
"print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0\n",
|
|
"\n",
|
|
"# lax.scan\n",
|
|
"def func11(arr, extra):\n",
|
|
" ones = jnp.ones(arr.shape) \n",
|
|
" def body(carry, aelems):\n",
|
|
" ae1, ae2 = aelems\n",
|
|
" return (carry + ae1 * ae2 + extra, carry)\n",
|
|
" return lax.scan(body, 0., (arr, ones)) \n",
|
|
"make_jaxpr(func11)(jnp.arange(16), 5.)\n",
|
|
"# make_jaxpr(func11)(iter(range(16)), 5.) # throws error\n",
|
|
"\n",
|
|
"# lax.cond\n",
|
|
"array_operand = jnp.array([0.])\n",
|
|
"lax.cond(True, array_operand, lambda x: x+1, array_operand, lambda x: x-1)\n",
|
|
"iter_operand = iter(range(10))\n",
|
|
"# lax.cond(True, iter_operand, lambda x: next(x)+1, iter_operand, lambda x: next(x)-1) # throws error"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "oBdKtkVW8Lha"
|
|
},
|
|
"source": [
|
|
"## 🔪 In-Place Updates"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "JffAqnEW4JEb"
|
|
},
|
|
"source": [
|
|
"In Numpy you're used to doing this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 125,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 153
|
|
},
|
|
"colab_type": "code",
|
|
"id": "om4xV7_84N9j",
|
|
"outputId": "733f901e-d433-4dc8-b5bb-0c23bf2b1306"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"original array:\n",
|
|
"[[0. 0. 0.]\n",
|
|
" [0. 0. 0.]\n",
|
|
" [0. 0. 0.]]\n",
|
|
"updated array:\n",
|
|
"[[0. 0. 0.]\n",
|
|
" [1. 1. 1.]\n",
|
|
" [0. 0. 0.]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"numpy_array = np.zeros((3,3), dtype=np.float32)\n",
|
|
"print(\"original array:\")\n",
|
|
"print(numpy_array)\n",
|
|
"\n",
|
|
"# In place, mutating update\n",
|
|
"numpy_array[1, :] = 1.0\n",
|
|
"print(\"updated array:\")\n",
|
|
"print(numpy_array)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "go3L4x3w4-9p"
|
|
},
|
|
"source": [
|
|
"If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 126,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 54
|
|
},
|
|
"colab_type": "code",
|
|
"id": "2AxeCufq4wAp",
|
|
"outputId": "d5d873db-cee0-49dc-981d-ec852347f7ca",
|
|
"tags": [
|
|
"raises-exception"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Exception '<class 'jax.interpreters.xla.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n",
|
|
"\n",
|
|
"# In place update of JAX's array will yield an error!\n",
|
|
"try:\n",
|
|
" jax_array[1, :] = 1.0\n",
|
|
"except Exception as e:\n",
|
|
" print(\"Exception {}\".format(e))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "7mo76sS25Wco"
|
|
},
|
|
"source": [
|
|
"__What gives?!__ \n",
|
|
"\n",
|
|
"Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program. \n",
|
|
"\n",
|
|
"Instead, JAX offers the _functional_ update functions: [__index_update__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update), [__index_add__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add), [__index_min__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_min.html#jax.ops.index_min), [__index_max__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_max.html#jax.ops.index_max), and the [__index__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index.html#jax.ops.index) helper.\n",
|
|
"\n",
|
|
"️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "m5lg1RYq5D9p"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jax.ops import index, index_add, index_update"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "X2Xjjvd-l8NL"
|
|
},
|
|
"source": [
|
|
"### index_update"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "eM6MyndXL2NY"
|
|
},
|
|
"source": [
|
|
"If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 128,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 221
|
|
},
|
|
"colab_type": "code",
|
|
"id": "ygUJT49b7BBk",
|
|
"outputId": "1a3511c4-a480-472f-cccb-5e01620cbe99"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"original array:\n",
|
|
"[[0. 0. 0.]\n",
|
|
" [0. 0. 0.]\n",
|
|
" [0. 0. 0.]]\n",
|
|
"old array unchanged:\n",
|
|
"[[0. 0. 0.]\n",
|
|
" [0. 0. 0.]\n",
|
|
" [0. 0. 0.]]\n",
|
|
"new array:\n",
|
|
"[[0. 0. 0.]\n",
|
|
" [1. 1. 1.]\n",
|
|
" [0. 0. 0.]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"jax_array = jnp.zeros((3, 3))\n",
|
|
"print(\"original array:\")\n",
|
|
"print(jax_array)\n",
|
|
"\n",
|
|
"new_jax_array = index_update(jax_array, index[1, :], 1.)\n",
|
|
"\n",
|
|
"print(\"old array unchanged:\")\n",
|
|
"print(jax_array)\n",
|
|
"\n",
|
|
"print(\"new array:\")\n",
|
|
"print(new_jax_array)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "7to-sF8EmC_y"
|
|
},
|
|
"source": [
|
|
"### index_add"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "iI5cLY1xMBLs"
|
|
},
|
|
"source": [
|
|
"If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 129,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 221
|
|
},
|
|
"colab_type": "code",
|
|
"id": "tsw2svao8FUp",
|
|
"outputId": "874acd15-a493-4d63-efe4-9f440d5d2a12"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"original array:\n",
|
|
"[[1. 1. 1. 1. 1. 1.]\n",
|
|
" [1. 1. 1. 1. 1. 1.]\n",
|
|
" [1. 1. 1. 1. 1. 1.]\n",
|
|
" [1. 1. 1. 1. 1. 1.]\n",
|
|
" [1. 1. 1. 1. 1. 1.]]\n",
|
|
"new array post-addition:\n",
|
|
"[[1. 1. 1. 8. 8. 8.]\n",
|
|
" [1. 1. 1. 1. 1. 1.]\n",
|
|
" [1. 1. 1. 8. 8. 8.]\n",
|
|
" [1. 1. 1. 1. 1. 1.]\n",
|
|
" [1. 1. 1. 8. 8. 8.]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"original array:\")\n",
|
|
"jax_array = jnp.ones((5, 6))\n",
|
|
"print(jax_array)\n",
|
|
"\n",
|
|
"new_jax_array = index_add(jax_array, index[::2, 3:], 7.)\n",
|
|
"print(\"new array post-addition:\")\n",
|
|
"print(new_jax_array)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "oZ_jE2WAypdL"
|
|
},
|
|
"source": [
|
|
"## 🔪 Out-of-Bounds Indexing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "btRFwEVzypdN"
|
|
},
|
|
"source": [
|
|
"In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 130,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "5_ZM-BJUypdO",
|
|
"outputId": "461f38cd-9452-4bcc-a44f-a07ddfa12f42",
|
|
"tags": [
|
|
"raises-exception"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Exception index 11 is out of bounds for axis 0 with size 10\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"try:\n",
|
|
" np.arange(10)[11]\n",
|
|
"except Exception as e:\n",
|
|
" print(\"Exception {}\".format(e))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "eoXrGARWypdR"
|
|
},
|
|
"source": [
|
|
"However, raising an error on other accelerators can be more difficult. Therefore, JAX does not raise an error, instead the index is clamped to the bounds of the array, meaning that for this example the last value of the array will be returned. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 131,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "cusaAD0NypdR",
|
|
"outputId": "48428ad6-6cde-43ad-c12d-2eb9b9fe59cf"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DeviceArray(9, dtype=int32)"
|
|
]
|
|
},
|
|
"execution_count": 131,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"jnp.arange(10)[11]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Note that due to this behavior jnp.nanargmin and jnp.nanargmax return -1 for slices consisting of NaNs whereas Numpy would throw an error."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "MUycRNh6e50W"
|
|
},
|
|
"source": [
|
|
"## 🔪 Random Numbers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "O8vvaVt3MRG2"
|
|
},
|
|
"source": [
|
|
"> _If all scientific papers whose results are in doubt because of bad \n",
|
|
"> `rand()`s were to disappear from library shelves, there would be a \n",
|
|
"> gap on each shelf about as big as your fist._ - Numerical Recipes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Qikt9pPW9L5K"
|
|
},
|
|
"source": [
|
|
"### RNGs and State\n",
|
|
"You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 132,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 68
|
|
},
|
|
"colab_type": "code",
|
|
"id": "rr9FeP41fynt",
|
|
"outputId": "849d84cf-04ad-4e8b-9505-a92f6c0d7a39"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.08960303423860538\n",
|
|
"0.6720478073539145\n",
|
|
"0.24536720985284477\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(np.random.random())\n",
|
|
"print(np.random.random())\n",
|
|
"print(np.random.random())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "ORMVVGZJgSVi"
|
|
},
|
|
"source": [
|
|
"Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "7Pyp2ajzfPO2"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"np.random.seed(0)\n",
|
|
"rng_state = np.random.get_state()\n",
|
|
"#print(rng_state)\n",
|
|
"# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n",
|
|
"# 2481403966, 4042607538, 337614300, ... 614 more numbers..., \n",
|
|
"# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "aJIxHVXCiM6m"
|
|
},
|
|
"source": [
|
|
"This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "GAHaDCYafpAF"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"_ = np.random.uniform()\n",
|
|
"rng_state = np.random.get_state()\n",
|
|
"#print(rng_state) \n",
|
|
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
|
|
"# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n",
|
|
"\n",
|
|
"# Let's exhaust the entropy in this PRNG statevector\n",
|
|
"for i in range(311):\n",
|
|
" _ = np.random.uniform()\n",
|
|
"rng_state = np.random.get_state()\n",
|
|
"#print(rng_state) \n",
|
|
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
|
|
"# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n",
|
|
"\n",
|
|
"# Next call iterates the RNG state for a new batch of fake \"entropy\".\n",
|
|
"_ = np.random.uniform()\n",
|
|
"rng_state = np.random.get_state()\n",
|
|
"# print(rng_state) \n",
|
|
"# --> ('MT19937', array([1499117434, 2949980591, 2242547484, \n",
|
|
"# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "N_mWnleNogps"
|
|
},
|
|
"source": [
|
|
"The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n",
|
|
"\n",
|
|
"The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5Kb state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Uvq7nV-j4vKK"
|
|
},
|
|
"source": [
|
|
"### JAX PRNG"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "COjzGBpO4tzL"
|
|
},
|
|
"source": [
|
|
"\n",
|
|
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/master/design_notes/prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
|
|
"\n",
|
|
"The random state is described by two unsigned-int32s that we call a __key__:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 135,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "yPHE7KTWgAWs",
|
|
"outputId": "329e7757-2461-434c-a08c-fde80a2d10c9"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DeviceArray([0, 0], dtype=uint32)"
|
|
]
|
|
},
|
|
"execution_count": 135,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from jax import random\n",
|
|
"key = random.PRNGKey(0)\n",
|
|
"key"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "XjYyWYNfq0hW"
|
|
},
|
|
"source": [
|
|
"JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! \n",
|
|
"\n",
|
|
"Reusing the same state will cause __sadness__ and __monotony__, depriving the enduser of __lifegiving chaos__:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 136,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 85
|
|
},
|
|
"colab_type": "code",
|
|
"id": "7zUdQMynoE5e",
|
|
"outputId": "50617324-b887-42f2-a7ff-2a10f92d876a"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[-0.20584226]\n",
|
|
"[0 0]\n",
|
|
"[-0.20584226]\n",
|
|
"[0 0]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(random.normal(key, shape=(1,)))\n",
|
|
"print(key)\n",
|
|
"# No no no!\n",
|
|
"print(random.normal(key, shape=(1,)))\n",
|
|
"print(key)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "hQN9van8rJgd"
|
|
},
|
|
"source": [
|
|
"Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 137,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 68
|
|
},
|
|
"colab_type": "code",
|
|
"id": "ASj0_rSzqgGh",
|
|
"outputId": "bcc2ed60-2e41-4ef8-e84f-c724654aa198"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"old key [0 0]\n",
|
|
" \\---SPLIT --> new key [4146024105 967050713]\n",
|
|
" \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"old key\", key)\n",
|
|
"key, subkey = random.split(key)\n",
|
|
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
|
|
"print(\" \\---SPLIT --> new key \", key)\n",
|
|
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "tqtFVE4MthO3"
|
|
},
|
|
"source": [
|
|
"We propagate the __key__ and make new __subkeys__ whenever we need a new random number:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 138,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 68
|
|
},
|
|
"colab_type": "code",
|
|
"id": "jbC34XLor2Ek",
|
|
"outputId": "6834a812-7160-4646-ee19-a246f683905a"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"old key [4146024105 967050713]\n",
|
|
" \\---SPLIT --> new key [2384771982 3928867769]\n",
|
|
" \\--> new subkey [1278412471 2182328957] --> normal [-0.58665055]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"old key\", key)\n",
|
|
"key, subkey = random.split(key)\n",
|
|
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
|
|
"print(\" \\---SPLIT --> new key \", key)\n",
|
|
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "0KLYUluz3lN3"
|
|
},
|
|
"source": [
|
|
"We can generate more than one __subkey__ at a time:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 139,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 68
|
|
},
|
|
"colab_type": "code",
|
|
"id": "lEi08PJ4tfkX",
|
|
"outputId": "3bb513de-8d14-4d37-ae57-51d6f5eaa762"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[-0.37533438]\n",
|
|
"[0.98645043]\n",
|
|
"[0.14553197]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"key, *subkeys = random.split(key, 4)\n",
|
|
"for subkey in subkeys:\n",
|
|
" print(random.normal(subkey, shape=(1,)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "rg4CpMZ8c3ri"
|
|
},
|
|
"source": [
|
|
"## 🔪 Control Flow"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "izLTvT24dAq0"
|
|
},
|
|
"source": [
|
|
"### ✔ python control_flow + autodiff ✔\n",
|
|
"\n",
|
|
"If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 140,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 51
|
|
},
|
|
"colab_type": "code",
|
|
"id": "aAx0T3F8lLtu",
|
|
"outputId": "808cfa77-d924-4586-af19-35a8fd7d2238"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"12.0\n",
|
|
"-4.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" if x < 3:\n",
|
|
" return 3. * x ** 2\n",
|
|
" else:\n",
|
|
" return -4 * x\n",
|
|
"\n",
|
|
"print(grad(f)(2.)) # ok!\n",
|
|
"print(grad(f)(4.)) # ok!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "hIfPT7WMmZ2H"
|
|
},
|
|
"source": [
|
|
"### python control flow + JIT\n",
|
|
"\n",
|
|
"Using control flow with `jit` is more complicated, and by default it has more constraints.\n",
|
|
"\n",
|
|
"This works:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 141,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "OZ_BJX0CplNC",
|
|
"outputId": "48ce004c-536a-44f5-b020-9267825e7e4d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"24\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x):\n",
|
|
" for i in range(3):\n",
|
|
" x = 2 * x\n",
|
|
" return x\n",
|
|
"\n",
|
|
"print(f(3))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "22RzeJ4QqAuX"
|
|
},
|
|
"source": [
|
|
"So does this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 142,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "pinVnmRWp6w6",
|
|
"outputId": "e3e6f2f7-ba59-4a98-cdfc-905c91b38ed1"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"6.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"@jit\n",
|
|
"def g(x):\n",
|
|
" y = 0.\n",
|
|
" for i in range(x.shape[0]):\n",
|
|
" y = y + x[i]\n",
|
|
" return y\n",
|
|
"\n",
|
|
"print(g(jnp.array([1., 2., 3.])))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "TStltU2dqf8A"
|
|
},
|
|
"source": [
|
|
"But this doesn't, at least by default:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 143,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 54
|
|
},
|
|
"colab_type": "code",
|
|
"id": "9z38AIKclRNM",
|
|
"outputId": "466730dd-df8b-4b80-ac5e-e55b5ea85ec7"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Exception Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x):\n",
|
|
" if x < 3:\n",
|
|
" return 3. * x ** 2\n",
|
|
" else:\n",
|
|
" return -4 * x\n",
|
|
"\n",
|
|
"# This will fail!\n",
|
|
"try:\n",
|
|
" f(2)\n",
|
|
"except Exception as e:\n",
|
|
" print(\"Exception {}\".format(e))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "pIbr4TVPqtDN"
|
|
},
|
|
"source": [
|
|
"__What gives!?__\n",
|
|
"\n",
|
|
"When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n",
|
|
"\n",
|
|
"For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
|
|
"\n",
|
|
"To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/master/jax/abstract_arrays.py), and different transformations use different abstraction levels.\n",
|
|
"\n",
|
|
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
|
|
"\n",
|
|
"But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n",
|
|
"\n",
|
|
"The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 144,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "-Tzp0H7Bt1Sn",
|
|
"outputId": "aba57a88-d8eb-40b0-ff22-7c266d892b13"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"12.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" if x < 3:\n",
|
|
" return 3. * x ** 2\n",
|
|
" else:\n",
|
|
" return -4 * x\n",
|
|
"\n",
|
|
"f = jit(f, static_argnums=(0,))\n",
|
|
"\n",
|
|
"print(f(2.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "MHm1hIQAvBVs"
|
|
},
|
|
"source": [
|
|
"Here's another example, this time involving a loop:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 145,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "iwY86_JKvD6b",
|
|
"outputId": "1ec847ea-df2b-438d-c0a1-fabf7b93b73d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DeviceArray(5., dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 145,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def f(x, n):\n",
|
|
" y = 0.\n",
|
|
" for i in range(n):\n",
|
|
" y = y + x[i]\n",
|
|
" return y\n",
|
|
"\n",
|
|
"f = jit(f, static_argnums=(1,))\n",
|
|
"\n",
|
|
"f(jnp.array([2., 3., 4.]), 2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "nSPTOX8DvOeO"
|
|
},
|
|
"source": [
|
|
"In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "wWdg8LTYwCW3"
|
|
},
|
|
"source": [
|
|
"️⚠️ **functions with argument-__value__ dependent shapes**\n",
|
|
"\n",
|
|
"These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 146,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 122
|
|
},
|
|
"colab_type": "code",
|
|
"id": "Tqe9uLmUI_Gv",
|
|
"outputId": "fe319758-9959-434c-ab9d-0926e599dbc0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[4. 4. 4. 4. 4.]\n",
|
|
"Exception Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>,).\n",
|
|
"If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\n",
|
|
"[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n",
|
|
"[4. 4. 4. 4. 4.]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def example_fun(length, val):\n",
|
|
" return jnp.ones((length,)) * val\n",
|
|
"# un-jit'd works fine\n",
|
|
"print(example_fun(5, 4))\n",
|
|
"\n",
|
|
"bad_example_jit = jit(example_fun)\n",
|
|
"# this will fail:\n",
|
|
"try:\n",
|
|
" print(bad_example_jit(10, 4))\n",
|
|
"except Exception as e:\n",
|
|
" print(\"Exception {}\".format(e))\n",
|
|
"# static_argnums tells JAX to recompile on changes at these argument positions:\n",
|
|
"good_example_jit = jit(example_fun, static_argnums=(0,))\n",
|
|
"# first compile\n",
|
|
"print(good_example_jit(10, 4))\n",
|
|
"# recompiles\n",
|
|
"print(good_example_jit(5, 4))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "MStx_r2oKxpp"
|
|
},
|
|
"source": [
|
|
"`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! \n",
|
|
"\n",
|
|
"Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions: "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 147,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 68
|
|
},
|
|
"colab_type": "code",
|
|
"id": "m2ABpRd8K094",
|
|
"outputId": "64da37a0-aa06-46a3-e975-88c676c5b9fa"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>\n",
|
|
"Traced<ShapedArray(int32[]):JaxprTrace(level=-1/1)>\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DeviceArray(4, dtype=int32)"
|
|
]
|
|
},
|
|
"execution_count": 147,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"@jit\n",
|
|
"def f(x):\n",
|
|
" print(x)\n",
|
|
" y = 2 * x\n",
|
|
" print(y)\n",
|
|
" return y\n",
|
|
"f(2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "uCDcWG4MnVn-"
|
|
},
|
|
"source": [
|
|
"### Structured control flow primitives\n",
|
|
"\n",
|
|
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
|
|
"\n",
|
|
" - `lax.cond` _differentiable_\n",
|
|
" - `lax.while_loop` __fwd-mode-differentiable__\n",
|
|
" - `lax.fori_loop` __fwd-mode-differentiable__\n",
|
|
" - `lax.scan` _differentiable_\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Sd9xrLMXeK3A"
|
|
},
|
|
"source": [
|
|
"#### cond\n",
|
|
"python equivalent:\n",
|
|
"\n",
|
|
"```\n",
|
|
"def cond(pred, true_operand, true_fun, false_operand, false_fun):\n",
|
|
" if pred:\n",
|
|
" return true_fun(true_operand)\n",
|
|
" else:\n",
|
|
" return false_fun(false_operand)\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 148,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "SGxz9JOWeiyH",
|
|
"outputId": "b29da06c-037f-4b05-dbd8-ba52ac35a8cf"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DeviceArray([-1.], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 148,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from jax import lax\n",
|
|
"\n",
|
|
"operand = jnp.array([0.])\n",
|
|
"lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
|
"# --> array([1.], dtype=float32)\n",
|
|
"lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
|
"# --> array([-1.], dtype=float32)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "xkOFAw24eOMg"
|
|
},
|
|
"source": [
|
|
"#### while_loop\n",
|
|
"\n",
|
|
"python equivalent:\n",
|
|
"```\n",
|
|
"def while_loop(cond_fun, body_fun, init_val):\n",
|
|
" val = init_val\n",
|
|
" while cond_fun(val):\n",
|
|
" val = body_fun(val)\n",
|
|
" return val\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 149,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "jM-D39a-c436",
|
|
"outputId": "b9c97167-fecf-4559-9ca7-1cb0235d8ad2"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DeviceArray(10, dtype=int32)"
|
|
]
|
|
},
|
|
"execution_count": 149,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"init_val = 0\n",
|
|
"cond_fun = lambda x: x<10\n",
|
|
"body_fun = lambda x: x+1\n",
|
|
"lax.while_loop(cond_fun, body_fun, init_val)\n",
|
|
"# --> array(10, dtype=int32)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "apo3n3HAeQY_"
|
|
},
|
|
"source": [
|
|
"#### fori_loop\n",
|
|
"python equivalent:\n",
|
|
"```\n",
|
|
"def fori_loop(start, stop, body_fun, init_val):\n",
|
|
" val = init_val\n",
|
|
" for i in range(start, stop):\n",
|
|
" val = body_fun(i, val)\n",
|
|
" return val\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 150,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "dt3tUpOmeR8u",
|
|
"outputId": "864f2959-2429-4666-b364-4baf90a57482"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DeviceArray(45, dtype=int32)"
|
|
]
|
|
},
|
|
"execution_count": 150,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"init_val = 0\n",
|
|
"start = 0\n",
|
|
"stop = 10\n",
|
|
"body_fun = lambda i,x: x+i\n",
|
|
"lax.fori_loop(start, stop, body_fun, init_val)\n",
|
|
"# --> array(45, dtype=int32)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "SipXS5qiqk8e"
|
|
},
|
|
"source": [
|
|
"#### Summary\n",
|
|
"\n",
|
|
"$$\n",
|
|
"\\begin{array} {r|rr} \n",
|
|
"\\hline \\\n",
|
|
"\\textrm{construct} \n",
|
|
"& \\textrm{jit} \n",
|
|
"& \\textrm{grad} \\\\\n",
|
|
"\\hline \\\n",
|
|
"\\textrm{if} & ❌ & ✔ \\\\\n",
|
|
"\\textrm{for} & ✔* & ✔\\\\\n",
|
|
"\\textrm{while} & ✔* & ✔\\\\\n",
|
|
"\\textrm{lax.cond} & ✔ & ✔\\\\\n",
|
|
"\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n",
|
|
"\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n",
|
|
"\\textrm{lax.scan} & ✔ & ✔\\\\\n",
|
|
"\\hline\n",
|
|
"\\end{array}\n",
|
|
"$$\n",
|
|
"<center>$\\ast$ = argument-__value__-independent loop condition - unrolls the loop </center>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "bxuUjFVG-v1h"
|
|
},
|
|
"source": [
|
|
"## 🔪 Convolutions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "0pcn2LeS-03b"
|
|
},
|
|
"source": [
|
|
"JAX and XLA offer the very general N-dimensional __conv_general_dilated__ function, but it's not very obvious how to use it. We'll give some examples of the common use-cases.\n",
|
|
"\n",
|
|
"For the most common kinds of convolutions, see also the convenience functions lax.conv and lax.conv_general_padding, as well as jax.numpy.convolve and jax.scipy.signal.convolve/jax.scipy.signal.convolve2d for an interface similar to that of the numpy and scipy packages.\n",
|
|
"\n",
|
|
"A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285) is highly recommended reading!\n",
|
|
"\n",
|
|
"Let's define a simple diagonal edge kernel:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 151,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 286
|
|
},
|
|
"colab_type": "code",
|
|
"id": "Yud1Y3ss-x1K",
|
|
"outputId": "5aacee92-2769-4f10-d9a6-475cded80981"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Edge Conv kernel:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ8AAAD8CAYAAABpXiE9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAN7klEQVR4nO3df6yeZX3H8fdnLWAmTCol0pQqP6Nz\nbgY8QZTFNEMTJIYukSXwh4LRdDjJlGgy1AQTk2XqHy5jGkkDRFgMNoKB41JDYMBwWYpUUiiFIIW4\ntLUTLK7IdLKy7/44N+bxcH71eu7zPM/R9yt58lz3fV/nvr692nx6/2xTVUjSkfq9cRcgaWUyPCQ1\nMTwkNTE8JDUxPCQ1MTwkNRkqPJK8NsldSZ7svtfM0++lJDu7z/QwY0qaDBnmOY8kXwKeq6ovJLka\nWFNVfzNHvxeq6tgh6pQ0YYYNjyeAjVV1IMk64L6qeuMc/QwP6bfMsOHxX1V1fNcO8LOXl2f1Owzs\nBA4DX6iq2+fZ32ZgM8Crfz9ve9MZRzfXJu362YnjLmHivbh330+rqmmiVi/WIcndwElzbPrs4EJV\nVZL5kugNVbU/yWnAPUl2VdVTsztV1RZgC8DUW19V379zw6K/AGk+p2+9YtwlTLwffeJT/9H6s4uG\nR1W9e75tSX6SZN3Aacsz8+xjf/f9dJL7gLOAV4SHpJVj2Fu108BlXfsy4I7ZHZKsSXJM114LnAc8\nNuS4ksZs2PD4AvCeJE8C7+6WSTKV5Pquzx8CO5I8DNzLzDUPw0Na4RY9bVlIVR0Ezp9j/Q7gI137\n34E/HmYcSZPHJ0wlNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1\nMTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUx\nPCQ16SU8klyQ5Ikke5JcPcf2Y5Js7bY/kOSUPsaVND5Dh0eSVcBXgfcCbwYuTfLmWd0+DPysqs4A\n/h744rDjShqvPo48zgH2VNXTVfUi8E1g06w+m4CbuvatwPlJ0sPYksakj/BYD+wdWN7XrZuzT1Ud\nBg4BJ/QwtqQxmagLpkk2J9mRZMezB18adzmSFtBHeOwHNgwsn9ytm7NPktXAa4CDs3dUVVuqaqqq\npk48YVUPpUlaLn2Ex4PAmUlOTXI0cAkwPavPNHBZ174YuKeqqoexJY3J6mF3UFWHk1wJ3AmsAm6s\nqt1JPg/sqKpp4Abgn5LsAZ5jJmAkrWBDhwdAVW0Dts1ad81A+3+Av+hjLEmTYaIumEpaOQwPSU0M\nD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwP\nSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTXoJjyQXJHkiyZ4k\nV8+x/fIkzybZ2X0+0se4ksZn9bA7SLIK+CrwHmAf8GCS6ap6bFbXrVV15bDjSZoMfRx5nAPsqaqn\nq+pF4JvAph72K2mCDX3kAawH9g4s7wPePke/9yd5F/BD4Kqq2ju7Q5LNwGaA16/vo7TfXqdvvWLc\nJUy8M67aPu4SJt6PhvjZUV0w/Q5wSlX9CXAXcNNcnapqS1VNVdXUiSesGlFpklr0ER77gQ0Dyyd3\n636tqg5W1a+6xeuBt/UwrqQx6iM8HgTOTHJqkqOBS4DpwQ5J1g0sXgQ83sO4ksZo6AsLVXU4yZXA\nncAq4Maq2p3k88COqpoG/jrJRcBh4Dng8mHHlTRevVyVrKptwLZZ664ZaH8a+HQfY0maDD5hKqmJ\n4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnh\nIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqUkv4ZHkxiTPJHl0\nnu1Jcm2SPUkeSXJ2H+NKGp++jjy+DlywwPb3Amd2n83A13oaV9KY9BIeVXU/8NwCXTYBN9eM7cDx\nSdb1Mbak8RjVNY/1wN6B5X3dut+QZHOSHUl2PHvwpRGVJqnFRF0wraotVTVVVVMnnrBq3OVIWsCo\nwmM/sGFg+eRunaQValThMQ18sLvrci5wqKoOjGhsSctgdR87SXILsBFYm2Qf8DngKICqug7YBlwI\n7AF+AXyoj3EljU8v4VFVly6yvYCP9TGWpMkwURdMJa0choekJoaHpCaGh6QmhoekJoaHpCaGh6Qm\nhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaG\nh6QmhoekJoaHpCaGh6QmhoekJoaHpCa9hEeSG5M8k+TRebZvTHIoyc7uc00f40oan17+o2vg68BX\ngJsX6PO9qnpfT+NJGrNejjyq6n7guT72JWll6OvIYynekeRh4MfAp6pq9+wOSTYDmwFWrVnD6Vuv\nGGF5K8sZV20fdwn6HTeqC6YPAW+oqrcC/wjcPlenqtpSVVNVNbXq2FePqDRJLUYSHlX1fFW90LW3\nAUclWTuKsSUtj5GER5KTkqRrn9ONe3AUY0taHr1c80hyC7ARWJtkH/A54CiAqroOuBj4aJLDwC+B\nS6qq+hhb0nj0Eh5Vdeki27/CzK1cSb8lfMJUUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1IT\nw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPD\nQ1ITw0NSE8NDUhPDQ1ITw0NSk6HDI8mGJPcmeSzJ7iQfn6NPklybZE+SR5KcPey4ksarj//o+jDw\nyap6KMlxwA+S3FVVjw30eS9wZvd5O/C17lvSCjX0kUdVHaiqh7r2z4HHgfWzum0Cbq4Z24Hjk6wb\ndmxJ49PrNY8kpwBnAQ/M2rQe2DuwvI9XBoykFaS38EhyLHAb8Imqer5xH5uT7Eiy46UX/ruv0iQt\ng17CI8lRzATHN6rq23N02Q9sGFg+uVv3G6pqS1VNVdXUqmNf3UdpkpZJH3dbAtwAPF5VX56n2zTw\nwe6uy7nAoao6MOzYksanj7st5wEfAHYl2dmt+wzweoCqug7YBlwI7AF+AXyoh3EljdHQ4VFV/wZk\nkT4FfGzYsSRNDp8wldTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ\n1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDU\nxPCQ1MTwkNRk6PBIsiHJvUkeS7I7ycfn6LMxyaEkO7vPNcOOK2m8Vvewj8PAJ6vqoSTHAT9IcldV\nPTar3/eq6n09jCdpAgx95FFVB6rqoa79c+BxYP2w+5U02VJV/e0sOQW4H3hLVT0/sH4jcBuwD/gx\n8Kmq2j3Hz28GNneLbwEe7a24fqwFfjruIgZYz8ImrR6YvJreWFXHtfxgb+GR5FjgX4G/rapvz9r2\nB8D/VdULSS4E/qGqzlxkfzuqaqqX4noyaTVZz8ImrR6YvJqGqaeXuy1JjmLmyOIbs4MDoKqer6oX\nuvY24Kgka/sYW9J49HG3JcANwONV9eV5+pzU9SPJOd24B4cdW9L49HG35TzgA8CuJDu7dZ8BXg9Q\nVdcBFwMfTXIY+CVwSS1+vrSlh9r6Nmk1Wc/CJq0emLyamuvp9YKppN8dPmEqqYnhIanJxIRHktcm\nuSvJk933mnn6vTTwmPv0MtRxQZInkuxJcvUc249JsrXb/kD3bMuyWkJNlyd5dmBePrKMtdyY5Jkk\ncz6DkxnXdrU+kuTs5arlCGoa2esRS3xdY6RztGyvkFTVRHyALwFXd+2rgS/O0++FZaxhFfAUcBpw\nNPAw8OZZff4KuK5rXwJsXeZ5WUpNlwNfGdHv07uAs4FH59l+IfBdIMC5wAMTUNNG4J9HND/rgLO7\n9nHAD+f4/RrpHC2xpiOeo4k58gA2ATd17ZuAPx9DDecAe6rq6ap6EfhmV9egwTpvBc5/+Tb0GGsa\nmaq6H3hugS6bgJtrxnbg+CTrxlzTyNTSXtcY6RwtsaYjNknh8bqqOtC1/xN43Tz9XpVkR5LtSfoO\nmPXA3oHlfbxykn/dp6oOA4eAE3qu40hrAnh/dwh8a5INy1jPYpZa76i9I8nDSb6b5I9GMWB3SnsW\n8MCsTWObowVqgiOcoz6e81iyJHcDJ82x6bODC1VVSea7h/yGqtqf5DTgniS7quqpvmtdYb4D3FJV\nv0ryl8wcGf3ZmGuaJA8x8+fm5dcjbgcWfD1iWN3rGrcBn6iB97zGaZGajniORnrkUVXvrqq3zPG5\nA/jJy4du3fcz8+xjf/f9NHAfMynal/3A4N/aJ3fr5uyTZDXwGpb3adlFa6qqg1X1q27xeuBty1jP\nYpYyhyNVI349YrHXNRjDHC3HKySTdNoyDVzWtS8D7pjdIcmaJMd07bXMPN06+98NGcaDwJlJTk1y\nNDMXRGff0Rms82LgnuquOC2TRWuadb58ETPntOMyDXywu6NwLnBo4HR0LEb5ekQ3zoKvazDiOVpK\nTU1zNIor0Eu8InwC8C/Ak8DdwGu79VPA9V37ncAuZu447AI+vAx1XMjM1eingM926z4PXNS1XwV8\nC9gDfB84bQRzs1hNfwfs7ublXuBNy1jLLcAB4H+ZOVf/MHAFcEW3PcBXu1p3AVMjmJ/FarpyYH62\nA+9cxlr+FCjgEWBn97lwnHO0xJqOeI58PF1Sk0k6bZG0ghgekpoYHpKaGB6SmhgekpoYHpKaGB6S\nmvw/0ikHROf6cwcAAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# 2D kernel - HWIO layout\n",
|
|
"kernel = np.zeros((3, 3, 3, 3), dtype=jnp.float32)\n",
|
|
"kernel += np.array([[1, 1, 0],\n",
|
|
" [1, 0,-1],\n",
|
|
" [0,-1,-1]])[:, :, np.newaxis, np.newaxis]\n",
|
|
"\n",
|
|
"print(\"Edge Conv kernel:\")\n",
|
|
"plt.imshow(kernel[:, :, 0, 0]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "dITPaPdh_cMI"
|
|
},
|
|
"source": [
|
|
"And we'll make a simple synthetic image:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 152,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 286
|
|
},
|
|
"colab_type": "code",
|
|
"id": "cpbGsIGa_Qyx",
|
|
"outputId": "e27385e6-8fa2-498d-f952-7d8e04775856"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Original Image:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAANBklEQVR4nO3df6jd9X3H8edrcfWPTlCnC6J2iZIW\ntIw7K7awKnZbW5Wx6P5wkbFmrSwKBjYYDO1gle2fsdUJZa0lsmCE1R9sWENx1SyM+s9cTdrgr2qN\nNmKymEwd2q2lbeJ7f5zvXY7Xe5fcc8435977eT7gy/l+P+fH9/Px3Jff7/eck887VYWk9vzctDsg\naToMv9Qowy81yvBLjTL8UqMMv9So3sKf5KokLyTZm+TWvvYjaTTp43v+JKuA7wOfBPYDTwI3VNVz\nE9+ZpJH0deS/DNhbVS9X1U+B+4H1Pe1L0ghO6el1zwVeHdreD3x0oQcn8WeGUn9er6qz5zb2Ff7j\nSrIJ2DSt/UsNeWW+xr7CfwA4f2j7vK7t/1TVFmALeOSXpqGva/4ngXVJ1iZ5H7AB2N7TviSNoJcj\nf1UdSbIZeBRYBWytqmf72Jek0fTyVd+iO+Fpv9Sn3VV16dxGf+EnNcrwS40y/FKjDL/UKMMvNcrw\nS40y/FKjDL/UKMMvNcrwS40y/FKjDL/UKMMvNcrwS40y/FKjDL/UKMMvNcrwS40aOfxJzk/yr0me\nS/Jskj/q2m9PciDJnm65ZnLdlTQp40zgeQT4k6r6TpLTgN1JdnT33VlVXxy/e5L6MnL4q+ogcLBb\n/2GS7zGo1CNpGZjINX+SNcCvAv/eNW1O8lSSrUnOmMQ+JE3W2OFP8gvAPwF/XFVvA3cBFwIzDM4M\n7ljgeZuS7Eqya9w+SFq8sebtT/LzwDeAR6vqb+e5fw3wjar68HFex3n7pf5Mdt7+JAH+HvjecPCT\nnDP0sOuAZ0bdh6T+jPNp/68Bvw88nWRP1/Z54IYkM0AB+4CbxuqhpF5Yrkta+eY97e+rRPeSsJj/\no6S3XkhLkz/vlRpl+KVGGX6pUYZfapThlxpl+KVGGX6pUYZfapThlxpl+KVGreif9/qTXWlhHvml\nRhl+qVGGX2qU4ZcaZfilRhl+qVGGX2rU2N/zJ9kH/BA4ChypqkuTnAk8AKxhMInn9VX1X+PuS9Lk\nTOrI/4mqmhmaJPBWYGdVrQN2dtuSlpC+TvvXA9u69W3AtT3tR9KIJhH+Ah5LsjvJpq5tdVfIE+A1\nYPXcJ1muS5quSfy2/+NVdSDJLwE7kjw/fGdV1Xzz8lfVFmALOG+/NA1jH/mr6kB3exh4CLgMODRb\ntqu7PTzufiRN1ljhT/L+JKfNrgOfYlCbbzuwsXvYRuDhcfYjafLGPe1fDTw0qNnJKcDXquqbSZ4E\nHkxyI/AKcP2Y+5E0Ydbqk1a+yZbolrS8GX6pUYZfapThlxpl+KVGGX6pUYZfapThlxpl+KVGGX6p\nUYZfapThlxpl+KVGGX6pUYZfapThlxpl+KVGGX6pUSPP4ZfkQwxKcs26APhz4HTgD4H/7No/X1WP\njNxDSb2YyBx+SVYBB4CPAp8F/ruqvriI5zuHn9SfXufw+w3gpap6ZUKvJ6lnkwr/BuC+oe3NSZ5K\nsjXJGfM9wXJd0nSNfdqf5H3AfwAXV9WhJKuB1xnU8PtL4Jyq+txxXsPTfqk/vZ32Xw18p6oOAVTV\noao6WlXvAHczKN8laYmZRPhvYOiUf7ZGX+c6BuW7JC0xY5Xr6urzfRK4aaj5r5PMMDjt3zfnPklL\nhOW6pJXPcl2SjjH8UqMMv9Qowy81yvBLjTL8UqMMv9Qowy81yvBLjTL8UqPG+m2/loHF/HA6vfVC\nS5BHfqlRhl9qlOGXGmX4pUYZfqlRhl9qlOGXGnVC4e/m3z+c5JmhtjOT7EjyYnd7RteeJF9Ksreb\nu/+SvjovaXQneuS/B7hqTtutwM6qWgfs7LZhMJX3um7ZBNw1fjclTdoJhb+qHgfenNO8HtjWrW8D\nrh1qv7cGngBOnzOdt6QlYJxr/tVVdbBbfw1Y3a2fC7w69Lj9XZumIYtY1JSJ/La/qmqx028n2cTg\nskDSFIxz5D80ezrf3R7u2g8A5w897ryu7V2qaktVXTrffOKS+jdO+LcDG7v1jcDDQ+2f6T71/xjw\n1tDlgaSloqqOuzCoxXcQ+BmDa/gbgV9k8Cn/i8C/AGd2jw3wZeAl4Gng0hN4/XJxcelt2TVf7izX\nJa18luuSdIzhlxpl+KVGGX6pUYZfapThlxpl+KVGGX6pUYZfapThlxpl+KVGGX6pUYZfapThlxpl\n+KVGGX6pUYZfapThlxp13PAvUKrrb5I835XjeijJ6V37miQ/TrKnW77aZ+clje5Ejvz38N5SXTuA\nD1fVrwDfB24buu+lqprplpsn001Jk3bc8M9XqquqHquqI93mEwzm5pe0jEzimv9zwD8Pba9N8t0k\n30py+QReX1IPxirXleTPgCPAP3RNB4EPVNUbST4CfD3JxVX19jzPtVyXNEUjH/mT/AHwW8Dv1Wzl\njaqfVNUb3fpuBoU7Pjjf8y3XJU3XSOFPchXwp8BvV9WPhtrPTrKqW78AWAe8PImOSpqs4572J7kP\nuBI4K8l+4AsMPt0/FdiRBOCJ7pP9K4C/SPIz4B3g5qp6c94XljRVluuSVj7LdUk6xvBLjTL8UqMM\nv9Qowy81yvBLjTL8UqMMv9Qowy81yvBLjTL8UqMMv9Qowy81yvBLjTL8UqMMv9Qowy81yvBLjRq1\nXNftSQ4MleW6Zui+25LsTfJCkk/31XFJ4xm1XBfAnUNluR4BSHIRsAG4uHvOV2Zn85W0tIxUruv/\nsR64v5u//wfAXuCyMfonqSfjXPNv7qr0bk1yRtd2LvDq0GP2d23SlNUiljaMGv67gAuBGQYluu5Y\n7Ask2ZRkV5JdI/ZB0hhGCn9VHaqqo1X1DnA3x07tDwDnDz30vK5tvtewXJc0RaOW6zpnaPM6YPab\ngO3AhiSnJlnLoFzXt8froqQ+jFqu68okMwwukPYBNwFU1bNJHgSeY1C995aqOtpP1yWNw3JdasRi\n/sTSWy+mxHJdko4x/FKjDL/UKMMvNcrwS4067ld90sqw4j7BH5tHfqlRhl9qlOGXGmX4pUYZfqlR\nhl9qlOGXGmX4pUYZfqlRhl9qlOGXGmX4pUaNWq7rgaFSXfuS7Ona1yT58dB9X+2z85JGdyL/qu8e\n4O+Ae2cbqup3Z9eT3AG8NfT4l6pqZlIdlNSP44a/qh5Psma++5IEuB749cl2S1Lfxr3mvxw4VFUv\nDrWtTfLdJN9KcvmYry+pJ+NO5nEDcN/Q9kHgA1X1RpKPAF9PcnFVvT33iUk2AZvG3L+kEY185E9y\nCvA7wAOzbV113je69d3AS8AH53u+5bqk6RrntP83geerav9sQ5Kzk6zq1i9gUK7r5fG6KKkPJ/JV\n333AvwEfSrI/yY3dXRt49yk/wBXAU91Xf/8I3FxVb06yw5Imw3Jd0spnuS5Jxxh+qVGGX2qU4Zca\nZfilRhl+qVGGX2qU4ZcaZfilRhl+qVGGX2qU4ZcaZfilRhl+qVGGX2qU4ZcaZfilRhl+qVGGX2qU\n4ZcaZfilRhl+qVHjluualNeB/+luV5qzWJnjgpU7tpU2rl+er3FJzNsPkGTXSizdtVLHBSt3bCt1\nXHN52i81yvBLjVpK4d8y7Q70ZKWOC1bu2FbquN5lyVzzSzq5ltKRX9JJNPXwJ7kqyQtJ9ia5ddr9\nGVeSfUmeTrInya6u7cwkO5K82N2eMe1+Hk+SrUkOJ3lmqG3ecWTgS917+FSSS6bX8+NbYGy3JznQ\nvW97klwzdN9t3dheSPLp6fR68qYa/iSrgC8DVwMXATckuWiafZqQT1TVzNDXRbcCO6tqHbCz217q\n7gGumtO20DiuBtZ1yybgrpPUx1Hdw3vHBnBn977NVNUjAN3f4wbg4u45X+n+bpe9aR/5LwP2VtXL\nVfVT4H5g/ZT71If1wLZufRtw7RT7ckKq6nHgzTnNC41jPXBvDTwBnJ7knJPT08VbYGwLWQ/cX1U/\nqaofAHsZ/N0ue9MO/7nAq0Pb+7u25ayAx5LsTrKpa1tdVQe79deA1dPp2tgWGsdKeR83d5ctW4cu\nzVbK2N5j2uFfiT5eVZcwOBW+JckVw3fW4OuVZf8Vy0oZx5C7gAuBGeAgcMd0u9O/aYf/AHD+0PZ5\nXduyVVUHutvDwEMMThEPzZ4Gd7eHp9fDsSw0jmX/PlbVoao6WlXvAHdz7NR+2Y9tIdMO/5PAuiRr\nk7yPwQcr26fcp5EleX+S02bXgU8BzzAY08buYRuBh6fTw7EtNI7twGe6T/0/Brw1dHmwLMz5jOI6\nBu8bDMa2IcmpSdYy+FDz2ye7f32Y6r/qq6ojSTYDjwKrgK1V9ew0+zSm1cBDSWDw3/ZrVfXNJE8C\nDya5EXgFuH6KfTwhSe4DrgTOSrIf+ALwV8w/jkeAaxh8GPYj4LMnvcOLsMDYrkwyw+BSZh9wE0BV\nPZvkQeA54AhwS1UdnUa/J81f+EmNmvZpv6QpMfxSowy/1CjDLzXK8EuNMvxSowy/1CjDLzXqfwFe\nIOBcSsg4NQAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# NHWC layout\n",
|
|
"img = np.zeros((1, 200, 198, 3), dtype=jnp.float32)\n",
|
|
"for k in range(3):\n",
|
|
" x = 30 + 60*k\n",
|
|
" y = 20 + 60*k\n",
|
|
" img[0, x:x+10, y:y+10, k] = 1.0\n",
|
|
"\n",
|
|
"print(\"Original Image:\")\n",
|
|
"plt.imshow(img[0]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "_m90y74OWorG"
|
|
},
|
|
"source": [
|
|
"### lax.conv and lax.conv_with_general_padding"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Pv9_QPDnWssM"
|
|
},
|
|
"source": [
|
|
"These are the simple convenience functions for convolutions\n",
|
|
"\n",
|
|
"️⚠️ The convenience `lax.conv`, `lax.conv_with_general_padding` helper function assume __NCHW__ images and __OIHW__ kernels."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 153,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 628
|
|
},
|
|
"colab_type": "code",
|
|
"id": "kppxbxpZW0nb",
|
|
"outputId": "0d72fdd9-19d7-45ae-891b-b19df819620f"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"out shape: (1, 3, 200, 198)\n",
|
|
"First output channel:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAJBCAYAAACqM9quAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAaP0lEQVR4nO3df6zld13n8dd7Z7SJUzdtd7pNhUIL\nKQ1qulXHSqIQdpHaEkPFP6CNUVSyhQSIzZooaLIQExNXrXZ1s5gSGiCpBVysNqZVuqwr2cQqU2zG\nQhmZYhumO7YdCmKLYe3w3j/mzHBnuMNM594zZ5j345Hc3O/5nB/fz3zzvdPnfD/n3FZ3BwBgon+1\n6gkAAKyKEAIAxhJCAMBYQggAGEsIAQBjCSEAYKylhVBVXV1Vu6tqT1W9bVn7AQA4WbWM3yNUVVuS\n/F2SVybZm+TjSa7v7k9t+s4AAE7S1iW97pVJ9nT3Z5Okqj6Q5Nok64bQlrO39dbzzlvSVACAyZ55\n8skceOrpWu++ZYXQc5J8bs3tvUl+4FgP3nreefmOn79xSVMBACb7vzfdfMz7VvZm6aq6oap2VtXO\nA089vappAACDLSuEHk1y0Zrbz12MHdbdt3T3ju7eseXsbUuaBgDAsS0rhD6e5NKquqSqvjXJdUnu\nXNK+AABOylLeI9Tdz1TVW5L8WZItSW7t7k8uY18AACdrWW+WTnffleSuZb0+AMBG+c3SAMBYQggA\nGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICx\nhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsI\nAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAA\nMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBj\nCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAY510CFXVRVX151X1qar6ZFX93GL8\nnVX1aFXdv/h61eZNFwBg82zdwHOfSfLz3f2Jqvr2JPdV1T2L+367u39z49MDAFiekw6h7t6XZN9i\n+5+q6sEkz9msiQEALNumvEeoqi5O8j1J/mox9Jaq2lVVt1bVuZuxDwCAzbbhEKqqs5N8OMmN3f2l\nJO9K8sIkV+TgFaObjvG8G6pqZ1XtPPDU0xudBgDAs7ahEKqqb8nBCLqtu/8wSbr7se4+0N1fTfLu\nJFeu99zuvqW7d3T3ji1nb9vINAAATspGPjVWSd6T5MHu/q014xeuedhrkjxw8tMDAFiejXxq7AeT\n/GSSv62q+xdjv5Tk+qq6IkkneTjJGzc0QwCAJdnIp8b+T5Ja5667Tn46AACnjt8sDQCMJYQAgLGE\nEAAwlhACAMbayKfGOIZzHjz4HvLtu5b/iyL3X/6138H0xRf30vcHAGcSV4QAgLGEEAAwlqWxJTi8\nJHbvrq8NvuTypexr7XLY+ZftT5I8sXv7UvYFAGcaV4QAgLFcEVqmNVeB9rzu25ayixuvuvvrxm7e\nfc1S9gUAZxpXhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhAC\nAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYa+uqJ3Am2n/5\ntiTJF1/ch8duvOrupezrrec+cnj7d7/w/KXsAwDOVK4IAQBjCSEAYCxLY0twaEns/Mv2L31fa5fD\nbnvk+5e+PwA4k7giBACMJYQAgLEsjS3RE7u3H96+efc1K5wJALAeV4QAgLGEEAAwlhACAMYSQgDA\nWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwl\nhACAsYQQADCWEAIAxhJCAMBYQggAGGvrRl+gqh5O8k9JDiR5prt3VNV5ST6Y5OIkDyd5bXd/YaP7\nAgDYTJt1Rejfd/cV3b1jcfttST7a3Zcm+ejiNgDAaWVZS2PXJnnfYvt9SX5sSfsBADhpmxFCneQj\nVXVfVd2wGLugu/cttv8hyQVHP6mqbqiqnVW188BTT2/CNAAAnp0Nv0coyQ9196NV9W+T3FNVn157\nZ3d3VfXRT+ruW5LckiRnPe+ir7sfAGDZNnxFqLsfXXx/PMkdSa5M8lhVXZgki++Pb3Q/AACbbUMh\nVFXbqurbD20nuSrJA0nuTPL6xcNen+SPN7IfAIBl2OjS2AVJ7qiqQ6/1+939p1X18SQfqqo3JHkk\nyWs3uB8AgE23oRDq7s8m+XfrjH8+ySs28toAAMvmN0sDAGMJIQBgLCEEAIwlhACAsYQQADCWEAIA\nxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAs\nIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBY\nQggAGEsIAQBjCSEAYCwhBACMtfVkn1hVlyX54JqhFyT5z0nOSfIfkzyxGP+l7r7rpGcIALAkJx1C\n3b07yRVJUlVbkjya5I4kP5Pkt7v7NzdlhgAAS7JZS2OvSPJQdz+ySa8HALB0mxVC1yW5fc3tt1TV\nrqq6tarOXe8JVXVDVe2sqp0Hnnp6k6YBAHDiNhxCVfWtSV6d5A8WQ+9K8sIcXDbbl+Sm9Z7X3bd0\n947u3rHl7G0bnQYAwLO2GVeErknyie5+LEm6+7HuPtDdX03y7iRXbsI+AAA23WaE0PVZsyxWVReu\nue81SR7YhH0AAGy6k/7UWJJU1bYkr0zyxjXDv15VVyTpJA8fdR8AwGljQyHU3U8n+TdHjf3khmYE\nAHCK+M3SAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkh\nAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIA\nxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAw1tZVTwBYnnMerMPb23c9vdR97b982+Ht\nL764l7ovgM3iihAAMJYQAgDGsjQGZ7AjlsPu3XXw+0suX8q+1i6HnX/Z/sPbT+zevpT9AWwGV4QA\ngLGEEAAwlqUxmGKxJLbndd+2lJe/8aq71x2/efc1S9kfwGZwRQgAGEsIAQBjCSEAYCwhBACMJYQA\ngLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAY\nSwgBAGMJIQBgLCEEAIy1ddUTAJZn/+XbDm9/8cWdJLnxqruXsq+3nvvI4e3f/cLzl7IPgM12QleE\nqurWqnq8qh5YM3ZeVd1TVZ9ZfD93MV5V9TtVtaeqdlXV9y5r8gAAG3GiS2PvTXL1UWNvS/LR7r40\nyUcXt5PkmiSXLr5uSPKujU8TAGDzndDSWHd/rKouPmr42iQvX2y/L8n/TvKLi/H3d3cnubeqzqmq\nC7t732ZMGDhxh5bDkuT8y/YvdV9rl8Nue+T7l7ovgM2ykTdLX7Ambv4hyQWL7eck+dyax+1djB2h\nqm6oqp1VtfPAU09vYBoAACdnUz41trj608d94JHPuaW7d3T3ji1nbzv+EwAANtlGPjX22KElr6q6\nMMnji/FHk1y05nHPXYwBK/TE7u1Jkpt3X7PimQCcPjZyRejOJK9fbL8+yR+vGf+pxafHXpLkH70/\nCAA4HZ3QFaGquj0H3xi9var2JnlHkl9L8qGqekOSR5K8dvHwu5K8KsmeJF9O8jObPGcAgE1xop8a\nu/4Yd71incd2kjdvZFIAAKeC/8UGADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGE\nEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgB\nAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAw\nlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJ\nIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhAC\nAMYSQgDAWMcNoaq6taoer6oH1oz9RlV9uqp2VdUdVXXOYvziqvrnqrp/8fV7y5w8AMBGnMgVofcm\nufqosXuSfHd3X57k75K8fc19D3X3FYuvN23ONAEANt9xQ6i7P5bkyaPGPtLdzyxu3pvkuUuYGwDA\nUm3Ge4R+Nsnda25fUlV/U1V/UVUvPdaTquqGqtpZVTsPPPX0JkwDAODZ2bqRJ1fVLyd5Jslti6F9\nSZ7X3Z+vqu9L8kdV9V3d/aWjn9vdtyS5JUnOet5FvZF5AACcjJO+IlRVP53kR5P8RHd3knT3V7r7\n84vt+5I8lORFmzBPAIBNd1IhVFVXJ/mFJK/u7i+vGT+/qrYstl+Q5NIkn92MiQIAbLbjLo1V1e1J\nXp5ke1XtTfKOHPyU2FlJ7qmqJLl38QmxlyX5lar6lyRfTfKm7n5y3RcGAFix44ZQd1+/zvB7jvHY\nDyf58EYnBQBwKvjN0gDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAY\nSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGE\nEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgB\nAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAw\nlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYa+uqJwDA\nme+cB+vw9vZdTy99f/sv35Yk+eKLe+n74pvbca8IVdWtVfV4VT2wZuydVfVoVd2/+HrVmvveXlV7\nqmp3Vf3IsiYOALBRJ7I09t4kV68z/tvdfcXi664kqarvTHJdku9aPOe/V9WWzZosAMBmOu7SWHd/\nrKouPsHXuzbJB7r7K0n+vqr2JLkyyV+e9AwB+KZ3xHLYvbu+tv2Sy5eyv0NLYudftv/w2BO7ty9l\nX3xz28ibpd9SVbsWS2fnLsaek+Rzax6zdzH2darqhqraWVU7Dzy1/PViAICjnWwIvSvJC5NckWRf\nkpue7Qt09y3dvaO7d2w5e9tJTgMA4OSd1KfGuvuxQ9tV9e4kf7K4+WiSi9Y89LmLMQA4aM1y2J7X\nfdtSdnHjVXd/3djNu69Zyr745nZSV4Sq6sI1N1+T5NAnyu5Mcl1VnVVVlyS5NMlfb2yKAADLcdwr\nQlV1e5KXJ9leVXuTvCPJy6vqiiSd5OEkb0yS7v5kVX0oyaeSPJPkzd19YDlTBwDYmBP51Nj16wy/\n5xs8/leT/OpGJgUAcCr4X2wAAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBj\nCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAY21d9QQAOPPtv3zb4e0vvrgPb994\n1d1L2d9bz30kSfK7X3j+Ul6fM4crQgDAWEIIABjL0hgAS7d2Oez8y/YvfX+HlsRue+T7l74vvrm5\nIgQAjCWEAICxLI0BcEo9sXv74e2bd1+zwpmAK0IAwGBCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQ\nADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEA\nYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCW\nEAIAxhJCAMBYQggAGEsIAQBjHTeEqurWqnq8qh5YM/bBqrp/8fVwVd2/GL+4qv55zX2/t8zJAwBs\nxNYTeMx7k/y3JO8/NNDdrzu0XVU3JfnHNY9/qLuv2KwJAgAsy3FDqLs/VlUXr3dfVVWS1yb5D5s7\nLQCA5dvoe4RemuSx7v7MmrFLqupvquovquqlx3piVd1QVTuraueBp57e4DQAAJ69E1ka+0auT3L7\nmtv7kjyvuz9fVd+X5I+q6ru6+0tHP7G7b0lyS5Kc9byLeoPzAAB41k76ilBVbU3y40k+eGisu7/S\n3Z9fbN+X5KEkL9roJAEAlmEjS2M/nOTT3b330EBVnV9VWxbbL0hyaZLPbmyKAADLcSIfn789yV8m\nuayq9lbVGxZ3XZcjl8WS5GVJdi0+Tv8/krypu5/czAkDAGyWE/nU2PXHGP/pdcY+nOTDG58WAMDy\n+c3SAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAs\nIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBY\nQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjFXd\nveo5pKqeSPJIku1J9q94OqcTx+NIjseRHI+vcSyO5HgcyfE40sTj8fzuPn+9O06LEDqkqnZ2945V\nz+N04XgcyfE4kuPxNY7FkRyPIzkeR3I8jmRpDAAYSwgBAGOdbiF0y6oncJpxPI7keBzJ8fgax+JI\njseRHI8jOR5rnFbvEQIAOJVOtytCAACnzGkRQlV1dVXtrqo9VfW2Vc/nVKuqi6rqz6vqU1X1yar6\nucX4O6vq0aq6f/H1qlXP9VSpqoer6m8Xf+6di7HzquqeqvrM4vu5q57nqVBVl605B+6vqi9V1Y2T\nzo+qurWqHq+qB9aMrXs+1EG/s/j7ZFdVfe/qZr4cxzgev1FVn178me+oqnMW4xdX1T+vOU9+b3Uz\nX45jHI9j/nxU1dsX58fuqvqR1cx6eY5xPD645lg8XFX3L8bP+PPjeFa+NFZVW5L8XZJXJtmb5ONJ\nru/uT610YqdQVV2Y5MLu/kRVfXuS+5L8WJLXJnmqu39zpRNcgap6OMmO7t6/ZuzXkzzZ3b+2COZz\nu/sXVzXHVVj8vDya5AeS/EyGnB9V9bIkTyV5f3d/92Js3fNh8R+8tyZ5VQ4ep//a3T+wqrkvwzGO\nx1VJ/ld3P1NV/yVJFsfj4iR/cuhxZ6JjHI93Zp2fj6r6ziS3J7kyyXck+Z9JXtTdB07ppJdoveNx\n1P03JfnH7v6VCefH8ZwOV4SuTLKnuz/b3f8vyQeSXLviOZ1S3b2vuz+x2P6nJA8mec5qZ3VaujbJ\n+xbb78vBWJzmFUke6u5HVj2RU6m7P5bkyaOGj3U+XJuD/wHo7r43yTmLf2ycMdY7Ht39ke5+ZnHz\n3iTPPeUTW5FjnB/Hcm2SD3T3V7r775PsycH/Dp0xvtHxqKrKwX9k335KJ3UaOx1C6DlJPrfm9t4M\njoBFnX9Pkr9aDL1lcan71ilLQQud5CNVdV9V3bAYu6C79y22/yHJBauZ2kpdlyP/Apt6fiTHPh/8\nnZL8bJK719y+pKr+pqr+oqpeuqpJrcB6Px/Tz4+XJnmsuz+zZmzq+ZHk9AghFqrq7CQfTnJjd38p\nybuSvDDJFUn2JblphdM71X6ou783yTVJ3ry41HtYH1zTHfWRx6r61iSvTvIHi6HJ58cRJp4Px1JV\nv5zkmSS3LYb2JXled39Pkv+U5Per6l+van6nkJ+P9V2fI/8xNfX8OOx0CKFHk1y05vZzF2OjVNW3\n5GAE3dbdf5gk3f1Ydx/o7q8meXfOsMu330h3P7r4/niSO3Lwz/7YoSWOxffHVzfDlbgmySe6+7Fk\n9vmxcKzzYezfKVX100l+NMlPLOIwiyWgzy+270vyUJIXrWySp8g3+PmYfH5sTfLjST54aGzq+bHW\n6RBCH09yaVVdsvgX73VJ7lzxnE6pxZrte5I82N2/tWZ87fsaXpPkgaOfeyaqqm2LN42nqrYluSoH\n/+x3Jnn94mGvT/LHq5nhyhzxL7mp58caxzof7kzyU4tPj70kB98Uum+9FziTVNXVSX4hyau7+8tr\nxs9fvMk+VfWCJJcm+exqZnnqfIOfjzuTXFdVZ1XVJTl4PP76VM9vRX44yae7e++hgannx1pbVz2B\nxScc3pLkz5JsSXJrd39yxdM61X4wyU8m+dtDH2lM8ktJrq+qK3Lwkv/DSd64mumdchckueNgH2Zr\nkt/v7j+tqo8n+VBVvSHJIzn4hr8RFkH4yhx5Dvz6lPOjqm5P8vIk26tqb5J3JPm1rH8+3JWDnxjb\nk+TLOfjpujPKMY7H25OcleSexc/Ovd39piQvS/IrVfUvSb6a5E3dfaJvLP6mcIzj8fL1fj66+5NV\n9aEkn8rBJcQ3n0mfGEvWPx7d/Z58/XsMkwHnx/Gs/OPzAACrcjosjQEArIQQAgDGEkIAwFhCCAAY\nSwgBAGMJIQBgLCEEAIwlhACAsf4/QXaJowhPRekAAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 720x720 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"out = lax.conv(jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
|
|
" jnp.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor\n",
|
|
" (1, 1), # window strides\n",
|
|
" 'SAME') # padding mode\n",
|
|
"print(\"out shape: \", out.shape)\n",
|
|
"print(\"First output channel:\")\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"plt.imshow(np.array(out)[0,0,:,:]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 154,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 628
|
|
},
|
|
"colab_type": "code",
|
|
"id": "aonr1tWvYCW9",
|
|
"outputId": "1b61e1b7-331d-4b60-b524-73a0fbad3ed9"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"out shape: (1, 3, 202, 200)\n",
|
|
"First output channel:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAJBCAYAAACqM9quAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAbkklEQVR4nO3df6xnd13n8dd7Z7SJUzdtd7pNLS0t\npDSg6RYcSxOFsIt0W2Ko+Ae0MViV7EACxGbdKGCyEBMTVkVZ3SxmCA0lqQXcWmlMq3RZV7KJVabY\nDP3ByBQ7YSZD2+GH0GLQDp/9456Z+c5wpzNz7/d773jfj0dyc8/38/1xPnN67vQ553zP99YYIwAA\nHf2r9Z4AAMB6EUIAQFtCCABoSwgBAG0JIQCgLSEEALS1sBCqquuqandV7amqdy5qPQAAK1WL+Byh\nqtqU5O+SvCbJviSfTXLTGOORua8MAGCFNi/oda9OsmeM8aUkqaqPJbkhybIhtOnsLWPzeectaCoA\nQHf/9OV9B8cY5x8/vqgQuijJl2du70vy8hM9ePN55+WHfvmWBU0FAOju8Vv+y97lxtftzdJVtb2q\ndlbVzkNPP7Ne0wAAGltUCO1PcvHM7edNY0eMMXaMMbaNMbZtOnvLgqYBAHBiiwqhzya5vKouq6rv\nT3JjkrsXtC4AgBVZyHuExhjPVtXbk/x5kk1Jbh1jPLyIdQEArNSi3iydMcY9Se5Z1OsDAKyWT5YG\nANoSQgBAW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCg\nLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0JYQAgLaEEADQlhACANoS\nQgBAW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCgLSEE\nALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0JYQAgLaEEADQlhACANoSQgBA\nW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFsrDqGquriq/qKqHqmqh6vql6bx\n91bV/qp6cPp67fymCwAwP5tX8dxnk/zyGONzVfWDSR6oqvum+353jPHbq58eAMDirDiExhgHkhyY\nlr9VVY8muWheEwMAWLS5vEeoqi5N8tIkfz0Nvb2qdlXVrVV17jzWAQAwb6sOoao6O8mdSW4ZY3wz\nyQeTvDDJVVk6YvT+Ezxve1XtrKqdh55+ZrXTAAA4basKoar6vixF0O1jjD9OkjHGE2OMQ2OM7yb5\nUJKrl3vuGGPHGGPbGGPbprO3rGYaAAArspqrxirJh5M8Osb4nZnxC2ce9vokD618egAAi7Oaq8Z+\nPMmbkny+qh6cxt6d5KaquirJSPJ4kresaoYAAAuymqvG/l+SWuaue1Y+HQCAteOTpQGAtoQQANCW\nEAIA2hJCAEBbQggAaEsIAQBtCSEAoK3VfKAiJ3HOo0c/ZmnrrsX+PrWDVx79NSXfePFY6LoAYKNw\nRAgAaMsRoQU65ijQ/buWvl9z5ULWNXsU6PwrDiZJntq9dSHrAoCNwhEhAKAtIQQAtOXU2FqZTont\neeMPLOTlb7n23u8Z+8Du6xeyLgDYKBwRAgDaEkIAQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQ\nANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0JYQAgLaEEADQlhACANoSQgBAW0IIAGhLCAEA\nbQkhAKCtzes9gY3s4JVbjix/48UjSXLLtfcuZF3vOHfvkeXf//rzF7IOANhoHBECANpyRGiBDh8F\nSpLzrzi40HXNHgW6fe+PLXRdALBROCIEALQlhACAtpwaWyNP7d6aJPnA7uvXeSYAwGGOCAEAbQkh\nAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCWEAIA\n2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0tXm1L1BVjyf5VpJDSZ4dY2yrqvOSfDzJpUkeT/KGMcbX\nV7suAIB5mtcRoX8/xrhqjLFtuv3OJJ8eY1ye5NPTbQCAM8qiTo3dkOS2afm2JD+9oPUAAKzYPEJo\nJPlUVT1QVdunsQvGGAem5a8kueD4J1XV9qraWVU7Dz39zBymAQBwelb9HqEkPzHG2F9V/zbJfVX1\nhdk7xxijqsbxTxpj7EiyI0nOuuTi77kfAGDRVn1EaIyxf/r+ZJK7klyd5ImqujBJpu9PrnY9AADz\ntqoQqqotVfWDh5eTXJvkoSR3J7l5etjNST65mvUAACzCak+NXZDkrqo6/Fp/OMb4s6r6bJJPVNWb\nk+xN8oZVrgcAYO5WFUJjjC8l+XfLjH81yatX89oAAIvmk6UBgLaEEADQlhACANoSQgBAW0IIAGhL\nCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQ\nANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0JYQAgLaEEADQlhACANoSQgBAW0IIAGhLCAEA\nbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCW\nEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0JYQAgLaEEADQlhACANoSQgBAW0IIAGhLCAEAbQkh\nAKAtIQQAtCWEAIC2hBAA0NbmlT6xqq5I8vGZoRck+a9Jzknyn5I8NY2/e4xxz4pnCACwICsOoTHG\n7iRXJUlVbUqyP8ldSX4hye+OMX57LjMEAFiQeZ0ae3WSx8YYe+f0egAACzevELoxyR0zt99eVbuq\n6taqOne5J1TV9qraWVU7Dz39zJymAQBw6lYdQlX1/Ulel+SPpqEPJnlhlk6bHUjy/uWeN8bYMcbY\nNsbYtunsLaudBgDAaZvHEaHrk3xujPFEkowxnhhjHBpjfDfJh5JcPYd1AADM3TxC6KbMnBarqgtn\n7nt9kofmsA4AgLlb8VVjSVJVW5K8JslbZoZ/s6quSjKSPH7cfQAAZ4xVhdAY45kk/+a4sTetakYA\nAGvEJ0sDAG0JIQCgLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0JYQA\ngLaEEADQlhACANoSQgBAW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABo\nSwgBAG0JIQCgLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtbV7vCQCLdc6jlSTZuuuZ\nha/r4JVbkiTfePFY+LoA5sERIQCgLSEEALTl1BhscEdOid2/6+jgNVcuZF2HT4mdf8XBI2NP7d66\nkHUBzIMjQgBAW0IIAGjLqTHoYuZ02J43/sBCVnHLtfd+z9gHdl+/kHUBzIMjQgBAW0IIAGhLCAEA\nbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCW\nEAIA2hJCAEBbQggAaOuUQqiqbq2qJ6vqoZmx86rqvqr64vT93Gm8qur3qmpPVe2qqpctavIAAKtx\nqkeEPpLkuuPG3pnk02OMy5N8erqdJNcnuXz62p7kg6ufJgDA/G0+lQeNMT5TVZceN3xDkldNy7cl\n+b9JfnUa/+gYYyS5v6rOqaoLxxgH5jFh4PQcvHJLkuQbLx5Hxm659t6FrOsd5+5Nkvz+15+/kNcH\nmLfVvEfogpm4+UqSC6bli5J8eeZx+6YxAIAzylzeLD0d/RknfeCMqtpeVTurauehp5+ZxzQAAE7L\nKZ0aO4EnDp/yqqoLkzw5je9PcvHM4543jR1jjLEjyY4kOeuSi08rooBTd/iU2PlXHFz4ug6fErt9\n748tfF0A87CaI0J3J7l5Wr45ySdnxn9uunrsmiT/4P1BAMCZ6JSOCFXVHVl6Y/TWqtqX5D1J3pfk\nE1X15iR7k7xhevg9SV6bZE+Sbyf5hTnPGQBgLk71qrGbTnDXq5d57EjyttVMCpi/p3ZvPbL8gd3X\nr+NMAM4cPlkaAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCg\nLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0JYQAgLaEEADQlhACANoS\nQgBAW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCgLSEE\nALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0JYQAgLaEEADQlhACANoSQgBA\nW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCgrZOGUFXd\nWlVPVtVDM2O/VVVfqKpdVXVXVZ0zjV9aVf9YVQ9OX3+wyMkDAKzGqRwR+kiS644buy/Jj4wxrkzy\nd0neNXPfY2OMq6avt85nmgAA83fSEBpjfCbJ144b+9QY49np5v1JnreAuQEALNQ83iP0i0nunbl9\nWVX9bVX9ZVW94kRPqqrtVbWzqnYeevqZOUwDAOD0bF7Nk6vq15I8m+T2aehAkkvGGF+tqh9N8idV\n9cNjjG8e/9wxxo4kO5LkrEsuHquZBwDASqz4iFBV/XySn0rys2OMkSRjjO+MMb46LT+Q5LEkL5rD\nPAEA5m5FIVRV1yX5lSSvG2N8e2b8/KraNC2/IMnlSb40j4kCAMzbSU+NVdUdSV6VZGtV7Uvynixd\nJXZWkvuqKknun64Qe2WSX6+qf07y3SRvHWN8bdkXBgBYZycNoTHGTcsMf/gEj70zyZ2rnRQAwFrw\nydIAQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0h\nBAC0JYQAgLaEEADQlhACANoSQgBAW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIA\nQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0\nJYQAgLaEEADQlhACANoSQgBAW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtC\nCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaOukIVRVt1bVk1X10MzYe6tq\nf1U9OH29dua+d1XVnqraXVX/cVETBwBYrVM5IvSRJNctM/67Y4yrpq97kqSqXpLkxiQ/PD3nf1bV\npnlNFgBgnjaf7AFjjM9U1aWn+Ho3JPnYGOM7Sf6+qvYkuTrJX614hgD8i3bOo3VkeeuuZxa6roNX\nbjmy/I0Xj4Wui41hNe8RentV7ZpOnZ07jV2U5Mszj9k3jX2PqtpeVTurauehpxf7gwEAsJyVhtAH\nk7wwyVVJDiR5/+m+wBhjxxhj2xhj26azt5z8CQAAc3bSU2PLGWM8cXi5qj6U5E+nm/uTXDzz0OdN\nYwA0dczpsPt3HV2+5sq5r2v2dNj5Vxw8svzU7q1zXxcbw4qOCFXVhTM3X5/k8BVldye5sarOqqrL\nklye5G9WN0UAgMU46RGhqrojyauSbK2qfUnek+RVVXVVkpHk8SRvSZIxxsNV9YkkjyR5NsnbxhiH\nFjN1AP7FmTkKtOeNPzD3l7/l2nuXHf/A7uvnvi42hlO5auymZYY//ByP/40kv7GaSQEArAWfLA0A\ntCWEAIC2hBAA0JYQAgDaEkIAQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCWEAIA2hJCAEBb\nQggAaEsIAQBtCSEAoC0hBAC0JYQAgLaEEADQlhACANravN4TAGBjO3jlliPL33jxOLJ8y7X3zn1d\n7zh375Hl3//68+f++mw8jggBAG0JIQCgLafGAFio2dNh519xcKHrmj0ddvveH1voutgYHBECANpy\nRAiANfPU7q1Hlj+w+/p1nAkscUQIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFtC\nCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0hBAC0JYQA\ngLaEEADQlhACANoSQgBAW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIAQFsnDaGq\nurWqnqyqh2bGPl5VD05fj1fVg9P4pVX1jzP3/cEiJw8AsBqbT+ExH0nyP5J89PDAGOONh5er6v1J\n/mHm8Y+NMa6a1wQBABblpCE0xvhMVV263H1VVUnekOQ/zHdaAACLt9r3CL0iyRNjjC/OjF1WVX9b\nVX9ZVa840ROrantV7ayqnYeefmaV0wAAOH2ncmrsudyU5I6Z2weSXDLG+GpV/WiSP6mqHx5jfPP4\nJ44xdiTZkSRnXXLxWOU8AABO24qPCFXV5iQ/k+Tjh8fGGN8ZY3x1Wn4gyWNJXrTaSQIALMJqTo39\nZJIvjDH2HR6oqvOratO0/IIklyf50uqmCACwGKdy+fwdSf4qyRVVta+q3jzddWOOPS2WJK9Msmu6\nnP5/JXnrGONr85wwAMC8nMpVYzedYPznlxm7M8mdq58WAMDi+WRpAKAtIQQAtCWEAIC2hBAA0JYQ\nAgDaEkIAQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEA\noC0hBAC0JYQAgLaEEADQlhACANoSQgBAW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDa\nEkIAQFtCCABoSwgBAG0JIQCgLSEEALQlhACAtoQQANCWEAIA2hJCAEBbQggAaEsIAQBtCSEAoC0h\nBAC0JYQAgLaEEADQlhACANoSQgBAW0IIAGhLCAEAbQkhAKAtIQQAtCWEAIC2hBAA0JYQAgDaEkIA\nQFtCCABoSwgBAG0JIQCgLSEEALR10hCqqour6i+q6pGqeriqfmkaP6+q7quqL07fz53Gq6p+r6r2\nVNWuqnrZov8QAAArcSpHhJ5N8stjjJckuSbJ26rqJUnemeTTY4zLk3x6up0k1ye5fPranuSDc581\nAMAcnDSExhgHxhifm5a/leTRJBcluSHJbdPDbkvy09PyDUk+Opbcn+Scqrpw7jMHAFil03qPUFVd\nmuSlSf46yQVjjAPTXV9JcsG0fFGSL888bd80BgBwRjnlEKqqs5PcmeSWMcY3Z+8bY4wk43RWXFXb\nq2pnVe089PQzp/NUAIC5OKUQqqrvy1IE3T7G+ONp+InDp7ym709O4/uTXDzz9OdNY8cYY+wYY2wb\nY2zbdPaWlc4fAGDFTuWqsUry4SSPjjF+Z+auu5PcPC3fnOSTM+M/N109dk2Sf5g5hQYAcMbYfAqP\n+fEkb0ry+ap6cBp7d5L3JflEVb05yd4kb5juuyfJa5PsSfLtJL8w1xkDAMxJLb29Z50nUfVUlmJq\na5KD6zydM4VtcZRtcSzb4yjb4ijb4ijb4li2x5LnjzHOP37wjAihw6pq5xhj23rP40xgWxxlWxzL\n9jjKtjjKtjjKtjiW7fHc/IoNAKAtIQQAtHWmhdCO9Z7AGcS2OMq2OJbtcZRtcZRtcZRtcSzb4zmc\nUe8RAgBYS2faESEAgDVzRoRQVV1XVburak9VvfPkz9g4quriqvqLqnqkqh6uql+axt9bVfur6sHp\n67XrPde1UlWPV9Xnpz/3zmnsvKq6r6q+OH0/d73nuWhVdcXMf/8Hq+qbVXVLl32jqm6tqier6qGZ\nsWX3g+kDXH9v+jtkV1W9bP1mvhgn2B6/VVVfmP7Md1XVOdP4pVX1jzP7yB+s38zn7wTb4oQ/F1X1\nrmnf2F1V/3F9Zr0YJ9gWH5/ZDo8f/gzAjb5frNS6nxqrqk1J/i7Ja7L0C1o/m+SmMcYj6zqxNTL9\nepILxxifq6ofTPJAkp/O0gdUPj3G+O11neA6qKrHk2wbYxycGfvNJF8bY7xviuVzxxi/ul5zXGvT\nz8n+JC/P0oeUbvh9o6pemeTpJB8dY/zINLbsfjD9T+8dWfow15cn+e9jjJev19wX4QTb49ok/2eM\n8WxV/bckmbbHpUn+9PDjNpoTbIv3Zpmfi6p6SZI7klyd5IeS/O8kLxpjHFrTSS/IctviuPvfn6Xf\n8PDrG32/WKkz4YjQ1Un2jDG+NMb4pyQfS3LDOs9pzYwxDowxPjctfyvJo0kuWt9ZnZFuSHLbtHxb\nlmKxk1cneWyMsXe9J7JWxhifSfK144ZPtB/ckKX/EYwxxv1Jzpn+kbFhLLc9xhifGmM8O928P0u/\n23HDO8G+cSI3JPnYGOM7Y4y/z9JvPbh6YZNbY8+1LaZfkfWGLIUgJ3AmhNBFSb48c3tfmobAVOsv\nTfLX09Dbp0Pet3Y4FTRjJPlUVT1QVdunsQtmfmfdV5JcsD5TWzc35ti/zLruGyfaD/w9kvxikntn\nbl9WVX9bVX9ZVa9Yr0mtseV+LjrvG69I8sQY44szYx33i+d0JoQQSarq7CR3JrlljPHNJB9M8sIk\nVyU5kOT96zi9tfYTY4yXJbk+ydumQ79HjKXzuW0ud6yq70/yuiR/NA113jeO6LYfPJeq+rUkzya5\nfRo6kOSSMcZLk/znJH9YVf96vea3RvxcfK+bcuw/oDruFyd1JoTQ/iQXz9x+3jTWRlV9X5Yi6PYx\nxh8nyRjjiTHGoTHGd5N8KBvoUO7JjDH2T9+fTHJXlv7sTxw+1TF9f3L9Zrjmrk/yuTHGE0nvfSMn\n3g/a/j1SVT+f5KeS/OwUh5lOA311Wn4gyWNJXrRuk1wDz/Fz0XLfqKrNSX4myccPj3XcL07FmRBC\nn01yeVVdNv3L98Ykd6/znNbMdA73w0keHWP8zsz47PsbXp/koeOfuxFV1ZbpTeOpqi1Jrs3Sn/3u\nJDdPD7s5ySfXZ4br4ph/1XXdNyYn2g/uTvJz09Vj12TpzaEHlnuBjaSqrkvyK0leN8b49sz4+dMb\n7FNVL0hyeZIvrc8s18Zz/FzcneTGqjqrqi7L0rb4m7We3zr4ySRfGGPsOzzQcb84FZvXewLT1Q5v\nT/LnSTYluXWM8fA6T2st/XiSNyX5/OFLHJO8O8lNVXVVlg79P57kLeszvTV3QZK7lvowm5P84Rjj\nz6rqs0k+UVVvTrI3S28A3PCmGHxNjv3v/5sd9o2quiPJq5Jsrap9Sd6T5H1Zfj+4J0tXjO1J8u0s\nXVm3oZxge7wryVlJ7pt+Zu4fY7w1ySuT/HpV/XOS7yZ56xjjVN9cfMY7wbZ41XI/F2OMh6vqE0ke\nydLpw7dtlCvGkuW3xRjjw/ne9xUmG3y/WKl1v3weAGC9nAmnxgAA1oUQAgDaEkIAQFtCCABoSwgB\nAG0JIQCgLSEEALQlhACAtv4/KbjyvSJfXEAAAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 720x720 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"out = lax.conv_with_general_padding(\n",
|
|
" jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n",
|
|
" jnp.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor\n",
|
|
" (1, 1), # window strides\n",
|
|
" ((2,2),(2,2)), # general padding 2x2\n",
|
|
" (1,1), # lhs/image dilation\n",
|
|
" (1,1)) # rhs/kernel dilation\n",
|
|
"print(\"out shape: \", out.shape)\n",
|
|
"print(\"First output channel:\")\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"plt.imshow(np.array(out)[0,0,:,:]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "lyOwGRez_ycJ"
|
|
},
|
|
"source": [
|
|
"### Dimension Numbers define dimensional layout for conv_general_dilated\n",
|
|
"\n",
|
|
"The important argument is the 3-tuple of axis layout arguments:\n",
|
|
"(Input Layout, Kernel Layout, Output Layout)\n",
|
|
" - __N__ - batch dimension\n",
|
|
" - __H__ - spatial height\n",
|
|
" - __W__ - spatial height\n",
|
|
" - __C__ - channel dimension\n",
|
|
" - __I__ - kernel _input_ channel dimension\n",
|
|
" - __O__ - kernel _output_ channel dimension\n",
|
|
"\n",
|
|
"⚠️ To demonstrate the flexibility of dimension numbers we choose a __NHWC__ image and __HWIO__ kernel convention for `lax.conv_general_dilated` below."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 155,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "oXKebfCb_i2B",
|
|
"outputId": "0243bbe8-ac5a-4923-8c6f-454a8d28f04b"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape\n",
|
|
" kernel.shape, # only ndim matters, not shape \n",
|
|
" ('NHWC', 'HWIO', 'NHWC')) # the important bit\n",
|
|
"print(dn)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "elZys_HzFVG6"
|
|
},
|
|
"source": [
|
|
"#### SAME padding, no stride, no dilation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 156,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 628
|
|
},
|
|
"colab_type": "code",
|
|
"id": "rgb2T15aFVG6",
|
|
"outputId": "2dae283f-21a6-4ca6-bf10-eaa247e579e7"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"out shape: (1, 200, 198, 3)\n",
|
|
"First output channel:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAJBCAYAAACqM9quAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAaP0lEQVR4nO3df6zld13n8dd7Z7SJUzdtd7pNhUIL\nKQ1qulXHSqIQdpHaEkPFP6CNUVSyhQSIzZooaLIQExNXrXZ1s5gSGiCpBVysNqZVuqwr2cQqU2zG\nQhmZYhumO7YdCmKLYe3w3j/mzHBnuMNM594zZ5j345Hc3O/5nB/fz3zzvdPnfD/n3FZ3BwBgon+1\n6gkAAKyKEAIAxhJCAMBYQggAGEsIAQBjCSEAYKylhVBVXV1Vu6tqT1W9bVn7AQA4WbWM3yNUVVuS\n/F2SVybZm+TjSa7v7k9t+s4AAE7S1iW97pVJ9nT3Z5Okqj6Q5Nok64bQlrO39dbzzlvSVACAyZ55\n8skceOrpWu++ZYXQc5J8bs3tvUl+4FgP3nreefmOn79xSVMBACb7vzfdfMz7VvZm6aq6oap2VtXO\nA089vappAACDLSuEHk1y0Zrbz12MHdbdt3T3ju7eseXsbUuaBgDAsS0rhD6e5NKquqSqvjXJdUnu\nXNK+AABOylLeI9Tdz1TVW5L8WZItSW7t7k8uY18AACdrWW+WTnffleSuZb0+AMBG+c3SAMBYQggA\nGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICx\nhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsI\nAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAA\nMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBj\nCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAY510CFXVRVX151X1qar6ZFX93GL8\nnVX1aFXdv/h61eZNFwBg82zdwHOfSfLz3f2Jqvr2JPdV1T2L+367u39z49MDAFiekw6h7t6XZN9i\n+5+q6sEkz9msiQEALNumvEeoqi5O8j1J/mox9Jaq2lVVt1bVuZuxDwCAzbbhEKqqs5N8OMmN3f2l\nJO9K8sIkV+TgFaObjvG8G6pqZ1XtPPDU0xudBgDAs7ahEKqqb8nBCLqtu/8wSbr7se4+0N1fTfLu\nJFeu99zuvqW7d3T3ji1nb9vINAAATspGPjVWSd6T5MHu/q014xeuedhrkjxw8tMDAFiejXxq7AeT\n/GSSv62q+xdjv5Tk+qq6IkkneTjJGzc0QwCAJdnIp8b+T5Ja5667Tn46AACnjt8sDQCMJYQAgLGE\nEAAwlhACAMbayKfGOIZzHjz4HvLtu5b/iyL3X/6138H0xRf30vcHAGcSV4QAgLGEEAAwlqWxJTi8\nJHbvrq8NvuTypexr7XLY+ZftT5I8sXv7UvYFAGcaV4QAgLFcEVqmNVeB9rzu25ayixuvuvvrxm7e\nfc1S9gUAZxpXhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhAC\nAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYa+uqJ3Am2n/5\ntiTJF1/ch8duvOrupezrrec+cnj7d7/w/KXsAwDOVK4IAQBjCSEAYCxLY0twaEns/Mv2L31fa5fD\nbnvk+5e+PwA4k7giBACMJYQAgLEsjS3RE7u3H96+efc1K5wJALAeV4QAgLGEEAAwlhACAMYSQgDA\nWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwl\nhACAsYQQADCWEAIAxhJCAMBYQggAGGvrRl+gqh5O8k9JDiR5prt3VNV5ST6Y5OIkDyd5bXd/YaP7\nAgDYTJt1Rejfd/cV3b1jcfttST7a3Zcm+ejiNgDAaWVZS2PXJnnfYvt9SX5sSfsBADhpmxFCneQj\nVXVfVd2wGLugu/cttv8hyQVHP6mqbqiqnVW188BTT2/CNAAAnp0Nv0coyQ9196NV9W+T3FNVn157\nZ3d3VfXRT+ruW5LckiRnPe+ir7sfAGDZNnxFqLsfXXx/PMkdSa5M8lhVXZgki++Pb3Q/AACbbUMh\nVFXbqurbD20nuSrJA0nuTPL6xcNen+SPN7IfAIBl2OjS2AVJ7qiqQ6/1+939p1X18SQfqqo3JHkk\nyWs3uB8AgE23oRDq7s8m+XfrjH8+ySs28toAAMvmN0sDAGMJIQBgLCEEAIwlhACAsYQQADCWEAIA\nxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAs\nIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBY\nQggAGEsIAQBjCSEAYCwhBACMtfVkn1hVlyX54JqhFyT5z0nOSfIfkzyxGP+l7r7rpGcIALAkJx1C\n3b07yRVJUlVbkjya5I4kP5Pkt7v7NzdlhgAAS7JZS2OvSPJQdz+ySa8HALB0mxVC1yW5fc3tt1TV\nrqq6tarOXe8JVXVDVe2sqp0Hnnp6k6YBAHDiNhxCVfWtSV6d5A8WQ+9K8sIcXDbbl+Sm9Z7X3bd0\n947u3rHl7G0bnQYAwLO2GVeErknyie5+LEm6+7HuPtDdX03y7iRXbsI+AAA23WaE0PVZsyxWVReu\nue81SR7YhH0AAGy6k/7UWJJU1bYkr0zyxjXDv15VVyTpJA8fdR8AwGljQyHU3U8n+TdHjf3khmYE\nAHCK+M3SAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkh\nAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIA\nxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAw1tZVTwBYnnMerMPb23c9vdR97b982+Ht\nL764l7ovgM3iihAAMJYQAgDGsjQGZ7AjlsPu3XXw+0suX8q+1i6HnX/Z/sPbT+zevpT9AWwGV4QA\ngLGEEAAwlqUxmGKxJLbndd+2lJe/8aq71x2/efc1S9kfwGZwRQgAGEsIAQBjCSEAYCwhBACMJYQA\ngLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAY\nSwgBAGMJIQBgLCEEAIy1ddUTAJZn/+XbDm9/8cWdJLnxqruXsq+3nvvI4e3f/cLzl7IPgM12QleE\nqurWqnq8qh5YM3ZeVd1TVZ9ZfD93MV5V9TtVtaeqdlXV9y5r8gAAG3GiS2PvTXL1UWNvS/LR7r40\nyUcXt5PkmiSXLr5uSPKujU8TAGDzndDSWHd/rKouPmr42iQvX2y/L8n/TvKLi/H3d3cnubeqzqmq\nC7t732ZMGDhxh5bDkuT8y/YvdV9rl8Nue+T7l7ovgM2ykTdLX7Ambv4hyQWL7eck+dyax+1djB2h\nqm6oqp1VtfPAU09vYBoAACdnUz41trj608d94JHPuaW7d3T3ji1nbzv+EwAANtlGPjX22KElr6q6\nMMnji/FHk1y05nHPXYwBK/TE7u1Jkpt3X7PimQCcPjZyRejOJK9fbL8+yR+vGf+pxafHXpLkH70/\nCAA4HZ3QFaGquj0H3xi9var2JnlHkl9L8qGqekOSR5K8dvHwu5K8KsmeJF9O8jObPGcAgE1xop8a\nu/4Yd71incd2kjdvZFIAAKeC/8UGADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGE\nEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgB\nAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAw\nlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJ\nIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhAC\nAMYSQgDAWMcNoaq6taoer6oH1oz9RlV9uqp2VdUdVXXOYvziqvrnqrp/8fV7y5w8AMBGnMgVofcm\nufqosXuSfHd3X57k75K8fc19D3X3FYuvN23ONAEANt9xQ6i7P5bkyaPGPtLdzyxu3pvkuUuYGwDA\nUm3Ge4R+Nsnda25fUlV/U1V/UVUvPdaTquqGqtpZVTsPPPX0JkwDAODZ2bqRJ1fVLyd5Jslti6F9\nSZ7X3Z+vqu9L8kdV9V3d/aWjn9vdtyS5JUnOet5FvZF5AACcjJO+IlRVP53kR5P8RHd3knT3V7r7\n84vt+5I8lORFmzBPAIBNd1IhVFVXJ/mFJK/u7i+vGT+/qrYstl+Q5NIkn92MiQIAbLbjLo1V1e1J\nXp5ke1XtTfKOHPyU2FlJ7qmqJLl38QmxlyX5lar6lyRfTfKm7n5y3RcGAFix44ZQd1+/zvB7jvHY\nDyf58EYnBQBwKvjN0gDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAY\nSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGE\nEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgB\nAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAw\nlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYa+uqJwDA\nme+cB+vw9vZdTy99f/sv35Yk+eKLe+n74pvbca8IVdWtVfV4VT2wZuydVfVoVd2/+HrVmvveXlV7\nqmp3Vf3IsiYOALBRJ7I09t4kV68z/tvdfcXi664kqarvTHJdku9aPOe/V9WWzZosAMBmOu7SWHd/\nrKouPsHXuzbJB7r7K0n+vqr2JLkyyV+e9AwB+KZ3xHLYvbu+tv2Sy5eyv0NLYudftv/w2BO7ty9l\nX3xz28ibpd9SVbsWS2fnLsaek+Rzax6zdzH2darqhqraWVU7Dzy1/PViAICjnWwIvSvJC5NckWRf\nkpue7Qt09y3dvaO7d2w5e9tJTgMA4OSd1KfGuvuxQ9tV9e4kf7K4+WiSi9Y89LmLMQA4aM1y2J7X\nfdtSdnHjVXd/3djNu69Zyr745nZSV4Sq6sI1N1+T5NAnyu5Mcl1VnVVVlyS5NMlfb2yKAADLcdwr\nQlV1e5KXJ9leVXuTvCPJy6vqiiSd5OEkb0yS7v5kVX0oyaeSPJPkzd19YDlTBwDYmBP51Nj16wy/\n5xs8/leT/OpGJgUAcCr4X2wAAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBj\nCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAY21d9QQAOPPtv3zb4e0vvrgPb994\n1d1L2d9bz30kSfK7X3j+Ul6fM4crQgDAWEIIABjL0hgAS7d2Oez8y/YvfX+HlsRue+T7l74vvrm5\nIgQAjCWEAICxLI0BcEo9sXv74e2bd1+zwpmAK0IAwGBCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQ\nADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEA\nYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCW\nEAIAxhJCAMBYQggAGEsIAQBjHTeEqurWqnq8qh5YM/bBqrp/8fVwVd2/GL+4qv55zX2/t8zJAwBs\nxNYTeMx7k/y3JO8/NNDdrzu0XVU3JfnHNY9/qLuv2KwJAgAsy3FDqLs/VlUXr3dfVVWS1yb5D5s7\nLQCA5dvoe4RemuSx7v7MmrFLqupvquovquqlx3piVd1QVTuraueBp57e4DQAAJ69E1ka+0auT3L7\nmtv7kjyvuz9fVd+X5I+q6ru6+0tHP7G7b0lyS5Kc9byLeoPzAAB41k76ilBVbU3y40k+eGisu7/S\n3Z9fbN+X5KEkL9roJAEAlmEjS2M/nOTT3b330EBVnV9VWxbbL0hyaZLPbmyKAADLcSIfn789yV8m\nuayq9lbVGxZ3XZcjl8WS5GVJdi0+Tv8/krypu5/czAkDAGyWE/nU2PXHGP/pdcY+nOTDG58WAMDy\n+c3SAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAs\nIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBY\nQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjFXd\nveo5pKqeSPJIku1J9q94OqcTx+NIjseRHI+vcSyO5HgcyfE40sTj8fzuPn+9O06LEDqkqnZ2945V\nz+N04XgcyfE4kuPxNY7FkRyPIzkeR3I8jmRpDAAYSwgBAGOdbiF0y6oncJpxPI7keBzJ8fgax+JI\njseRHI8jOR5rnFbvEQIAOJVOtytCAACnzGkRQlV1dVXtrqo9VfW2Vc/nVKuqi6rqz6vqU1X1yar6\nucX4O6vq0aq6f/H1qlXP9VSpqoer6m8Xf+6di7HzquqeqvrM4vu5q57nqVBVl605B+6vqi9V1Y2T\nzo+qurWqHq+qB9aMrXs+1EG/s/j7ZFdVfe/qZr4cxzgev1FVn178me+oqnMW4xdX1T+vOU9+b3Uz\nX45jHI9j/nxU1dsX58fuqvqR1cx6eY5xPD645lg8XFX3L8bP+PPjeFa+NFZVW5L8XZJXJtmb5ONJ\nru/uT610YqdQVV2Y5MLu/kRVfXuS+5L8WJLXJnmqu39zpRNcgap6OMmO7t6/ZuzXkzzZ3b+2COZz\nu/sXVzXHVVj8vDya5AeS/EyGnB9V9bIkTyV5f3d/92Js3fNh8R+8tyZ5VQ4ep//a3T+wqrkvwzGO\nx1VJ/ld3P1NV/yVJFsfj4iR/cuhxZ6JjHI93Zp2fj6r6ziS3J7kyyXck+Z9JXtTdB07ppJdoveNx\n1P03JfnH7v6VCefH8ZwOV4SuTLKnuz/b3f8vyQeSXLviOZ1S3b2vuz+x2P6nJA8mec5qZ3VaujbJ\n+xbb78vBWJzmFUke6u5HVj2RU6m7P5bkyaOGj3U+XJuD/wHo7r43yTmLf2ycMdY7Ht39ke5+ZnHz\n3iTPPeUTW5FjnB/Hcm2SD3T3V7r775PsycH/Dp0xvtHxqKrKwX9k335KJ3UaOx1C6DlJPrfm9t4M\njoBFnX9Pkr9aDL1lcan71ilLQQud5CNVdV9V3bAYu6C79y22/yHJBauZ2kpdlyP/Apt6fiTHPh/8\nnZL8bJK719y+pKr+pqr+oqpeuqpJrcB6Px/Tz4+XJnmsuz+zZmzq+ZHk9AghFqrq7CQfTnJjd38p\nybuSvDDJFUn2JblphdM71X6ou783yTVJ3ry41HtYH1zTHfWRx6r61iSvTvIHi6HJ58cRJp4Px1JV\nv5zkmSS3LYb2JXled39Pkv+U5Per6l+van6nkJ+P9V2fI/8xNfX8OOx0CKFHk1y05vZzF2OjVNW3\n5GAE3dbdf5gk3f1Ydx/o7q8meXfOsMu330h3P7r4/niSO3Lwz/7YoSWOxffHVzfDlbgmySe6+7Fk\n9vmxcKzzYezfKVX100l+NMlPLOIwiyWgzy+270vyUJIXrWySp8g3+PmYfH5sTfLjST54aGzq+bHW\n6RBCH09yaVVdsvgX73VJ7lzxnE6pxZrte5I82N2/tWZ87fsaXpPkgaOfeyaqqm2LN42nqrYluSoH\n/+x3Jnn94mGvT/LHq5nhyhzxL7mp58caxzof7kzyU4tPj70kB98Uum+9FziTVNXVSX4hyau7+8tr\nxs9fvMk+VfWCJJcm+exqZnnqfIOfjzuTXFdVZ1XVJTl4PP76VM9vRX44yae7e++hgannx1pbVz2B\nxScc3pLkz5JsSXJrd39yxdM61X4wyU8m+dtDH2lM8ktJrq+qK3Lwkv/DSd64mumdchckueNgH2Zr\nkt/v7j+tqo8n+VBVvSHJIzn4hr8RFkH4yhx5Dvz6lPOjqm5P8vIk26tqb5J3JPm1rH8+3JWDnxjb\nk+TLOfjpujPKMY7H25OcleSexc/Ovd39piQvS/IrVfUvSb6a5E3dfaJvLP6mcIzj8fL1fj66+5NV\n9aEkn8rBJcQ3n0mfGEvWPx7d/Z58/XsMkwHnx/Gs/OPzAACrcjosjQEArIQQAgDGEkIAwFhCCAAY\nSwgBAGMJIQBgLCEEAIwlhACAsf4/QXaJowhPRekAAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 720x720 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
|
|
" kernel, # rhs = conv kernel tensor\n",
|
|
" (1,1), # window strides\n",
|
|
" 'SAME', # padding mode\n",
|
|
" (1,1), # lhs/image dilation\n",
|
|
" (1,1), # rhs/kernel dilation\n",
|
|
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
|
|
"print(\"out shape: \", out.shape)\n",
|
|
"print(\"First output channel:\")\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"plt.imshow(np.array(out)[0,:,:,0]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "E4i3TI5JFVG9"
|
|
},
|
|
"source": [
|
|
"#### VALID padding, no stride, no dilation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 157,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 628
|
|
},
|
|
"colab_type": "code",
|
|
"id": "1HQwudKVFVG-",
|
|
"outputId": "a141ffd2-9c7c-4633-b752-7cd345632fdf"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"out shape: (1, 198, 196, 3) DIFFERENT from above!\n",
|
|
"First output channel:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAJBCAYAAACqM9quAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAaOklEQVR4nO3df6zld13n8dd7Z5TEqZu2O92mQqGF\nlAY13apjJVEIu0htiaHiH9DGKCrZQgLEZk0UNFmIiYmrVru6WUwJDZDUAorVZtMqXdaVbGKVKTZj\noYy02IZ2a9uhILYY1g7v/WPODGc6d+h07j1zhnk/HsnN/Z7P+fH9zDffO33O98dtdXcAACb6V+ue\nAADAugghAGAsIQQAjCWEAICxhBAAMJYQAgDGWlkIVdVlVbW3qu6tqrevaj0AAMerVvF7hKpqW5K/\nS/KqJA8m+USSq7r701u+MgCA47R9RZ97SZJ7u/tzSVJVH0xyRZINQ2jbaTt6+5lnrmgqAMBkTz3+\nePY/8WRt9NyqQui5ST6/9PjBJD9wtBdvP/PMfMfPX7OiqQAAk/3fa6876nNru1i6qq6uqt1VtXv/\nE0+uaxoAwGCrCqGHkpy79Ph5i7FDuvv67t7V3bu2nbZjRdMAADi6VYXQJ5JcUFXnV9W3JrkyyS0r\nWhcAwHFZyTVC3f1UVb01yZ8l2Zbkhu7+1CrWBQBwvFZ1sXS6+9Ykt67q8wEANstvlgYAxhJCAMBY\nQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWE\nAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggA\nGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICx\nhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsI\nAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYx13CFXVuVX151X16ar6VFX9\n3GL8XVX1UFXdtfh69dZNFwBg62zfxHufSvLz3f3Jqvr2JHdW1e2L5367u39z89MDAFid4w6h7n44\nycOL5X+qqnuSPHerJgYAsGpbco1QVZ2X5HuS/NVi6K1VtaeqbqiqM47ynqurandV7d7/xJNbMQ0A\ngGdl0yFUVacl+UiSa7r7y0neneRFSS7OgSNG1270vu6+vrt3dfeubaft2Ow0AACetU2FUFV9Sw5E\n0I3d/UdJ0t2PdPf+7v5akvckuWTz0wQA2HqbuWuskrw3yT3d/VtL4+csvey1Se4+/ukBAKzOZu4a\n+8EkP5nkb6vqrsXYLyW5qqouTtJJ7k/ypk3NEABgRTZz19j/SVIbPHXr8U8HAODE8ZulAYCxNnNq\njCWn3/P1g2M796z+1wHsu+jrd9p96SW98vUBwKnIESEAYCwhBACM5dTYFjnsdNgde76+/NKLVrK+\n5dNhZ124L0ny2N6dK1kXAJyqHBECAMZyRGgVlo4C3fv6b1vJKq659LYjxq7be/lK1gUApypHhACA\nsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhL\nCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYa/u6J3Cq2HfRjkPLX3pJH1q+5tLb\nVrK+t53xwKHl3/3iC1ayDgA41TkiBACMJYQAgLGcGtsiy6fDzrpw38rXt3w67MYHvn/l6wOAU5Ej\nQgDAWEIIABjLqbEVeGzvzkPL1+29fI0zAQC+EUeEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJ\nIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhAC\nAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjLV9sx9QVfcn+ack+5M81d27qurMJB9Kcl6S+5O8rru/\nuNl1AQBspa06IvTvu/vi7t61ePz2JB/r7guSfGzxGADgpLKqU2NXJHn/Yvn9SX5sResBADhuWxFC\nneSjVXVnVV29GDu7ux9eLP9DkrO3YD0AAFtq09cIJfmh7n6oqv5tktur6jPLT3Z3V1U//U2LaLo6\nSbadccYWTAMA4NnZ9BGh7n5o8f3RJDcnuSTJI1V1TpIsvj+6wfuu7+5d3b1r22k7NjsNAIBnbVMh\nVFU7qurbDy4nuTTJ3UluSfKGxcvekORPNrMeAIBV2OypsbOT3FxVBz/r97v7T6vqE0k+XFVvTPJA\nktdtcj0AAFtuUyHU3Z9L8u82GP9Ckldu5rMBAFbNb5YGAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBY\nQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWE\nAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggA\nGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICx\nhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGNtP943VtWFST60NPTCJP85yelJ/mOSxxbjv9Tdtx73DAEA\nVuS4Q6i79ya5OEmqaluSh5LcnORnkvx2d//mlswQAGBFturU2CuT3NfdD2zR5wEArNxWhdCVSW5a\nevzWqtpTVTdU1RkbvaGqrq6q3VW1e/8TT27RNAAAjt2mQ6iqvjXJa5L8wWLo3UlelAOnzR5Ocu1G\n7+vu67t7V3fv2nbajs1OAwDgWduKI0KXJ/lkdz+SJN39SHfv7+6vJXlPkku2YB0AAFtuK0Loqiyd\nFquqc5aee22Su7dgHQAAW+647xpLkqrakeRVSd60NPzrVXVxkk5y/9OeAwA4aWwqhLr7yST/5mlj\nP7mpGQEAnCB+szQAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGNtX/cEgNU4/Z46tLxzz5MrX9++i3YcWv7SS3rl\n6wPYCo4IAQBjCSEAYCynxuAUddjpsDv2fH35pRetZH3Lp8POunBfkuSxvTtXsi6AreKIEAAwliNC\nMMHSUaB7X/9tK1nFNZfedsTYdXsvX8m6ALaKI0IAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQ\nADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEA\nYwkhAGAsIQQAjLV93RMAVmPfRTsOLX/pJX1o+ZpLb1vJ+t52xgOHln/3iy9YyToAtpojQgDAWEII\nABjLqTE4RS2fDjvrwn0rX9/y6bAbH/j+la8PYCsc0xGhqrqhqh6tqruXxs6sqtur6rOL72csxquq\nfqeq7q2qPVX1vauaPADAZhzrqbH3JbnsaWNvT/Kx7r4gyccWj5Pk8iQXLL6uTvLuzU8TAGDrHdOp\nse7+eFWd97ThK5K8YrH8/iT/O8kvLsY/0N2d5I6qOr2qzunuh7diwsCz99jenYeWr9t7+RpnAnBy\n2czF0mcvxc0/JDl7sfzcJJ9fet2Di7HDVNXVVbW7qnbvf+LJTUwDAOD4bMldY4ujP/2MLzz8Pdd3\n967u3rXttB3P/AYAgC22mRB6pKrOSZLF90cX4w8lOXfpdc9bjAEAnFQ2E0K3JHnDYvkNSf5kafyn\nFnePvTTJP7o+CAA4GR3TxdJVdVMOXBi9s6oeTPLOJL+W5MNV9cYkDyR53eLltyZ5dZJ7k3wlyc9s\n8ZwBALbEsd41dtVRnnrlBq/tJG/ZzKQAAE4E/4sNAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhC\nCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQA\ngLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAY\nSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGE\nEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgB\nAGMJIQBgLCEEAIwlhACAsZ4xhKrqhqp6tKruXhr7jar6TFXtqaqbq+r0xfh5VfXPVXXX4uv3Vjl5\nAIDNOJYjQu9LctnTxm5P8t3dfVGSv0vyjqXn7uvuixdfb96aaQIAbL1nDKHu/niSx5829tHufmrx\n8I4kz1vB3AAAVmorrhH62SS3LT0+v6r+pqr+oqpetgWfDwCwEts38+aq+uUkTyW5cTH0cJLnd/cX\nqur7kvxxVX1Xd395g/deneTqJNl2xhmbmQYAwHE57iNCVfXTSX40yU90dydJd3+1u7+wWL4zyX1J\nXrzR+7v7+u7e1d27tp2243inAQBw3I4rhKrqsiS/kOQ13f2VpfGzqmrbYvmFSS5I8rmtmCgAwFZ7\nxlNjVXVTklck2VlVDyZ5Zw7cJfacJLdXVZLcsbhD7OVJfqWq/iXJ15K8ubsf3/CDAQDW7BlDqLuv\n2mD4vUd57UeSfGSzkwIAOBH8ZmkAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkh\nAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIA\nxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAs\nIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGCs7eue\nAAAznH5PHVreuefJla9v30U7Di1/6SW98vXxzckRIQBgLEeEADghDjsKdMeery+/9KKVrG/5KNBZ\nF+5Lkjy2d+dK1sU3r2c8IlRVN1TVo1V199LYu6rqoaq6a/H16qXn3lFV91bV3qr6kVVNHABgs47l\n1Nj7kly2wfhvd/fFi69bk6SqvjPJlUm+a/Ge/15V27ZqsgAAW+kZT41198er6rxj/Lwrknywu7+a\n5O+r6t4klyT5y+OeIQCnnqXTYfe+/ttWsoprLr3tiLHr9l6+knXxzWszF0u/tar2LE6dnbEYe26S\nzy+95sHF2BGq6uqq2l1Vu/c/sfq7BwAAnu54Q+jdSV6U5OIkDye59tl+QHdf3927unvXttN2PPMb\nAAC22HGFUHc/0t37u/trSd6TA6e/kuShJOcuvfR5izEAgJPOcYVQVZ2z9PC1SQ7eUXZLkiur6jlV\ndX6SC5L89eamCACwGs94sXRV3ZTkFUl2VtWDSd6Z5BVVdXGSTnJ/kjclSXd/qqo+nOTTSZ5K8pbu\n3r+aqQMAbM6x3DV21QbD7/0Gr//VJL+6mUkBAJwI/hcbAMBYQggAGEsIAQBjCSEAYCwhBACMJYQA\ngLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMNb2dU8AgBn2XbTj0PKX\nXtKHlq+59LaVrO9tZzxwaPl3v/iClayDb36OCAEAYwkhAGAsp8YAOCGWT4eddeG+la9v+XTYjQ98\n/8rXxzcnR4QAgLGEEAAwllNjAJxwj+3deWj5ur2Xr3EmTOeIEAAwlhACAMYSQgDAWEIIABhLCAEA\nYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCW\nEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkh\nAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgrGcMoaq6oaoeraq7l8Y+VFV3\nLb7ur6q7FuPnVdU/Lz33e6ucPADAZmw/hte8L8l/S/KBgwPd/fqDy1V1bZJ/XHr9fd198VZNEABg\nVZ4xhLr741V13kbPVVUleV2S/7C10wIAWL3NXiP0siSPdPdnl8bOr6q/qaq/qKqXHe2NVXV1Ve2u\nqt37n3hyk9MAAHj2juXU2DdyVZKblh4/nOT53f2Fqvq+JH9cVd/V3V9++hu7+/ok1yfJc55/bm9y\nHgAAz9pxHxGqqu1JfjzJhw6OdfdXu/sLi+U7k9yX5MWbnSQAwCps5tTYDyf5THc/eHCgqs6qqm2L\n5RcmuSDJ5zY3RQCA1TiW2+dvSvKXSS6sqger6o2Lp67M4afFkuTlSfYsbqf/wyRv7u7Ht3LCAABb\n5VjuGrvqKOM/vcHYR5J8ZPPTAgBYPb9ZGgAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBY\nQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWE\nAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBY1d3r\nnkOq6rEkDyTZmWTfmqdzMrE9jmSbHMk2OZztcSTb5Ei2yeFO9e3xgu4+a6MnTooQOqiqdnf3rnXP\n42RhexzJNjmSbXI42+NItsmRbJPDTd4eTo0BAGMJIQBgrJMthK5f9wROMrbHkWyTI9kmh7M9jmSb\nHMk2OdzY7XFSXSMEAHAinWxHhAAATpiTIoSq6rKq2ltV91bV29c9n3WoqnOr6s+r6tNV9amq+rnF\n+Luq6qGqumvx9ep1z/VEqar7q+pvF3/u3YuxM6vq9qr67OL7Geue54lSVRcu7Qd3VdWXq+qaaftI\nVd1QVY9W1d1LYxvuF3XA7yz+btlTVd+7vpmvxlG2x29U1WcWf+abq+r0xfh5VfXPS/vK761v5qtz\nlG1y1J+TqnrHYh/ZW1U/sp5Zr9ZRtsmHlrbH/VV112J8xH5y0NpPjVXVtiR/l+RVSR5M8okkV3X3\np9c6sROsqs5Jck53f7Kqvj3JnUl+LMnrkjzR3b+51gmuQVXdn2RXd+9bGvv1JI93968tovmM7v7F\ndc1xXRY/Nw8l+YEkP5NB+0hVvTzJE0k+0N3fvRjbcL9Y/MfubUlenQPb6r929w+sa+6rcJTtcWmS\n/9XdT1XVf0mSxfY4L8n/OPi6U9VRtsm7ssHPSVV9Z5KbklyS5DuS/M8kL+7u/Sd00iu20TZ52vPX\nJvnH7v6VKfvJQSfDEaFLktzb3Z/r7v+X5INJrljznE647n64uz+5WP6nJPckee56Z3VSuiLJ+xfL\n78+BWJzolUnu6+4H1j2RE627P57k8acNH22/uCIH/uLv7r4jyemLf3ScMjbaHt390e5+avHwjiTP\nO+ETW6Oj7CNHc0WSD3b3V7v775PcmwP/XTqlfKNtUlWVA//ovumETuokcTKE0HOTfH7p8YMZHgCL\nGv+eJH+1GHrr4hD3DZNOBSXpJB+tqjur6urF2Nnd/fBi+R+SnL2eqa3dlTn8L62p+8hBR9sv/P2S\n/GyS25Yen19Vf1NVf1FVL1vXpNZko58T+0jysiSPdPdnl8bG7CcnQwixpKpOS/KRJNd095eTvDvJ\ni5JcnOThJNeucXon2g919/cmuTzJWxaHdg/pA+d1x932WFXfmuQ1Sf5gMTR5HznC1P1iI1X1y0me\nSnLjYujhJM/v7u9J8p+S/H5V/et1ze8E83NydFfl8H9YjdpPToYQeijJuUuPn7cYG6eqviUHIujG\n7v6jJOnuR7p7f3d/Lcl7cgoesj2a7n5o8f3RJDfnwJ/9kYOnNhbfH13fDNfm8iSf7O5Hktn7yJKj\n7Rdj/36pqp9O8qNJfmIRh1mc/vnCYvnOJPclefHaJnkCfYOfk7H7SJJU1fYkP57kQwfHpu0nJ0MI\nfSLJBVV1/uJfulcmuWXNczrhFudo35vknu7+raXx5esZXpvk7qe/91RUVTsWF42nqnYkuTQH/uy3\nJHnD4mVvSPIn65nhWh32r7ep+8jTHG2/uCXJTy3uHntpDlwM+vBGH3AqqarLkvxCktd091eWxs9a\nXGifqnphkguSfG49szyxvsHPyS1Jrqyq51TV+TmwTf76RM9vjX44yWe6+8GDA9P2k+3rnsDiroa3\nJvmzJNuS3NDdn1rztNbhB5P8ZJK/PXgLY5JfSnJVVV2cA4f670/ypvVM74Q7O8nNB/ow25P8fnf/\naVV9IsmHq+qNSR7IgQv8xlhE4aty+H7w65P2kaq6KckrkuysqgeTvDPJr2Xj/eLWHLhj7N4kX8mB\nO+xOKUfZHu9I8pwkty9+hu7o7jcneXmSX6mqf0nytSRv7u5jvaj4m8ZRtskrNvo56e5PVdWHk3w6\nB04jvuVUu2Ms2XibdPd7c+T1hsmQ/eSgtd8+DwCwLifDqTEAgLUQQgDAWEIIABhLCAEAYwkhAGAs\nIQQAjCWEAICxhBAAMNb/BwN3joMOviDcAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 720x720 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
|
|
" kernel, # rhs = conv kernel tensor\n",
|
|
" (1,1), # window strides\n",
|
|
" 'VALID', # padding mode\n",
|
|
" (1,1), # lhs/image dilation\n",
|
|
" (1,1), # rhs/kernel dilation\n",
|
|
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
|
|
"print(\"out shape: \", out.shape, \"DIFFERENT from above!\")\n",
|
|
"print(\"First output channel:\")\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"plt.imshow(np.array(out)[0,:,:,0]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "VYKZdqLIFVHB"
|
|
},
|
|
"source": [
|
|
"#### SAME padding, 2,2 stride, no dilation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 158,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 627
|
|
},
|
|
"colab_type": "code",
|
|
"id": "mKq2-zmmFVHC",
|
|
"outputId": "065e2f69-3f1d-4d19-864d-28ef59f1b0f8"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"out shape: (1, 100, 99, 3) <-- half the size of above\n",
|
|
"First output channel:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAJACAYAAACJ77wgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAUcUlEQVR4nO3dXYymZ13H8d/fDi3sqrRF2dS2yhqI\nhpgodkt4scRsNUEgdg8I0qgppqYnKigaWz0hJphoQng5MJiGoj0gBVJJ2xijIbQk9aTptiUptCpN\nkXY3fTNYMK0BJv49mEdZ2Ck7OzPPvPzz+ZzM3NfzzFxXcufefPd+rueZ6u4AAEz2A7u9AACAZRM8\nAMB4ggcAGE/wAADjCR4AYDzBAwCMt6Xgqao3V9W/VtUjVXXDdi0KAGA71WY/h6eqzknyb0l+OcmJ\nJPcmubq7H3qhn1k5cLBf9NILNzUfAMD38+2vfy2rzz9X6z22soXf+9okj3T3o0lSVZ9MclWSFwye\nF730whx+13u3MCUAwPq+8rcffMHHtvKS1sVJHj/l+MRiDABgT1n6puWquq6qjlfV8dXnn1v2dAAA\np9lK8JxMcukpx5csxr5Ld9/Y3Ue6+8jKgYNbmA4AYHO2Ejz3JnlVVR2uqnOTvDPJHduzLACA7bPp\nTcvdvVpVv5vkn5Kck+Tj3f2lbVsZAMA22cq7tNLd/5DkH7ZpLQAAS+GTlgGA8QQPADCe4AEAxhM8\nAMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMA\njCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8AMB4\nggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMAjCd4\nAIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8AMB4ggcA\nGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMAjCd4AIDx\nBA8AMJ7gAQDGEzwAwHgru72A/ebQvd88bWzlzvuWMtfq0cvWHX/q8vOWMh8ATOUODwAwnuABAMYT\nPADAeIIHABhP8AAA43mX1lla7x1ZJ69/w1LmOnb13euO33bLFUuZDwCmcocHABhP8AAA4wkeAGA8\nwQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8\nAMB4ggcAGE/wAADjCR4AYDzBAwCMt7LbC9hvVo9edtrYsavvXspc73/5g+uO35YrljIfAEzlDg8A\nMJ7gAQDGEzwAwHiCBwAY74zBU1WXVtVdVfVQVX2pqt6zGL+wqj5bVV9efL1g+csFADh71d3f/wlV\nFyW5qLvvr6ofSnJfkmNJ3pXka939F1V1Q5ILuvv67/e7XnLRpX34Xe/dnpUDAJziK3/7wfz3E4/X\neo+d8Q5Pdz/R3fcvvv+vJA8nuTjJVUluXjzt5qxFEADAnnNWe3iq6hVJXpPkniSHuvuJxUNPJjm0\nrSsDANgmGw6eqvrBJH+X5Pe7+xunPtZrr4ut+9pYVV1XVcer6vjq889tabEAAJuxoeCpqhdlLXY+\n0d2fWQw/tdjf83/7fJ5e72e7+8buPtLdR1YOHNyONQMAnJWNvEurktyU5OHu/uApD92R5JrF99ck\nuX37lwcAsHUb+Vtab0zym0kerKovLMb+NMlfJPl0VV2b5KtJ3rGcJQIAbM0Zg6e7/znJum/xSnLl\n9i4HAGD7+aRlAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA\n8QQPADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP\n8AAA4wkeAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQP\nADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA\n4wkeAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe\n4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wke\nAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEA\nxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8\nwQMAjCd4AIDxBA8AMN6Gg6eqzqmqB6rq7xfHh6vqnqp6pKo+VVXnLm+ZAACbdzZ3eN6T5OFTjv8y\nyYe6+5VJ/jPJtdu5MACA7bKh4KmqS5K8NcnHFseV5GiSWxdPuTnJsWUsEABgqzZ6h+fDSf44yf8s\njl+W5NnuXl0cn0hy8Xo/WFXXVdXxqjq++vxzW1osAMBmnDF4quptSZ7u7vs2M0F339jdR7r7yMqB\ng5v5FQAAW7Kygee8McmvVtVbkrw4yQ8n+UiS86tqZXGX55IkJ5e3TACAzTvjHZ7u/pPuvqS7X5Hk\nnUnu7O5fT3JXkrcvnnZNktuXtkoAgC3YyufwXJ/kvVX1SNb29Ny0PUsCANheG3lJ6/919+eTfH7x\n/aNJXrv9SwIA2F4+aRkAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA\n4wkeAGA8wQMAjCd4AIDxzuqvpQN7x6F7v7nu+Mqd9y1lvtWjl5029tTl5y1lLoDt5g4PADCe4AEA\nxhM8AMB4ggcAGE/wAADjeZcW7FMv9G6sk9e/YSnzHbv67tPGbrvliqXMBbDd3OEBAMYTPADAeIIH\nABhP8AAA4wkeAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA\n8QQPADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYLyV3V4AsDmrRy9bd/zY1XcvZb73v/zB08ZuyxVL\nmQtgu7nDAwCMJ3gAgPEEDwAwnuABAMazaRn2qacuP2/d8dtuWc5GYhuUgf3MHR4AYDzBAwCMJ3gA\ngPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAY\nT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEE\nDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AA\nAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAw\n3oaCp6rOr6pbq+pfqurhqnp9VV1YVZ+tqi8vvl6w7MUCAGzGRu/wfCTJP3b3Tyf52SQPJ7khyee6\n+1VJPrc4BgDYc84YPFX10iRvSnJTknT3t7r72SRXJbl58bSbkxxb1iIBALZiI3d4Did5JsnfVNUD\nVfWxqjqY5FB3P7F4zpNJDi1rkQAAW7GR4FlJ8vNJPtrdr0nyXL7n5avu7iS93g9X1XVVdbyqjq8+\n/9xW1wsAcNY2Ejwnkpzo7nsWx7dmLYCeqqqLkmTx9en1fri7b+zuI919ZOXAwe1YMwDAWTlj8HT3\nk0ker6qfWgxdmeShJHckuWYxdk2S25eyQgCALVrZ4PN+L8knqurcJI8m+a2sxdKnq+raJF9N8o7l\nLBEAYGs2FDzd/YUkR9Z56MrtXQ4AwPbzScsAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe\n4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wke\nAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEA\nxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8\nwQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8\nAMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMA\njLey2wsAYP87dO83TxtbufO+pcy1evSy08aeuvy8pczFHO7wAADjCR4AYDzBAwCMJ3gAgPFsWgZg\ny9bboHzy+jcsZa5jV9992thtt1yxlLmYwx0eAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAY\nT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEE\nDwAwnuABAMZb2e0FALD/rR697LSxY1ffvZS53v/yB08buy1XLGUu5nCHBwAYT/AAAOMJHgBgvA0F\nT1X9QVV9qaq+WFW3VNWLq+pwVd1TVY9U1aeq6txlLxYAYDPOuGm5qi5O8u4kr+7u/66qTyd5Z5K3\nJPlQd3+yqv46ybVJPrrU1QKwJz11+Xmnjd12y3I2EtugzGZs9CWtlSQvqaqVJAeSPJHkaJJbF4/f\nnOTY9i8PAGDrzhg83X0yyQeSPJa10Pl6kvuSPNvdq4unnUhy8bIWCQCwFWcMnqq6IMlVSQ4n+bEk\nB5O8eaMTVNV1VXW8qo6vPv/cphcKALBZG3lJ65eSfKW7n+nubyf5TJI3Jjl/8RJXklyS5OR6P9zd\nN3b3ke4+snLg4LYsGgDgbGwkeB5L8rqqOlBVleTKJA8luSvJ2xfPuSbJ7ctZIgDA1mxkD889Wduc\nfH+SBxc/c2OS65O8t6oeSfKyJDctcZ0AAJu2ob+l1d3vS/K+7xl+NMlrt31FAADbzCctAwDjCR4A\nYDzBAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMAjCd4AIDxBA8AMJ7gAQDG\nEzwAwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzB\nAwCMJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwA\nwHiCBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCM\nJ3gAgPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiC\nBwAYT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gA\ngPEEDwAwnuABAMYTPADAeIIHABhP8AAA4wkeAGA8wQMAjCd4AIDxBA8AMJ7gAQDGEzwAwHiCBwAY\nT/AAAOMJHgBgPMEDAIwneACA8QQPADCe4AEAxhM8AMB4ggcAGE/wAADjCR4AYDzBAwCMJ3gAgPEE\nDwAwXnX3zk1W9UySry4OfyTJf+zY5GyV87X/OGf7j3O2vzhfe89PdPePrvfAjgbPd01cdby7j+zK\n5Jw152v/cc72H+dsf3G+9hcvaQEA4wkeAGC83QyeG3dxbs6e87X/OGf7j3O2vzhf+8iu7eEBANgp\nXtICAMbb8eCpqjdX1b9W1SNVdcNOz8+ZVdWlVXVXVT1UVV+qqvcsxi+sqs9W1ZcXXy/Y7bXyHVV1\nTlU9UFV/vzg+XFX3LK61T1XVubu9Rr6jqs6vqlur6l+q6uGqer1rbG+rqj9Y/Jv4xaq6pape7Drb\nP3Y0eKrqnCR/leRXkrw6ydVV9eqdXAMbsprkD7v71Ulel+R3FufphiSf6+5XJfnc4pi94z1JHj7l\n+C+TfKi7X5nkP5Ncuyur4oV8JMk/dvdPJ/nZrJ0719geVVUXJ3l3kiPd/TNJzknyzrjO9o2dvsPz\n2iSPdPej3f2tJJ9MctUOr4Ez6O4nuvv+xff/lbV/iC/O2rm6efG0m5Mc250V8r2q6pIkb03yscVx\nJTma5NbFU5yvPaSqXprkTUluSpLu/lZ3PxvX2F63kuQlVbWS5ECSJ+I62zd2OnguTvL4KccnFmPs\nUVX1iiSvSXJPkkPd/cTioSeTHNqlZXG6Dyf54yT/szh+WZJnu3t1cexa21sOJ3kmyd8sXob8WFUd\njGtsz+ruk0k+kOSxrIXO15PcF9fZvmHTMi+oqn4wyd8l+f3u/sapj/Xa2/u8xW8PqKq3JXm6u+/b\n7bWwYStJfj7JR7v7NUmey/e8fOUa21sW+6muylqs/liSg0nevKuL4qzsdPCcTHLpKceXLMbYY6rq\nRVmLnU9092cWw09V1UWLxy9K8vRurY/v8sYkv1pV/561l4mPZm1/yPmLW++Ja22vOZHkRHffszi+\nNWsB5Brbu34pyVe6+5nu/naSz2Tt2nOd7RM7HTz3JnnVYlf7uVnb8HXHDq+BM1js/7gpycPd/cFT\nHrojyTWL769JcvtOr43TdfefdPcl3f2KrF1Td3b3rye5K8nbF09zvvaQ7n4yyeNV9VOLoSuTPBTX\n2F72WJLXVdWBxb+R/3fOXGf7xI5/8GBVvSVr+w3OSfLx7v7zHV0AZ1RVv5Dk7iQP5jt7Qv40a/t4\nPp3kx7P2V+/f0d1f25VFsq6q+sUkf9Tdb6uqn8zaHZ8LkzyQ5De6+5u7uT6+o6p+LmubzM9N8miS\n38raf0JdY3tUVf1Zkl/L2jtZH0jy21nbs+M62wd80jIAMJ5NywDAeIIHABhP8AAA4wkeAGA8wQMA\njCd4AIDxBA8AMJ7gAQDG+1+hhyBcEjzN/wAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 720x720 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
|
|
" kernel, # rhs = conv kernel tensor\n",
|
|
" (2,2), # window strides\n",
|
|
" 'SAME', # padding mode\n",
|
|
" (1,1), # lhs/image dilation\n",
|
|
" (1,1), # rhs/kernel dilation\n",
|
|
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
|
|
"print(\"out shape: \", out.shape, \" <-- half the size of above\")\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"print(\"First output channel:\")\n",
|
|
"plt.imshow(np.array(out)[0,:,:,0]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "gPxttaiaFVHE"
|
|
},
|
|
"source": [
|
|
"#### VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 159,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 628
|
|
},
|
|
"colab_type": "code",
|
|
"id": "_pGr0x6qFVHF",
|
|
"outputId": "5387205f-3c23-4203-ff1b-ae5115eed5f7"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"out shape: (1, 176, 174, 3)\n",
|
|
"First output channel:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkEAAAJBCAYAAABBBGGtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAbnUlEQVR4nO3df7Dld13f8de7uyZKMjWJ0RSyaROT\noAOMFGal61AdflgJyLD84UiorVEzk4mlChGLIDOl/YMZ/DFGnVaYVCKhZSAUUTIOWjFimc40gQUF\n8gNkA0I2EwgOghKngcR3/7gn7Z3Nrpu959ycu/f9eMzs7D2f8z33vL98d2+efM/5nq3uDgDANP9g\n3QMAAKyDCAIARhJBAMBIIggAGEkEAQAjiSAAYKRti6CquqyqPllVh6vq1dv1PAAAW1Hb8TlBVbUn\nyZ8n+RdJjiT5UJKXdvcdK38yAIAt2LtN3/cZSQ5396eTpKrekeRgkmNG0J4zz+i955yzTaMAAJN9\n7e4jf9nd33r0+nZF0PlJ7t50+0iSf3a8jfeec06e8MpXbNMoAMBkf/GKn/3ssdbX9sboqrqqqg5V\n1aGHvnr/usYAAIbargi6J8kFm27vW6z9P919XXfv7+79e848Y5vGAAA4tu2KoA8lubSqLqqq05Jc\nnuSmbXouAICTti3vCeruB6vq3yb5H0n2JLm+u2/fjucCANiK7XpjdLr7vUneu13fHwBgGT4xGgAY\nSQQBACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQRBACMtG3/gOrJ\nOP3u+3PJNbes5HsdvvbAMddX9f234ngzAQDr40wQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQR\nBACMJIIAgJFEEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEA\nwEgiCAAYSQQBACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAj7V33AEnywAVn\n5PArD2zrcxy+dnu/PwBwanEmCAAYSQQBACOJIABgJBEEAIwkggCAkUQQADDSjrhE/vS7788l19yy\nku91Kl0Kf9dL3rS25774xqvX9twAsBM4EwQAjCSCAICRRBAAMNKWI6iqLqiq91fVHVV1e1W9fLF+\nTlW9r6o+tfj97NWNCwCwGsucCXowySu7+0lJDiR5WVU9Kcmrk9zc3ZcmuXlxGwBgR9lyBHX3vd39\nkcXXf5PkziTnJzmY5IbFZjckefGyQwIArNpK3hNUVRcmeVqSW5Oc1933Lu76fJLzVvEcAACrtHQE\nVdWZSX47ySu6+68339fdnaSP87irqupQVR36eh5YdgwAgJOyVARV1TdkI4De1t3vXix/oaoev7j/\n8UnuO9Zju/u67t7f3fu/IacvMwYAwElb5uqwSvLmJHd2969suuumJFcsvr4iyXu2Ph4AwPZY5p/N\neGaSf53k41X1Z4u1n0/yhiTvrKork3w2yQ8vNyIAwOptOYK6+38lqePc/dytfl8AgMeCT4wGAEYS\nQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQBACOJIABgJBEE\nAIwkggCAkUQQADCSCAIARhJBAMBIe9c9QJI8cMEZOfzKA+se4zF38Y1Xr3sEABjLmSAAYCQRBACM\nJIIAgJFEEAAwkggCAEbaEVeHPRbuesmb1vbcx7sKbCfOBABTOBMEAIwkggCAkUQQADCSCAIARhJB\nAMBIIggAGEkEAQAjiSAAYCQRBACMJIIAgJFEEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQA\njCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQBACOJIABgJBEEAIwkggCAkUQQADCSCAIARlo6gqpq\nT1X9aVX93uL2RVV1a1Udrqobq+q05ccEAFitVZwJenmSOzfd/oUk13b3JUn+KsmVK3gOAICV2rvM\ng6tqX5IfTPL6JD9TVZXkOUn+5WKTG5L8hyRvXOZ5VuHiG69e9wiPsBNnAoAplj0T9KtJXpXk7xa3\nvyXJl7v7wcXtI0nOP9YDq+qqqjpUVYce+ur9S44BAHBythxBVfXCJPd194e38vjuvq6793f3/j1n\nnrHVMQAAtmSZl8OemeRFVfWCJN+Y5B8m+bUkZ1XV3sXZoH1J7ll+TACA1drymaDufk137+vuC5Nc\nnuSPu/tHkrw/yQ8tNrsiyXuWnhIAYMW243OCfi4bb5I+nI33CL15G54DAGApS10d9rDu/pMkf7L4\n+tNJnrGK7wsAsF18YjQAMJIIAgBGEkEAwEgiCAAYSQQBACOJIABgJBEEAIwkggCAkUQQADCSCAIA\nRhJBAMBIIggAGEkEAQAjiSAAYCQRBACMJIIAgJFEEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAk\nEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQBACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJB\nAMBIIggAGEkEAQAjiSAAYCQRBACMJIIAgJFEEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQA\njCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQBACMtFUFVdVZVvauqPlFVd1bV91TVOVX1vqr61OL3\ns1c1LADAqix7JujXkvxBd39nkqcmuTPJq5Pc3N2XJrl5cRsAYEfZcgRV1Tcn+b4kb06S7v5ad385\nycEkNyw2uyHJi5cdEgBg1ZY5E3RRki8m+a2q+tOq+s2qOiPJed1972Kbzyc571gPrqqrqupQVR16\n6Kv3LzEGAMDJWyaC9iZ5epI3dvfTktyfo1766u5O0sd6cHdf1937u3v/njPPWGIMAICTt0wEHUly\npLtvXdx+Vzai6AtV9fgkWfx+33IjAgCs3pYjqLs/n+TuqvqOxdJzk9yR5KYkVyzWrkjynqUmBADY\nBnuXfPxPJXlbVZ2W5NNJfjwbYfXOqroyyWeT/PCSzwEAsHJLRVB3/1mS/ce467nLfF8AgO3mE6MB\ngJFEEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGWvYfUIWRLrnm\nlrU99+FrDxxzfSfOBLCTORMEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQRBACM\nJIIAgJFEEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgi\nCAAYSQQBACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGGnvugeAU9Hhaw+se4RH2Ikz\nAexkzgQBACOJIABgJBEEAIwkggCAkUQQADCSq8NgCy655pa1PfepdBXYXS9509qe++Ibr17bcwOn\nBmeCAICRRBAAMJIIAgBGWiqCquqaqrq9qm6rqrdX1TdW1UVVdWtVHa6qG6vqtFUNCwCwKluOoKo6\nP8lPJ9nf3U9JsifJ5Ul+Icm13X1Jkr9KcuUqBgUAWKVlXw7bm+SbqmpvkscluTfJc5K8a3H/DUle\nvORzAACs3JYjqLvvSfLLST6Xjfj5SpIPJ/lydz+42OxIkvOXHRIAYNWWeTns7CQHk1yU5AlJzkhy\n2Uk8/qqqOlRVhx766v1bHQMAYEuWeTns+5N8pru/2N1fT/LuJM9Mctbi5bEk2ZfknmM9uLuv6+79\n3b1/z5lnLDEGAMDJWyaCPpfkQFU9rqoqyXOT3JHk/Ul+aLHNFUnes9yIAACrt8x7gm7NxhugP5Lk\n44vvdV2Sn0vyM1V1OMm3JHnzCuYEAFippf7tsO5+XZLXHbX86STPWOb7AgBsN58YDQCMJIIAgJFE\nEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQB\nACOJIABgJBEEAIwkggCAkfauewA4FR2+9sC6RzglXHzj1eseAeC4nAkCAEYSQQDASCIIABhJBAEA\nI4kgAGAkV4fBLnHXS960tuc+3lVgO3EmgIc5EwQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQB\nACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQRBACMJIIAgJFEEAAw\nkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGOmEEVdX1VXVfVd22ae2c\nqnpfVX1q8fvZi/Wqql+vqsNV9bGqevp2Dg8AsFWP5kzQW5JcdtTaq5Pc3N2XJrl5cTtJnp/k0sWv\nq5K8cTVjAgCs1t4TbdDdH6iqC49aPpjkWYuvb0jyJ0l+brH+1u7uJLdU1VlV9fjuvndVAwPHdvGN\nV697hEfYiTMBPGyr7wk6b1PYfD7JeYuvz09y96btjizWAAB2lKXfGL0469Mn+7iquqqqDlXVoYe+\nev+yYwAAnJStRtAXqurxSbL4/b7F+j1JLti03b7F2iN093Xdvb+79+8584wtjgEAsDVbjaCbklyx\n+PqKJO/ZtP6ji6vEDiT5ivcDAQA70QnfGF1Vb8/Gm6DPraojSV6X5A1J3llVVyb5bJIfXmz+3iQv\nSHI4yd8m+fFtmBkAYGmP5uqwlx7nruceY9tO8rJlhwIA2G4+MRoAGEkEAQAjiSAAYCQRBACMJIIA\ngJFEEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAY\nSQQBACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQRBACMJIIAgJFE\nEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQB\nACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQRBACMJIIAgJFOGEFV\ndX1V3VdVt21a+6Wq+kRVfayqfqeqztp032uq6nBVfbKqnrddgwMALOPRnAl6S5LLjlp7X5KndPd3\nJfnzJK9Jkqp6UpLLkzx58ZjfqKo9K5sWAGBFThhB3f2BJF86au0Pu/vBxc1bkuxbfH0wyTu6+4Hu\n/kySw0mescJ5AQBWYhXvCfqJJL+/+Pr8JHdvuu/IYu0RquqqqjpUVYce+ur9KxgDAODRWyqCquq1\nSR5M8raTfWx3X9fd+7t7/54zz1hmDACAk7Z3qw+sqh9L8sIkz+3uXizfk+SCTZvtW6wBAOwoWzoT\nVFWXJXlVkhd1999uuuumJJdX1elVdVGSS5N8cPkxAQBW64Rngqrq7UmeleTcqjqS5HXZuBrs9CTv\nq6okuaW7r+7u26vqnUnuyMbLZC/r7oe2a3gAgK06YQR190uPsfzmv2f71yd5/TJDAQBsN58YDQCM\nJIIAgJFEEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMNIJ/+0wAGa65Jpb\n1vbch689cMz1nTgTpy5nggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQRBACMJIIAgJFE\nEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQB\nACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAj7V33AADsTIevPbDuER5hJ87E\nqcuZIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBILpEH4JguueaWtT33qXQp/F0vedPanvviG69e\n23PvBs4EAQAjiSAAYCQRBACMdMIIqqrrq+q+qrrtGPe9sqq6qs5d3K6q+vWqOlxVH6uqp2/H0AAA\ny3o0Z4LekuSyoxer6oIkP5Dkc5uWn5/k0sWvq5K8cfkRAQBW74QR1N0fSPKlY9x1bZJXJelNaweT\nvLU33JLkrKp6/EomBQBYoS29J6iqDia5p7s/etRd5ye5e9PtI4u1Y32Pq6rqUFUdeuir929lDACA\nLTvpzwmqqscl+flsvBS2Zd19XZLrkuT0f3xBn2BzAICV2sqHJV6c5KIkH62qJNmX5CNV9Ywk9yS5\nYNO2+xZrAAA7ykm/HNbdH+/ub+vuC7v7wmy85PX07v58kpuS/OjiKrEDSb7S3feudmQAgOU9mkvk\n357kfyf5jqo6UlVX/j2bvzfJp5McTvJfkvyblUwJALBiJ3w5rLtfeoL7L9z0dSd52fJjAQBsL58Y\nDQCMJIIAgJFEEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEA\nwEgiCAAYSQQBACOJIABgJBEEAIy0d90DALAzHb72wLpHOCVcfOPV6x6BLXImCAAYSQQBACOJIABg\nJBEEAIwkggCAkUQQADCSS+QBOGXc9ZI3re25j3cp/E6ciUfHmSAAYCQRBACMJIIAgJFEEAAwkggC\nAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQBACOJIABg\nJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQRBACMJIIAgJFOGEFVdX1V3VdV\ntx21/lNV9Ymqur2qfnHT+muq6nBVfbKqnrcdQwMALGvvo9jmLUn+U5K3PrxQVc9OcjDJU7v7gar6\ntsX6k5JcnuTJSZ6Q5I+q6ond/dCqBwcAWMYJzwR19weSfOmo5Z9M8obufmCxzX2L9YNJ3tHdD3T3\nZ5IcTvKMFc4LALASj+ZM0LE8Mcn3VtXrk/yfJD/b3R9Kcn6SWzZtd2SxBgBLu/jGq9c9wiPsxJl4\ndLYaQXuTnJPkQJLvTvLOqvr2k/kGVXVVkquSZM/ZZ29xDACArdnq1WFHkry7N3wwyd8lOTfJPUku\n2LTdvsXaI3T3dd29v7v37znzjC2OAQCwNVuNoN9N8uwkqaonJjktyV8muSnJ5VV1elVdlOTSJB9c\nxaAAAKt0wpfDqurtSZ6V5NyqOpLkdUmuT3L94rL5ryW5ors7ye1V9c4kdyR5MMnLXBkGAOxEJ4yg\n7n7pce76V8fZ/vVJXr/MUAAA280nRgMAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAY\nSQQBACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQRBACMJIIAgJFE\nEAAwkggCAEYSQQDASCIIABhJBAEAI4kgAGAkEQQAjCSCAICRRBAAMJIIAgBGEkEAwEgiCAAYSQQB\nACOJIABgJBEEAIwkggCAkUQQADCSCAIARhJBAMBIIggAGEkEAQAjiSAAYCQRBACMVN297hlSVV9M\n8tkk5yb5yzWPsw4T93viPif2e5qJ+z1xnxP7vdP9k+7+1qMXd0QEPayqDnX3/nXP8VibuN8T9zmx\n3+ue47E2cb8n7nNiv9c9x1Z5OQwAGEkEAQAj7bQIum7dA6zJxP2euM+J/Z5m4n5P3OfEfp+SdtR7\nggAAHis77UwQAMBjYkdEUFVdVlWfrKrDVfXqdc+zXarqgqp6f1XdUVW3V9XLF+vnVNX7qupTi9/P\nXvesq1ZVe6rqT6vq9xa3L6qqWxfH/MaqOm3dM65aVZ1VVe+qqk9U1Z1V9T1DjvU1iz/ft1XV26vq\nG3fj8a6q66vqvqq6bdPaMY9vbfj1xf5/rKqevr7Jl3Oc/f6lxZ/zj1XV71TVWZvue81ivz9ZVc9b\nz9TLO9Z+b7rvlVXVVXXu4vauPt6L9Z9aHPPbq+oXN62fUsd77RFUVXuS/Ockz0/ypCQvraonrXeq\nbfNgkld295OSHEjyssW+vjrJzd19aZKbF7d3m5cnuXPT7V9Icm13X5Lkr5JcuZapttevJfmD7v7O\nJE/Nxv7v6mNdVecn+ekk+7v7KUn2JLk8u/N4vyXJZUetHe/4Pj/JpYtfVyV542M043Z4Sx653+9L\n8pTu/q4kf57kNUmy+Pl2eZInLx7zG4uf+aeit+SR+52quiDJDyT53KblXX28q+rZSQ4meWp3PznJ\nLy/WT7njvfYISvKMJIe7+9Pd/bUk78jG/7i7Tnff290fWXz9N9n4j+L52djfGxab3ZDkxeuZcHtU\n1b4kP5jkNxe3K8lzkrxrsclu3OdvTvJ9Sd6cJN39te7+cnb5sV7Ym+SbqmpvkscluTe78Hh39weS\nfOmo5eMd34NJ3tobbklyVlU9/rGZdLWOtd/d/Yfd/eDi5i1J9i2+PpjkHd39QHd/JsnhbPzMP+Uc\n53gnybVJXpVk8xtsd/XxTvKTSd7Q3Q8strlvsX7KHe+dEEHnJ7l70+0ji7VdraouTPK0JLcmOa+7\n713c9fkk561prO3yq9n4IfF3i9vfkuTLm35o7sZjflGSLyb5rcXLgL9ZVWdklx/r7r4nG/+v8HPZ\niJ+vJPlwdv/xftjxju+kn3M/keT3F1/v6v2uqoNJ7unujx51167e7yRPTPK9i5e4/2dVffdi/ZTb\n750QQeNU1ZlJfjvJK7r7rzff1xuX6+2aS/aq6oVJ7uvuD697lsfY3iRPT/LG7n5akvtz1Etfu+1Y\nJ8niPTAHsxGBT0hyRo7xEsIEu/H4nkhVvTYbL/u/bd2zbLeqelySn0/y79c9yxrsTXJONt7W8e+S\nvHNxhv+UsxMi6J4kF2y6vW+xtitV1TdkI4De1t3vXix/4eFTpYvf7zve409Bz0zyoqr6i2y81Pmc\nbLxX5qzFyyXJ7jzmR5Ic6e5bF7fflY0o2s3HOkm+P8lnuvuL3f31JO/Oxp+B3X68H3a847vrf85V\n1Y8leWGSH+n//9kru3m/L85G7H908fNtX5KPVNU/yu7e72Tj59u7Fy/3fTAbZ/nPzSm43zshgj6U\n5NLF1SOnZeNNVTeteaZtsSjlNye5s7t/ZdNdNyW5YvH1FUne81jPtl26+zXdva+7L8zGsf3j7v6R\nJO9P8kOLzXbVPidJd38+yd1V9R2LpecmuSO7+FgvfC7Jgap63OLP+8P7vauP9ybHO743JfnRxVVD\nB5J8ZdPLZqe8qrosGy95v6i7/3bTXTclubyqTq+qi7LxRuEPrmPGVevuj3f3t3X3hYufb0eSPH3x\nd39XH+8kv5vk2UlSVU9Mclo2/hHVU+94d/fafyV5QTauKLgryWvXPc827uc/z8bp8Y8l+bPFrxdk\n4z0yNyf5VJI/SnLOumfdpv1/VpLfW3z97dn4y3E4yX9Pcvq659uG/f2nSQ4tjvfvJjl7wrFO8h+T\nfCLJbUn+a5LTd+PxTvL2bLzv6evZ+A/glcc7vkkqG1fB3pXk49m4em7t+7DC/T6cjfeCPPxz7U2b\ntn/tYr8/meT5655/lft91P1/keTcIcf7tCT/bfF3/CNJnnOqHm+fGA0AjLQTXg4DAHjMiSAAYCQR\nBACMJIIAgJFEEAAwkggCAEYSQQDASCIIABjp/wLoMFWX9Nz8GQAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 720x720 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
|
|
" kernel, # rhs = conv kernel tensor\n",
|
|
" (1,1), # window strides\n",
|
|
" 'VALID', # padding mode\n",
|
|
" (1,1), # lhs/image dilation\n",
|
|
" (12,12), # rhs/kernel dilation\n",
|
|
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
|
|
"print(\"out shape: \", out.shape)\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"print(\"First output channel:\")\n",
|
|
"plt.imshow(np.array(out)[0,:,:,0]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "v-RhEeUfFVHI"
|
|
},
|
|
"source": [
|
|
"#### VALID padding, no stride, lhs=input dilation ~ Transposed Convolution"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 160,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 629
|
|
},
|
|
"colab_type": "code",
|
|
"id": "B9Ail8ppFVHJ",
|
|
"outputId": "3617a5d2-1eaa-46e8-d691-87b365ed1310"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"out shape: (1, 397, 393, 3) <-- larger than original!\n",
|
|
"First output channel:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAJCCAYAAAAsp6gAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAaXklEQVR4nO3cb6ymdX3n8c93Z5A2YIpUl0yBXSmy\nMbZJR3dKp2nTuBpb5Ak2cQU3UdaYUHYxqU13U+wTbVKTdrOWbbNZCI1UbNoCsTUSQ/+w6sb0Aepo\nKfKntoNiYBZht/6pYJaG6XcfnIt6GGY4M3P+3FO+r1dycq77d133uX/3z+vge+7rvk91dwAAJvpn\nq54AAMCqCCEAYCwhBACMJYQAgLGEEAAwlhACAMbathCqqkuq6ktVdbCqrt2uxwEAOFm1HX9HqKp2\nJfnrJG9I8kiSzyV5a3ffv+UPBgBwkrbrFaGLkxzs7i93998nuSXJZdv0WAAAJ2X3Nv3cc5M8vO72\nI0l+7FgH7zrzjN599tnbNBUAYLKnv/71HH7iyTravu0KoQ1V1VVJrkqSXS95SX7gF9+9qqkAAC9g\n//sD/+2Y+7br0tihJOevu33eMvaPuvvG7t7X3ft2nXnGNk0DAODYtiuEPpfkoqq6oKpelOSKJLdv\n02MBAJyUbbk01t1PV9W7kvxpkl1Jburu+7bjsQAATta2vUeou+9Icsd2/XwAgM3yl6UBgLGEEAAw\nlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJ\nIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhAC\nAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBg\nLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYS\nQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFi7N3PnqnooybeTHE7ydHfv\nq6qzk9ya5OVJHkrylu7+xuamCQCw9bbiFaF/0917u3vfcvvaJJ/o7ouSfGK5DQBwytmOS2OXJbl5\n2b45yZu24TEAADZtsyHUSf6sqj5fVVctY+d096PL9teSnHO0O1bVVVV1oKoOHH7iyU1OAwDgxG3q\nPUJJfrK7D1XVP09yZ1X91fqd3d1V1Ue7Y3ffmOTGJDn9X5x/1GMAALbTpl4R6u5Dy/fHk3w0ycVJ\nHquqPUmyfH98s5MEANgOJx1CVXVGVb34me0kP53k3iS3J7lyOezKJB/b7CQBALbDZi6NnZPko1X1\nzM/5/e7+k6r6XJLbquqdSb6a5C2bnyYAwNY76RDq7i8n+ZGjjP9tktdvZlIAADvBX5YGAMYSQgDA\nWEIIABhLCAEAY232DypyFK/4hbt29PEOXrd/Rx8PAF4ovCIEAIzlFaFtspOv0jx4+Q258Nard+zx\nAOCFwitCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkh\nAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIA\nxhJCAMBYu1c9gReig9ftz4OX37DqaQAAGxBC2+TCW69e9RQAgA24NAYAjCWEAICxhBAAMJYQAgDG\nEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwh\nBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjLVhCFXVTVX1eFXdu27s7Kq6\ns6r+Zvn+kmW8quq3qupgVd1TVa/ZzskDAGzG8bwi9KEklxwxdm2ST3T3RUk+sdxOkjcmuWj5uirJ\n9VszTQCArbdhCHX3p5N8/Yjhy5LcvGzfnORN68Y/3GvuSnJWVe3ZqskCAGylk32P0Dnd/eiy/bUk\n5yzb5yZ5eN1xjyxjAACnnE2/Wbq7O0mf6P2q6qqqOlBVBw4/8eRmpwEAcMJONoQee+aS1/L98WX8\nUJLz1x133jL2HN19Y3fv6+59u8484ySnAQBw8k42hG5PcuWyfWWSj60bf/vy6bH9Sb617hIaAMAp\nZfdGB1TVHyR5bZKXVtUjSd6b5NeS3FZV70zy1SRvWQ6/I8mlSQ4m+U6Sd2zDnAEAtsSGIdTdbz3G\nrtcf5dhOcs1mJwUAsBP8ZWkAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAs\nIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBY\nQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWE\nAICxhBAAMNaGIVRVN1XV41V177qx91XVoaq6e/m6dN2+91TVwar6UlX9zHZNHABgs47nFaEPJbnk\nKOPXdffe5euOJKmqVyW5IskPLff5H1W1a6smCwCwlTYMoe7+dJKvH+fPuyzJLd39VHd/JcnBJBdv\nYn4AANtmM+8ReldV3bNcOnvJMnZukofXHfPIMvYcVXVVVR2oqgOHn3hyE9MAADg5JxtC1ye5MMne\nJI8m+cCJ/oDuvrG793X3vl1nnnGS0wAAOHknFULd/Vh3H+7uf0jy2/nu5a9DSc5fd+h5yxgAwCnn\npEKoqvasu/mzSZ75RNntSa6oqtOr6oIkFyX57OamCACwPXZvdEBV/UGS1yZ5aVU9kuS9SV5bVXuT\ndJKHkvxcknT3fVV1W5L7kzyd5JruPrw9UwcA2JwNQ6i733qU4Q8+z/HvT/L+zUwKAGAn+MvSAMBY\nQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWE\nAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggA\nGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMbaveoJANvvFb9w14491sHr9u/YYwFslhCCIXYq\nUB68/IZceOvVO/JYAJvl0hgAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCW\nEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkh\nAGAsIQQAjCWEAICxhBAAMNbujQ6oqvOTfDjJOUk6yY3d/ZtVdXaSW5O8PMlDSd7S3d+oqkrym0ku\nTfKdJP++u7+wPdMHjsfB6/bnwctvWPU0AE45x/OK0NNJfrG7X5Vkf5JrqupVSa5N8onuvijJJ5bb\nSfLGJBctX1cluX7LZw0AsAU2fEWoux9N8uiy/e2qeiDJuUkuS/La5bCbk/yvJL+0jH+4uzvJXVV1\nVlXtWX4OsCIX3nr1qqcAcMo5ofcIVdXLk7w6yWeSnLMubr6WtUtnyVokPbzubo8sYwAAp5TjDqGq\nOjPJHyZ5d3f/3fp9y6s/fSIPXFVXVdWBqjpw+IknT+SuAABb4rhCqKpOy1oE/V53/9Ey/FhV7Vn2\n70ny+DJ+KMn56+5+3jL2LN19Y3fv6+59u84842TnDwBw0jYMoeVTYB9M8kB3/8a6XbcnuXLZvjLJ\nx9aNv73W7E/yLe8PAgBORRu+WTrJTyR5W5IvVtXdy9gvJ/m1JLdV1TuTfDXJW5Z9d2Tto/MHs/bx\n+Xds6YwBALbI8Xxq7M+T1DF2v/4ox3eSazY5LwCAbecvSwMAYwkhAGAsIQQAjCWEAICxhBAAMJYQ\nAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEA\nYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDG\nEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwh\nBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIA\nwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsTYMoao6v6o+VVX3V9V9VfXzy/j7qupQVd29fF267j7v\nqaqDVfWlqvqZ7XwCAAAna/dxHPN0kl/s7i9U1YuTfL6q7lz2Xdfd/3X9wVX1qiRXJPmhJD+Q5H9W\n1b/q7sNbOXEAgM3a8BWh7n60u7+wbH87yQNJzn2eu1yW5Jbufqq7v5LkYJKLt2KyAABb6YTeI1RV\nL0/y6iSfWYbeVVX3VNVNVfWSZezcJA+vu9sjOUo4VdVVVXWgqg4cfuLJE544AMBmHXcIVdWZSf4w\nybu7+++SXJ/kwiR7kzya5AMn8sDdfWN37+vufbvOPONE7goAsCWOK4Sq6rSsRdDvdfcfJUl3P9bd\nh7v7H5L8dr57+etQkvPX3f28ZQwA4JRyPJ8aqyQfTPJAd//GuvE96w772ST3Ltu3J7miqk6vqguS\nXJTks1s3ZQCArXE8nxr7iSRvS/LFqrp7GfvlJG+tqr1JOslDSX4uSbr7vqq6Lcn9WfvE2TU+MQYA\nnIo2DKHu/vMkdZRddzzPfd6f5P2bmBcAwLbzl6UBgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkh\nAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIA\nxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAs\nIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDG2r3RAVX1PUk+neT05fiPdPd7q+qCJLck+f4kn0/ytu7++6o6PcmHk/zr\nJH+b5PLufmib5g/APzGv+IW7dvTxDl63f0cfj39ajucVoaeSvK67fyTJ3iSXVNX+JL+e5LrufkWS\nbyR553L8O5N8Yxm/bjkOAOCUs+ErQt3dSZ5Ybp62fHWS1yX5d8v4zUnel+T6JJct20nykST/vapq\n+TkAkGTnXql58PIbcuGtV+/IY/FPz3G9R6iqdlXV3UkeT3JnkgeTfLO7n14OeSTJucv2uUkeTpJl\n/7eydvkMAOCUclwh1N2Hu3tvkvOSXJzklZt94Kq6qqoOVNWBw088udkfBwBwwk7oU2Pd/c0kn0ry\n40nOqqpnLq2dl+TQsn0oyflJsuz/vqy9afrIn3Vjd+/r7n27zjzjJKcPAHDyNgyhqnpZVZ21bH9v\nkjckeSBrQfTm5bArk3xs2b59uZ1l/ye9PwgAOBVt+GbpJHuS3FxVu7IWTrd198er6v4kt1TVryb5\niyQfXI7/YJLfraqDSb6e5IptmDcAwKYdz6fG7kny6qOMfzlr7xc6cvz/Jfm3WzI7AIBt5C9LAwBj\nCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQ\nAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIy1e9UTAGCWg9ftT5I8ePkNK54JCCEAVuTCW69e9RTA\npTEAYC4hBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQ\nAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEA\nYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDG\nEkIAwFgbhlBVfU9Vfbaq/rKq7quqX1nGP1RVX6mqu5evvct4VdVvVdXBqrqnql6z3U8CAOBk7D6O\nY55K8rrufqKqTkvy51X1x8u+/9zdHzni+DcmuWj5+rEk1y/fAQBOKRu+ItRrnlhunrZ89fPc5bIk\nH17ud1eSs6pqz+anCgCwtY7rPUJVtauq7k7yeJI7u/szy673L5e/rquq05exc5M8vO7ujyxjAACn\nlOMKoe4+3N17k5yX5OKq+uEk70nyyiQ/muTsJL90Ig9cVVdV1YGqOnD4iSdPcNoAAJt3Qp8a6+5v\nJvlUkku6+9Hl8tdTSX4nycXLYYeSnL/ubuctY0f+rBu7e19379t15hknN3sAgE04nk+Nvayqzlq2\nvzfJG5L81TPv+6mqSvKmJPcud7k9yduXT4/tT/Kt7n50W2YPALAJx/OpsT1Jbq6qXVkLp9u6++NV\n9cmqelmSSnJ3kquX4+9IcmmSg0m+k+QdWz9tAIDN2zCEuvueJK8+yvjrjnF8J7lm81MDANhe/rI0\nADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEA\nYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCW\nEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkh\nAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIA\nxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWNXdq55Dqur/JHkyyf9d\n9VxOIS+N9TiSNXk26/Fc1uTZrMdzWZNnm7Ie/7K7X3a0HadECCVJVR3o7n2rnsepwno8lzV5Nuvx\nXNbk2azHc1mTZ7MeLo0BAIMJIQBgrFMphG5c9QROMdbjuazJs1mP57Imz2Y9nsuaPNv49Thl3iME\nALDTTqVXhAAAdtTKQ6iqLqmqL1XVwaq6dtXzWZWqeqiqvlhVd1fVgWXs7Kq6s6r+Zvn+klXPc7tU\n1U1V9XhV3btu7KjPv9b81nLO3FNVr1ndzLfPMdbkfVV1aDlP7q6qS9fte8+yJl+qqp9Zzay3T1Wd\nX1Wfqqr7q+q+qvr5ZXzkefI86zH5HPmeqvpsVf3lsia/soxfUFWfWZ77rVX1omX89OX2wWX/y1c5\n/632POvxoar6yrpzZO8y/oL+nTmm7l7ZV5JdSR5M8oNJXpTkL5O8apVzWuFaPJTkpUeM/Zck1y7b\n1yb59VXPcxuf/08leU2Sezd6/kkuTfLHSSrJ/iSfWfX8d3BN3pfkPx3l2Fctvz+nJ7lg+b3atern\nsMXrsSfJa5btFyf56+V5jzxPnmc9Jp8jleTMZfu0JJ9Z/re/LckVy/gNSf7Dsv0fk9ywbF+R5NZV\nP4cdWo8PJXnzUY5/Qf/OHOtr1a8IXZzkYHd/ubv/PsktSS5b8ZxOJZcluXnZvjnJm1Y4l23V3Z9O\n8vUjho/1/C9L8uFec1eSs6pqz87MdOccY02O5bIkt3T3U939lSQHs/b79YLR3Y929xeW7W8neSDJ\nuRl6njzPehzLhHOku/uJ5eZpy1cneV2SjyzjR54jz5w7H0ny+qqqHZrutnue9TiWF/TvzLGsOoTO\nTfLwutuP5Pl/kV/IOsmfVdXnq+qqZeyc7n502f5aknNWM7WVOdbzn37evGt52fqmdZdLR63Jcgnj\n1Vn7F+748+SI9UgGnyNVtauq7k7yeJI7s/bK1ze7++nlkPXP+x/XZNn/rSTfv7Mz3l5Hrkd3P3OO\nvH85R66rqtOXsRHnyJFWHUJ8109292uSvDHJNVX1U+t39trrlmM/4jf9+a9zfZILk+xN8miSD6x2\nOjuvqs5M8odJ3t3df7d+38Tz5CjrMfoc6e7D3b03yXlZe8XrlSue0koduR5V9cNJ3pO1dfnRJGcn\n+aUVTnHlVh1Ch5Kcv+72ecvYON19aPn+eJKPZu0X+LFnXpZcvj++uhmuxLGe/9jzprsfW/7D9g9J\nfjvfvbQxYk2q6rSs/Z/+73X3Hy3DY8+To63H9HPkGd39zSSfSvLjWbvEs3vZtf55/+OaLPu/L8nf\n7vBUd8S69bhkuaza3f1Ukt/J0HPkGasOoc8luWh5R/+LsvZmtdtXPKcdV1VnVNWLn9lO8tNJ7s3a\nWly5HHZlko+tZoYrc6znf3uSty+fcNif5FvrLo28oB1xvf5ns3aeJGtrcsXyKZgLklyU5LM7Pb/t\ntLx344NJHuju31i3a+R5cqz1GH6OvKyqzlq2vzfJG7L23qlPJXnzctiR58gz586bk3xyeVXxBeEY\n6/FX6/7hUFl7v9T6c+QF+ztzLLs3PmT7dPfTVfWuJH+atU+Q3dTd961yTityTpKPLu/R253k97v7\nT6rqc0luq6p3JvlqkrescI7bqqr+IMlrk7y0qh5J8t4kv5ajP/87svbphoNJvpPkHTs+4R1wjDV5\n7fJR187aJw1/Lkm6+76qui3J/UmeTnJNdx9exby30U8keVuSLy7veUiSX87c8+RY6/HWwefIniQ3\nV9WurP1D/7bu/nhV3Z/klqr61SR/kbWAzPL9d6vqYNY+mHDFKia9jY61Hp+sqpdl7dNhdye5ejn+\nhf47c1T+sjQAMNaqL40BAKyMEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLH+P2a9mLpL\n+QBTAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 720x720 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
|
|
" kernel, # rhs = conv kernel tensor\n",
|
|
" (1,1), # window strides\n",
|
|
" ((0, 0), (0, 0)), # padding mode\n",
|
|
" (2,2), # lhs/image dilation\n",
|
|
" (1,1), # rhs/kernel dilation\n",
|
|
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
|
|
"print(\"out shape: \", out.shape, \"<-- larger than original!\")\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"print(\"First output channel:\")\n",
|
|
"plt.imshow(np.array(out)[0,:,:,0]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "A-9OagtrVDyV"
|
|
},
|
|
"source": [
|
|
"We can use the last to, for instance, implement _transposed convolutions_:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 161,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 629
|
|
},
|
|
"colab_type": "code",
|
|
"id": "5EYIj77-NdHE",
|
|
"outputId": "f325e6cb-3079-4250-898f-ca4fb081c6c7"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"out shape: (1, 400, 396, 3) <-- transposed_conv\n",
|
|
"First output channel:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAJCCAYAAAAsp6gAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAaVElEQVR4nO3dbaytdXnn8d/Vc5A2YApUh1AgIwNM\nGtukR3JKado0jsYWeYNNHIsvKjEmp8xgUk1nUuyb2qQm7WQsE5MZDY1UbDoFYmskDX1glKTpC1C0\niDzUdvsUOIMwrdWKZmg4vebFvrH7MOdw9tlP63iuzyfZ2ff63/fa67/+3Bu/rHutbXV3AAAm+p5V\nTwAAYFWEEAAwlhACAMYSQgDAWEIIABhLCAEAY+1aCFXV1VX1+apaq6qbdutxAAC2qnbj7whV1b4k\nf5PkdUmeSPKpJG/u7kd3/MEAALZot14RujLJWnd/sbv/KcntSa7dpccCANiS/bv0cy9M8viG208k\n+fHjHbzv7LN6/3nn7dJUAIDJnvva13LkmW/VsfbtVgidUFUdSnIoSfade25+8JffsaqpAACnsf/9\n3v923H27dWnscJKLN9y+aBn7ju6+pbsPdvfBfWeftUvTAAA4vt0KoU8lubyqLqmqlyS5Lsldu/RY\nAABbsiuXxrr7uap6e5I/S7Ivya3d/chuPBYAwFbt2nuEuvvuJHfv1s8HANguf1kaABhLCAEAYwkh\nAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIA\nxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAs\nIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJC\nAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQA\njCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgrP3buXNVfTnJN5McSfJcdx+sqvOS3JHk\nFUm+nORN3f0P25smAMDO24lXhP5ddx/o7oPL7ZuSfLy7L0/y8eU2AMApZzcujV2b5LZl+7Ykb9iF\nxwAA2LbthlAn+fOq+nRVHVrGzu/uJ5ftryY5f5uPAQCwK7b1HqEkP9Xdh6vqXyW5p6r+euPO7u6q\n6mPdcQmnQ0my79xztzkNAICTt61XhLr78PL96SQfTXJlkqeq6oIkWb4/fZz73tLdB7v74L6zz9rO\nNAAAtmTLIVRVZ1XVS5/fTvIzSR5OcleS65fDrk/yse1OEgBgN2zn0tj5ST5aVc//nP/Z3X9aVZ9K\ncmdVvS3JV5K8afvTBADYeVsOoe7+YpIfPcb43yd57XYmBQCwF/xlaQBgLCEEAIwlhACAsYQQADCW\nEAIAxhJCAMBYQggAGEsIAQBjbff/dJUX8YWf/8CePt6ld9ywp48HAN/thNAe2KtAueyd92Xt5qv2\n5LEA4HTg0hgAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBY\nQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWE\nAICxhBAAMNb+VU/gdHbpHTckSS57530rngkAcCxCaA+s3XzVqqcAAByDS2MAwFhCCAAYSwgBAGMJ\nIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhAC\nAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxThhCVXVrVT1dVQ9vGDuvqu6pqr9dvp+7jFdV\nva+q1qrqoaq6YjcnDwCwHZt5RehDSa5+wdhNST7e3Zcn+fhyO0len+Ty5etQkvfvzDQBAHbeCUOo\nu/8iyddeMHxtktuW7duSvGHD+Id73X1JzqmqC3ZqsgAAO2mr7xE6v7ufXLa/muT8ZfvCJI9vOO6J\nZQwA4JSz7TdLd3cn6ZO9X1UdqqoHquqBI898a7vTAAA4aVsNoaeev+S1fH96GT+c5OINx120jP1/\nuvuW7j7Y3Qf3nX3WFqcBALB1Ww2hu5Jcv2xfn+RjG8bfsnx67Kok39hwCQ0A4JSy/0QHVNUfJHl1\nkpdV1RNJfi3Jbya5s6reluQrSd60HH53kmuSrCX5dpK37sKcAQB2xAlDqLvffJxdrz3GsZ3kxu1O\nCgBgL/jL0gDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJ\nIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhAC\nAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBg\nLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYS\nQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGOdMISq6taq\nerqqHt4w9u6qOlxVDy5f12zY966qWquqz1fVz+7WxAEAtmszrwh9KMnVxxi/ubsPLF93J0lVvTLJ\ndUl+eLnP/6iqfTs1WQCAnXTCEOruv0jytU3+vGuT3N7dz3b3l5KsJblyG/MDANg123mP0Nur6qHl\n0tm5y9iFSR7fcMwTyxgAwClnqyH0/iSXJjmQ5Mkk7z3ZH1BVh6rqgap64Mgz39riNAAAtm5LIdTd\nT3X3ke7+5yS/k3+5/HU4ycUbDr1oGTvWz7iluw9298F9Z5+1lWkAAGzLlkKoqi7YcPPnkjz/ibK7\nklxXVWdW1SVJLk/yye1NEQBgd+w/0QFV9QdJXp3kZVX1RJJfS/LqqjqQpJN8OckvJkl3P1JVdyZ5\nNMlzSW7s7iO7M3UAgO05YQh195uPMfzBFzn+PUnes51JAQDsBX9ZGgAYSwgBAGMJIQBgLCEEAIwl\nhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEII\nABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACA\nsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYKz9q54AsDe+8PMf2LPHuvSOG/bssQC2QwjBIHsV\nKJe9876s3XzVnjwWwHa4NAYAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwl\nhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEII\nABjrhCFUVRdX1b1V9WhVPVJVv7SMn1dV91TV3y7fz13Gq6reV1VrVfVQVV2x208CAGArNvOK0HNJ\nfrm7X5nkqiQ3VtUrk9yU5OPdfXmSjy+3k+T1SS5fvg4lef+OzxoAYAecMIS6+8nu/syy/c0kjyW5\nMMm1SW5bDrstyRuW7WuTfLjX3ZfknKq6YMdnDgCwTftP5uCqekWSVyW5P8n53f3ksuurSc5fti9M\n8viGuz2xjD0ZYGUuveOGXPbO+1Y9DYBTyqZDqKrOTvKHSd7R3f9YVd/Z191dVX0yD1xVh7J+6Sz7\nzj33ZO4KbNHazVetegoAp5RNfWqsqs7IegT9fnf/0TL81POXvJbvTy/jh5NcvOHuFy1jR+nuW7r7\nYHcf3Hf2WVudPwDAlm3mU2OV5INJHuvu396w664k1y/b1yf52IbxtyyfHrsqyTc2XEIDADhlbObS\n2E8m+YUkn6uqB5exX03ym0nurKq3JflKkjct++5Ock2StSTfTvLWHZ0xAMAOOWEIdfdfJqnj7H7t\nMY7vJDduc14AALvOX5YGAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIA\nwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACM\nJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhC\nCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQA\ngLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAY\n64QhVFUXV9W9VfVoVT1SVb+0jL+7qg5X1YPL1zUb7vOuqlqrqs9X1c/u5hMAANiq/Zs45rkkv9zd\nn6mqlyb5dFXds+y7ubv/68aDq+qVSa5L8sNJfjDJ/6qqf9vdR3Zy4gAA23XCV4S6+8nu/syy/c0k\njyW58EXucm2S27v72e7+UpK1JFfuxGQBAHbSSb1HqKpekeRVSe5fht5eVQ9V1a1Vde4ydmGSxzfc\n7Ym8eDgBAKzEpkOoqs5O8odJ3tHd/5jk/UkuTXIgyZNJ3nsyD1xVh6rqgap64Mgz3zqZuwIA7IhN\nhVBVnZH1CPr97v6jJOnup7r7SHf/c5Lfyb9c/jqc5OINd79oGTtKd9/S3Qe7++C+s8/aznMAANiS\nzXxqrJJ8MMlj3f3bG8Yv2HDYzyV5eNm+K8l1VXVmVV2S5PIkn9y5KQMA7IzNfGrsJ5P8QpLPVdWD\ny9ivJnlzVR1I0km+nOQXk6S7H6mqO5M8mvVPnN3oE2MAwKnohCHU3X+ZpI6x6+4Xuc97krxnG/MC\nANh1/rI0ADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEII\nABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACA\nsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhL\nCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQ\nADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWCcMoar63qr6\nZFV9tqoeqapfX8Yvqar7q2qtqu6oqpcs42cut9eW/a/Y3acAALA1m3lF6Nkkr+nuH01yIMnVVXVV\nkt9KcnN3X5bkH5K8bTn+bUn+YRm/eTkOAOCUc8IQ6nXPLDfPWL46yWuSfGQZvy3JG5bta5fbWfa/\ntqpqx2YMALBD9m/moKral+TTSS5L8t+TfCHJ17v7ueWQJ5JcuGxfmOTxJOnu56rqG0l+IMnf7eC8\nAfgu9YWf/8CePdald9ywZ4/Fd6dNhVB3H0lyoKrOSfLRJD+03QeuqkNJDiXJvnPP3e6PA+C7yF4F\nymXvvC9JsnbzVXvyeHz3OalPjXX315Pcm+QnkpxTVc+H1EVJDi/bh5NcnCTL/u9P8vfH+Fm3dPfB\n7j647+yztjh9AICt28ynxl6+vBKUqvq+JK9L8ljWg+iNy2HXJ/nYsn3XcjvL/k90d+/kpAEAdsJm\nLo1dkOS25X1C35Pkzu7+46p6NMntVfUbSf4qyQeX4z+Y5Peqai3J15JctwvzBgDYthOGUHc/lORV\nxxj/YpIrjzH+f5P8+x2ZHQDALvKXpQGAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACM\nJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhC\nCAAYa/+qJwDALJfecUMue+d9q54GJPGKEAAwmFeEANhzazdfteopQBKvCAEAgwkhAGAsIQQAjCWE\nAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggA\nGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICx\nhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIx1whCqqu+tqk9W1Wer6pGq+vVl/ENV9aWq\nenD5OrCMV1W9r6rWquqhqrpit58EAMBW7N/EMc8meU13P1NVZyT5y6r6k2Xff+7uj7zg+NcnuXz5\n+vEk71++AwCcUk74ilCve2a5ecby1S9yl2uTfHi5331JzqmqC7Y/VQCAnbWp9whV1b6qejDJ00nu\n6e77l13vWS5/3VxVZy5jFyZ5fMPdn1jGAABOKZsKoe4+0t0HklyU5Mqq+pEk70ryQ0l+LMl5SX7l\nZB64qg5V1QNV9cCRZ751ktMGANi+k/rUWHd/Pcm9Sa7u7ieXy1/PJvndJFcuhx1OcvGGu120jL3w\nZ93S3Qe7++C+s8/a2uwBALZhM58ae3lVnbNsf1+S1yX56+ff91NVleQNSR5e7nJXkrcsnx67Ksk3\nuvvJXZk9AMA2bOZTYxckua2q9mU9nO7s7j+uqk9U1cuTVJIHk9ywHH93kmuSrCX5dpK37vy0AQC2\n74Qh1N0PJXnVMcZfc5zjO8mN258aAMDu8pelAYCxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBg\nLCEEAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYS\nQgDAWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEE\nAIwlhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjCSEAYCwhBACMJYQAgLGEEAAwlhACAMYSQgDA\nWEIIABhLCAEAYwkhAGAsIQQAjCWEAICxhBAAMJYQAgDGEkIAwFhCCAAYSwgBAGMJIQBgLCEEAIwl\nhACAsYQQADCWEAIAxhJCAMBYQggAGEsIAQBjVXeveg6pqv+T5FtJ/m7VczmFvCzWYyPrcTTrcTTr\ncTTrcTTrcbSJ6/Gvu/vlx9pxSoRQklTVA919cNXzOFVYj6NZj6NZj6NZj6NZj6NZj6NZj6O5NAYA\njCWEAICxTqUQumXVEzjFWI+jWY+jWY+jWY+jWY+jWY+jWY8NTpn3CAEA7LVT6RUhAIA9tfIQqqqr\nq+rzVbVWVTetej6rUFVfrqrPVdWDVfXAMnZeVd1TVX+7fD931fPcLVV1a1U9XVUPbxg75vOvde9b\nzpeHquqK1c18dxxnPd5dVYeXc+TBqrpmw753Levx+ar62dXMevdU1cVVdW9VPVpVj1TVLy3jI8+R\nF1mPkedIVX1vVX2yqj67rMevL+OXVNX9y/O+o6pesoyfudxeW/a/YpXz32kvsh4fqqovbTg/Dizj\np/Xvy6Z098q+kuxL8oUk/ybJS5J8NskrVzmnFa3Dl5O87AVj/yXJTcv2TUl+a9Xz3MXn/9NJrkjy\n8Imef5JrkvxJkkpyVZL7Vz3/PVqPdyf5T8c49pXL782ZSS5Zfp/2rfo57PB6XJDkimX7pUn+Znne\nI8+RF1mPkefI8s/57GX7jCT3L//c70xy3TL+gST/Ydn+j0k+sGxfl+SOVT+HPVqPDyV54zGOP61/\nXzbztepXhK5MstbdX+zuf0pye5JrVzynU8W1SW5btm9L8oYVzmVXdfdfJPnaC4aP9/yvTfLhXndf\nknOq6oK9meneOM56HM+1SW7v7me7+0tJ1rL+e3Xa6O4nu/szy/Y3kzyW5MIMPUdeZD2O57Q+R5Z/\nzs8sN89YvjrJa5J8ZBl/4fnx/HnzkSSvrarao+nuuhdZj+M5rX9fNmPVIXRhksc33H4iL/4Lfbrq\nJH9eVZ+uqkPL2Pnd/eSy/dUk569maitzvOc/+Zx5+/LS9a0bLpWOWo/lMsarsv5fuePPkResRzL0\nHKmqfVX1YJKnk9yT9Ve9vt7dzy2HbHzO31mPZf83kvzA3s54d71wPbr7+fPjPcv5cXNVnbmMnfbn\nx4msOoRY91PdfUWS1ye5sap+euPOXn/9cuzH+6Y//8X7k1ya5ECSJ5O8d7XT2XtVdXaSP0zyju7+\nx437Jp4jx1iPsedIdx/p7gNJLsr6q10/tOIprdQL16OqfiTJu7K+Lj+W5Lwkv7LCKZ5SVh1Ch5Nc\nvOH2RcvYKN19ePn+dJKPZv0X+annX55cvj+9uhmuxPGe/8hzprufWv7l9s9Jfif/cmljxHpU1RlZ\n/x/93+/uP1qGx54jx1qP6edIknT315Pcm+Qnsn6JZ/+ya+Nz/s56LPu/P8nf7/FU98SG9bh6uaTa\n3f1skt/NwPPjeFYdQp9Kcvny7v6XZP2Na3eteE57qqrOqqqXPr+d5GeSPJz1dbh+Oez6JB9bzQxX\n5njP/64kb1k+6XBVkm9suDxy2nrBNfufy/o5kqyvx3XLJ2EuSXJ5kk/u9fx20/L+jQ8meay7f3vD\nrpHnyPHWY+o5UlUvr6pzlu3vS/K6rL9v6t4kb1wOe+H58fx588Ykn1heUTwtHGc9/nrDfzRU1t8v\ntfH8OG1/XzZj/4kP2T3d/VxVvT3Jn2X9E2S3dvcjq5zTCpyf5KPLe/X2J/mf3f2nVfWpJHdW1duS\nfCXJm1Y4x11VVX+Q5NVJXlZVTyT5tSS/mWM//7uz/imHtSTfTvLWPZ/wLjvOerx6+bhrZ/1Thr+Y\nJN39SFXdmeTRJM8lubG7j6xi3rvoJ5P8QpLPLe97SJJfzdxz5Hjr8eah58gFSW6rqn1Z/4/7O7v7\nj6vq0SS3V9VvJPmrrMdjlu+/V1VrWf9QwnWrmPQuOt56fKKqXp71T4c9mOSG5fjT/fflhPxlaQBg\nrFVfGgMAWBkhBACMJYQAgLGEEAAwlhACAMYSQgDAWEIIABhLCAEAY/0/+/Z6Go3XJGcAAAAASUVO\nRK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 720x720 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# The following is equivalent to tensorflow:\n",
|
|
"# N,H,W,C = img.shape\n",
|
|
"# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))\n",
|
|
"\n",
|
|
"# transposed conv = 180deg kernel roation plus LHS dilation\n",
|
|
"# rotate kernel 180deg:\n",
|
|
"kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))\n",
|
|
"# need a custom output padding:\n",
|
|
"padding = ((2, 1), (2, 1))\n",
|
|
"out = lax.conv_general_dilated(img, # lhs = image tensor\n",
|
|
" kernel_rot, # rhs = conv kernel tensor\n",
|
|
" (1,1), # window strides\n",
|
|
" padding, # padding mode\n",
|
|
" (2,2), # lhs/image dilation\n",
|
|
" (1,1), # rhs/kernel dilation\n",
|
|
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
|
|
"print(\"out shape: \", out.shape, \"<-- transposed_conv\")\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"print(\"First output channel:\")\n",
|
|
"plt.imshow(np.array(out)[0,:,:,0]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "v8HsE-NCmUxx"
|
|
},
|
|
"source": [
|
|
"### 1D Convolutions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "WeP0rw0tm7HK"
|
|
},
|
|
"source": [
|
|
"You aren't limited to 2D convolutions, a simple 1D demo is below:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 162,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 674
|
|
},
|
|
"colab_type": "code",
|
|
"id": "jJ-jcAn3cig-",
|
|
"outputId": "64e578be-92c5-4aef-9d5d-ae93939f9b31"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"in shapes: (1, 200, 2) (3, 2, 2)\n",
|
|
"ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))\n",
|
|
"out shape: (1, 200, 2)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlMAAAEvCAYAAABhSUTPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3de7AlV3Xf8d/qM0iIhwR4BoI1M4yw\nBY5iO4aaACmw44ecSEosOXHikiopP0KsSsUkdvmRkosEUzj5A7tip0gU23KZsk1hZPArk7JcODYY\nKCfCEiBADwRjIVsjC0lIIOGAkHR65Y8+3ffo+t45e+/efc/cs76fKtXM3Dmj6dN3+va6+7fXanN3\nAQAAoEyz7gMAAADYzyimAAAARqCYAgAAGIFiCgAAYASKKQAAgBEopgAAAEY4sK6/+ODBg37s2LF1\n/fUAAADJPvShD33W3Q/t9HtrK6aOHTumm2++eV1/PQAAQDIz+/Pdfo+YDwAAYASKKQAAgBEopgAA\nAEagmAIAABiBYgoAAGAEiikAAIARKKYAAABGWFlMmdlbzewBM7t1l983M3uLmZ00s4+Z2cvrHyYA\nAMCZKWVl6lckXXKa379U0oWL/66W9PPjDwsAAGB/WDkB3d3fb2bHTvOSKyT9mru7pBvN7Dlm9kJ3\nv6/SMWKP3HrvI/r4vY+s+zDOeCbpW77m+XrBuU/P/rN/+fkv6X2ffLD+Qe0TX/H5j+k5j34y6bVH\nnvcMvfC8/HO8EayRXnqZ9MyD+X/24U9Ln35f/WPaREf/rnTopes+CmyAGo+TOV/SPUu/PrX42F8r\npszsanWrVzp69GiFvxo1/di7PqpPfOYL6z6MfeH7X31MP/kdfyv7z/2393xK7/jTe1a/cEN94Kwf\n1pEmbjGZ5Zvukb719fl/7j0/Jd36W/WPZxO9+Fuk7/nddR8FNsCePpvP3a+TdJ0kHT9+3Pfy78Zq\njz0x17df9AL91BVfu+5DOaP9w7d8QI890Rb92ceeaPWV5z1dv/1vXl35qPaHQ78ofenod+kL3/gf\nT/u6//x7d+izf/Vlvf1fvXKPjuwM85aXSU9+qezPPvGYdPClFAmrvOv7pCcfW/dRYEPUKKbulXRk\n6deHFx/DPtO69KyzD+hvRI1WEh2Ymdq27HuB1l1nHWgCn+NW5zzrOTrn/AtO+6ovn/OwHvx/fyWd\n+5V7dFxnmOaA1JYV7PK5dODsuOcu1dPOkR7/4rqPAhuixmiEE5K+Z9HV9ypJj7Bfan+aty6zdR/F\nmW9mptbLiql562oin2RvpWa28mWzxlRYr24Gm3XnqkTiOQ5vzDkGtlm5MmVm75D0zZIOmtkpST8p\n6WmS5O6/IOkGSZdJOinpi5K+f6qDxbRad80i3+gTmZnmhcVU666mCXyO27bbXL2CmYpX/zaCWbfC\nVKKdJ53j8KwpP8fANindfFet+H2X9IPVjghr07prFvlGn2jWmAprKbWtYhes3nYrAit0K1OBi6lm\n5MpUwjkOb8w5Brbh2xcM5m236oLTa6yL60rMPXiU6nOlnIBmxOrfRrCmW2Eq4axMJbGmfF8asA1X\nHAbdytS6j+LM1zQjYr42+OpfO0/az9OYxb7P2WxEzMeeqSTEfKiIWycG7JlKMzOTj9gzFbqYSo75\nRMw3KubjS/tKxHyoiCsOg66bL/CNPlFjNiLmCx6lJkZQY87xRhgTQRHzpRkTpQLbcMVhED6CStQ0\npnnhfa5tXbOop9g9uW2/ib4BfUwElRilhjcmSgW2oZjCoHVRTCWYNSLmK9Gfs5SYz4LPmaKbb3rE\nfKiIYgqD8J1micZ0moWOUvtVgKSYr7xjciPQzTc9uvlQEVccBl0EFfRGn2HMfp7Qm/z74qBJKKaa\n8kf2bAS6+aZHzIeKKKYwCB1BZRg1tDNylNpHKskxX+Biim6+6TUNMR+q4YqDpG4PUBu90yzRqKGd\nkZ9/mBPzjZjltRHo5pse3XyoiCsOkrb2BoeNoDKM2TPlkVf/hpiPoZ0r0c03PWI+VEQxBUkaioOo\n9/kcXcxX/jiZJmrBmhPzMbSTbr6p0c2HiiimIGkrtmqoplYaNbSzFcVU6tDOyMUU3XzTo5sPFXHF\nQdJSzEcxtVI3ULLsz3rk5x/2xVRKN591m/xLVwD3PRu5MkXMt9qYcwxsE/XLOrYh5kvXWHkENW8D\nx3xtzpyp7hyFnY4was8U3XxJzNgzhWq44iBpKeaLeqPPMBv1bD6PG6UO3Xxpe6akwIM7mxndfFNr\nZnTzoRquOEjailOI+VYbF/MF7pgcYr60Z/NJgTeh24gZSMR8aYj5UBHFFCSxMpWjMRVP5+5ivsoH\ntF8UxXyRi6kRoxFYmVptzDkGtuGKg6SlPVNh7/TpZiMGSs7byDFf3gR0KXrMN6abj5WplfrRCFEL\ndlRFMQVJDO3M0Yx41IlHfjZfUcw35QGdwejmm15fcFJMoQKKKUhajvnWfCD7QDedm6Gd2YaYb/X7\n7/8dhn3YMd180+vPEVEfKuCKg6StvSlhI6gMY2K+1gOf46xuvkXMF3XVgG6+6fXzzujoQwVccZC0\n9XU77KpJhjHPjWsjb0DPifnYgE7MN7Uh5qOjD+NRTEHS1gpA2OncGUYN7eRBx3ndfFHvc3TzTY+Y\nDxVxxUHSUszHytRKs6Z8aGcbeQJ6ydDOqCtTdPNNr1+9I+ZDBRRTkLS10TfsjT6DWfnQztYDn+O+\nMEpYNbFhZSpoMTU25mNlarVhZSrq8idq4oqDpOWYL+iNPsOsGfdsvrBRar8CkPCg41n4PVOzsvip\nTd+XFh57plBR1C/r2IYN6OlmI+ZMtaGfzZcxtDP6nKmmcM5UxjkOr2FlCvVQTEHS8p6pNR/IPmAj\nHnTcRp4z5ekb0PtTFHYCujVlu+89fZZXeMZoBNRDMQVJWzctYr7VZs2IoZ1t4AnoQ8yXszIVuJgq\nivnSz3F4Q8xHMYXxKKYgiaGdOWbNyA3oUc9xwbP5whZTxHzTa9gzhXoopiCJ0Qg5zMpa9tvoj+zJ\nivmCP+jYmrL4KeMch0fMh4q44iBpa6Nv2Agqw6zw2Xx9wRr2HGd0mg0xX9RFg+JuPmK+ZHTzoSKK\nKUjiQcc5upgvv5iaR49ShwgqYTRCv2hAzJdnmOVFMbUSMR8qopiCpKUIKuqNPkM/tNMzb/Thx0+U\nxHxRiym6+abXnyNiPlRAMQVJSzEfxdRKW5uj8/5cG/35hzndfExAJ+abGt18qCjql3VsM2fOVLLS\nCGoefZM/QzvT0c03PWI+VEQxBUk8my9HaadZ+HPM0M50dPNNj24+VMQVB0nLEVTQG32G/hzlbucJ\nH6X2Jywj5svdl7YxbOTKFDHfanTzoSKKKUha7uYLeqPP0NdCuZujw3dMtumrJn0jROgN6GP2TLEy\ntVp/jiimUEHSFWdml5jZnWZ20syu2eH3j5rZe83sI2b2MTO7rP6hYkoM7UzXlMZ84UcjZBRT0Yd2\n9itLuR197JlK1xDzoZ6VX9XMbCbpWkmXSrpI0lVmdtG2l/0HSe9095dJulLS/6h9oJhW+Agqw1bM\nV1ZMhR3amRFBlUapG6M0ghrOMStTKxHzoaKUK+4Vkk66+13u/rik6yVdse01Luncxc/Pk/SX9Q4R\neyF8BJWhdNUkfJSaE/OF34C+OAG5UR8xX7oh5mNlCuMdSHjN+ZLuWfr1KUmv3PaaN0r6AzP7t5Ke\nKeniKkeHPRM+gspQup9nGNoZ9RwPMd/qlakm+tDOIebLvNFnnOPwSs8xsINa375cJelX3P2wpMsk\nvc3sr39rZGZXm9nNZnbzgw8+WOmvRg3hI6gMW51meX8u/NDOopgvaDE1OuajmFqJmA8VpXxZv1fS\nkaVfH158bNlrJb1Tktz9/0p6uqSD2/9H7n6dux939+OHDh0qO2JMYh79UScZSiOo8EM72/Rn821F\nqVMe0BmsNIIi5ktHzIeKUq64myRdaGYXmNlZ6jaYn9j2mr+Q9G2SZGZ/U10xxdLTPrIV8635QPaB\nIebLLKY8ejGV0c3Xr94R8+XGfHTzJSvtmAR2sPKrmrs/Kel1kt4t6Q51XXu3mdmbzOzyxct+VNIP\nmNlHJb1D0vd52PX5/Sn8dO4MpTFf+NW/jAiqYWhn92Pu+6ebLx1zplBRygZ0ufsNkm7Y9rE3LP38\ndkmvrnto2EtzJqAnawpXTfqVrLB7prK6+YLPmaKbb3rEfKiIKw6StuZMhV01yTB6aGfUc5zRaTYr\njFI3Bt1806ObDxVRTEHScsy35gPZB8YO7YxbTGXEfAzt7H6km286dPOhIoopSFqOoILe6DOUzkAK\nf46zuvm6H8NuQKebb3rEfKiIKw6SGNqZoy+mcpuAhig16jnO6DTrN/m3UYupZuTKFDHfak3hJn9g\nBxRTkEQElaOvhXJv9FvnuPYR7RPDfp7VJ8CGgjXoja5fNSkejcCX9pVKzzGwA644SNpq22cC+mql\nm6OHmC/qOW7n3Q0s4f2H34Beup+nLwwYjbAaMR8q4oqDJIZ25uhjuuKVqahLU94mx09bMd+UB3QG\nawpnIBHzpSuNUoEdcOuEJIZ25mgK9/O04Yd2zpPjpyGBibqfpTjmYwN6MmI+VMQVB0lbKwBhI6gM\ns8LnxoV/0HE7T27Zn4Uf2tmvmhR28zEaYbXScwzsIOqXdWzTt6BTS63WFK6abJ3joCfZPT3ma6LH\nfHTzTY6YDxVRTEFSF/M1FvhGn6Ep7DRro29Az4n5CjsmNwbdfNMbzjHFFMbjioOk7qYVdphkpqHT\nLHsD+lP/fDjtPLnDgZhvbDcfK1Mr0c2HiiimIKkrDFiVStMUdpr1hUHY05zTzVfYMbkxRnfz8aV9\nJWI+VMQVB0ldBBU2fso0DO0sfNBx2JWprJiPoZ2S6OabEt18qIgrDpK6VZawN/lMpQMlh2IqatGa\n0c0ndec57rP56OabHN18qIhiCpK6wiDqPT5X6ZyprZgv6InO6OaTuqIz6sIU3Xx7gJgPFVFMQRIb\n0HMUD+0k5suKn8yI+Yj5JkQ3HyriioOkRTEVdcUkU+kMpDb68w+9zXpe0ayxuBvQS7v5+vNFzLda\n6TkGdkAxBUndNO+w8VOmfmEp+0HH0QejtnkrU41Z9pT5jWGF3XwtK1PJ+guRPVOogCsOkiR3j/uY\nk0ylDzp2Yr6svTyNMRqBmG9CZt15opsPFXDFQVK3yhL2AbyZZsUb0Lsfw55nb7O7+cIWU8UxX/+P\njJgvic2I+VAFxRQkdRFU2Jt8pqbwQcd9zJexbWizFMV8UYupwuncxHx5rCHmQxVccZDU7VsNGz9l\nGhKYzBu9R58zlTEBXeri1LArU/3KUnHMx8pUkmZGzIcqKKYgqY/51n0U+0Ppo076VZawK4C53Xxm\ncbvWifn2hs22OiCBESimIGkR81FNJRlivtJiKup5zo758s/xxiju5uPZfFmI+VAJVxwkLbr5oq6Y\nZGoKnxvX1wVh49Tcbr7G4g7tHB3z8aU9SUM3H+rgioMkuvlylA7tHDagRz3NdPOlK12Z8rb7s1zL\naejmQyUUU5DUdaaFjZ8yFQ/tjL5nqqSbL2gtNaqbj1WpdMR8qISrDpIY2pmDoZ2Fcrv5Ij+bb0zM\nRydfOrr5UAm3T0hizlQOhnYW4tl86cZ089HJl45uPlRCMQVJ3f6fsDf5TKVDO9voe6YY2pmuOOZr\niflyEPOhEq46SOrilLA3+UzD0M7M72hb9+5xYFGL1uyYz7I3+W+MfnUpd9Uk8xyH1zRsQEcVFFOQ\n1G2ODruXJ1PpaIR5G3z8hGeuTDWBH3Tcn6eiPVOB/43l4kHHqIRiCpK6mxYxX5pZ4dDO1oN3TLbz\nvNEIxHxl3XzsmUpnM2I+VEExBUkUUzmawjlT3Tme4ID2C3eezZeqGbEBnZgvXcOcKdRBMQVJxHy5\nStr2ifnyIqhuz1TQYmpUzMeX9WTEfKiEqw6SiKByzRoriPmCP/+QmC/dMBqhoJuPmC8dE9BRCcUU\nJBFB5SpZNWmjP7Int5uvyY9SNwYx396gmw+VUExBEhFUrsbyH8I79+BRam43X8E53hhDzJdbTNHN\nl4WYD5VQTEESMV+uWWMFQzuDD0bNjfkKotSNQTff3qCbD5UkFVNmdomZ3WlmJ83sml1e891mdruZ\n3WZmv173MDE1hnbmaaxgaGf0c5zbzRd5aKfZYjo3Md+k6OZDJQdWvcDMZpKulfTtkk5JusnMTrj7\n7UuvuVDST0h6tbt/zsyeP9UBYxrhI6hMJW374Tsms2O+wA86lsoiKLr58hDzoZKUq+4Vkk66+13u\n/rik6yVdse01PyDpWnf/nCS5+wN1DxNTY85UnpJOM2K+efaDjsN280llERQxXx66+VBJyle28yXd\ns/TrU4uPLXuJpJeY2Z+Y2Y1mdkmtA8TeCN9plqlbmcr7M91ohGmOZ18oejZf4GKqJILKjFLDo5sP\nlayM+TL+PxdK+mZJhyW938y+zt0/v/wiM7ta0tWSdPTo0Up/NWpoXbEjqEwlEVTrwTsmS7r5IhdT\n1tDNNzViPlSS8pXtXklHln59ePGxZacknXD3J9z905I+qa64egp3v87dj7v78UOHDpUeMyYwb52v\nwRlmBTf6efTVP88bKDkrWP3bKCURVOY5Do+YD5WkFFM3SbrQzC4ws7MkXSnpxLbX/K66VSmZ2UF1\nsd9dFY8TEwu/apLJjAno2do2a2XKwm9At7I9U2xAT2cNoxFQxcqrzt2flPQ6Se+WdIekd7r7bWb2\nJjO7fPGyd0t6yMxul/ReST/u7g9NddCor6WbL8usyR8o2baKXbD6PGs/T+g5U1K3wlTUzcfKVLKS\ncwzsIGnPlLvfIOmGbR97w9LPXdKPLP7DPjRvu9UWpCmJoOYePEr1Nq+bL/yeKWK+ydms27QPjMR6\nMCT1K1PrPor9w0z5MV/0OVOZEZSZZe+/3iglEVRmlBpeSZQK7ICrDpLYM5VrVvDcuPBRanbMp9hz\npppZYTcfX9aTEfOhEq46SOq7+QLf6DPNSiage/AotaibL3AxRcw3Pbr5UAnFFCQRQeUyK3jQceua\nRT3F7ouhnZkxX+hiim6+ydHNh0q46iCJoZ25Zk3Bg44jx3z9d/85MV/BI3s2Ct180yPmQyUUU5BE\np1mu0qGdYWO+vpjKfDZf5FqKmG8P0M2HSiimIElyNqBnsYJVE/fAc6b67/4Z2pmObr7pEfOhEq46\nSOJRJ7nKNqAHftBxf8PKjfkirxoUx3xR/5EVaHg2H+rgqoPcXa0r9qNOMs0KZiCFLliHmI9uvmQl\nERQxXx66+VAJxRSGr9dhI6gCJUM7PfIG9KKYL/rQTrr5JkfMh0q46jAUBVHv8yVmjclLYr6oBWtJ\nN19Bx+RGaQo3oNPNl67kHAM7oJjCsJGamC9dU7ABfd6KYipj1aSJvmfKCvbzsGcqjzX5U+aBHXDV\nYSvmo5hK1jSmee52lsjPP+yLgowd+I1ZN+szakFls7JuPvZMpSs5x8AOon5pxxJivnwzy7/JswFd\nmTFfd67CTkcg5pte0xDzoQqKKWzFfFFv9AWKYj73uFGq529A709V2CnoJRGUz8X03QwlUSqwA4op\nDCssxHzpmoahnVmGmC991aQZVqYCF1Ml3XzEfOmI+VAJxRRYmSowW+znydHFfNMczxmv8Nl8UuBi\niphvenTzoRKKKQx7UsJGUAWaJn/OVBs65ivr5pOix3x0803KFnumohbsqIarDsN3/lHv8yUas+zn\nxrWRN6CPifmiLhzQzTe9fhWP1SmMRDGF4Tv/sPt5CpQ+my/sOS5YmZotThUxXwZivjz9qA6KKYxE\nMYWtlSmWppKVDJQM/fzDkm6+xbkKO7iTbr7p9f8e6ejDSBRTGL5eh42gCjQFz41rI29AL4n5+g3o\nkfdM0c03rSHmo5jCOBRTGL7zDzudu0DJc+PmkR90zNDOfMR802vYM4U6uH1iaQN60Bt9gZKhnaE3\noI8Z2hk65qObb1LEfKiEqw5DjBL2Rl+gaSx7xaT1wOe4L4iI+dJZ4coUMV86uvlQCcUUlmK+oDf6\nAo0VxHwtDzrO2RzdRB/ambtnqs3vmAzP6OZDHVx1YAN6gVlJzBd6aGdfTOXvmQo7tLOZ5XXzFZzj\n8BpiPtRBMQWGdhZoCuZMtR55z1RfsZc8m2+KA9oHcmO+4RzzZT0ZMR8q4arD1tBOqqlkJRPQ523g\noZ1t+Qb0uDGfZcZ8+ec4vCHmY2UK43DVgaGdBWYNQzuzlMR80Z/N18zy4idivnz9SikxH0aimAKj\nEQo0ltfNt9UxOdEBnelGxXxBi6nimI9iKhkxHyqhmMJQFISNoAo0ltey3xcEYc9xQafZ1miEKQ5o\nH8ju5iPmy0Y3HyrhqsMQo4RdNSmQG/PNo0epBUM7+zESYYd2Znfz5U+ZD4+YD5VQTGErgop6oy/Q\nmMld8sQbffjxEyUxX/g5U3TzTY6VKVTCVYetmI9iKtnWjT7t9W305x8WdfNFn4BON9/k6OZDJVx1\n2IqgqKWSDRFU4o1+Hn2TP0M789HNNz1iPlRCMQWezVcgt9Ms/DkeFfNNcUD7AN1806ObD5VQTGEp\nggp6oy+Qu58nfJRa1M23+KNh90zRzTc5Yj5UwlWHpW6+oDf6ArPMVZPwHZOeX0zNos+ZagpXpoj5\n0vWreFH/jaEaiikwtLNAf6pS9/OEnzJfMBrBok9A789V6niEgoI1vP5CZs8URuKqAxFUgWHVJLOY\niju0c3GzytjPE35latjPk3ijH84xX9aT5Z5jYBdJV52ZXWJmd5rZSTO75jSv+y4zczM7Xu8QMbXw\nEVSB3Bt9+Ci1IIKaRZ+A3mTOQCLmy9ewAR11rCymzGwm6VpJl0q6SNJVZnbRDq97tqQfkvTB2geJ\naYWPoAoMEVRiMdW/LOw5Lor5uh/DTkAfYr7EVZOCcxxe7jkGdpFy1b1C0kl3v8vdH5d0vaQrdnjd\nT0l6s6THKh4f9gB7pvLlrpqEX/0bE/OF3TNVGvOxMpWMmA+VpBRT50u6Z+nXpxYfG5jZyyUdcfff\nq3hs2CPzRUEQdj9PgX5oZ3LMF338RH+eCoZ2Rq2lsiMoYr58xHyoZPR6sJk1kn5W0o8mvPZqM7vZ\nzG5+8MEHx/7VqGQr5lvzgewjuZ1mHn31b4ig0t9/Q8zX/Zgc89HNly23YxLYRcpVd6+kI0u/Prz4\nWO/Zkr5W0h+b2d2SXiXpxE6b0N39Onc/7u7HDx06VH7UqCr8dO4Cs8yhnfPoDzouiKB4Nl/mqgnd\nfPkY2olKUq66myRdaGYXmNlZkq6UdKL/TXd/xN0Puvsxdz8m6UZJl7v7zZMcMaoLH0EVyI2g+hWs\nsA86Lunmiz4agW6+6RHzoZKVX9rd/UlJr5P0bkl3SHqnu99mZm8ys8unPkBMry8Iwq6aFCge2hn1\nHBd0mjUM7ex+pJtvOnTzoZIDKS9y9xsk3bDtY2/Y5bXfPP6wsJfa6J1mBXJXTcIXUyUxX/SVKbr5\npkc3HyrhWxgsRVBBb/QF8vdMBT/HJd18mc8/3Dh0802PmA+VUEyBoZ0Fcrv5hig16jkuivm6H4n5\niPkmQzcfKuGqAxFUgX6FKTWB2jrHUx3RGa7/zj+j06wZznHUYip3Zaqv2FmZSmaZm/yBXVBMYetB\nxxRTyXJXTfp9aWHPcTvPXjFhA3rmjb7Nn+UVHqMRUAnFFIabFV+D0/WrJqkDJfvXWdST7PPsvTyz\n4fmHUxzQPtCUxnysTCXrV/Ho5sNIFFPYWjUJm0Hl62/0qRFUvyUj7Dn2Njt+GsYsEfOlvX6IUimm\nkuWeY2AXFFMg5iuwFUGlvb4dBqNOdURnOGK+fLkRVMsG9GzEfKiEqw5LEdSaD2Qf6VdNUm/0xHxt\nfsyXGaVunNwIipgvHzEfKqGYgtrW1VjgG32B/Jgv+AZ0b7OfGddYXsfkxqGbb3rDOY76jwy1UExB\nrXvcvTyFcjegD1Fq1PNcFPN1PxLz5Xbz8WU9Wf/NDTEfRuKqg+burEplyt3PE75jsqSbrwm+Z6q4\nm48v68mI+VAJVx3Uth43fipUOrQz7MpUQTefmcmMbj66+SZENx8qoZiCWg98ky+UPbTTg++ZKoj5\npG4FMOwGdLr5pkc3HyrhqoPmrceNnwoNMV/mg47DxqkF3XxSV3ymjp/YOHTzTY+YD5VQTEHOBvRs\ns8znxnn0DegF3XxS90eI+Yj5JkM3HyqhmILm7jzkOFPu0M5+ZSpqLTUq5ou6AT075muf+uewGjEf\nKuGqg+atKKYy9ZPM28xn84U9zyNivqi11LDClLpq4hRT2ZrM8RPALrjqsIj51n0U+0u/9ym1mPLw\n3XxlK1Nm6ed44/SFN6MRpmUNe6YwGlcdNG+J+XLNsudMdT+GPc/tvGgvz6yJHPP1K1OZ3Xzsmcpj\nM2I+jEYxBfZMFehXmFLv80PMF/WKK435Gou7MtUUbkCnmy9PMyPmw2hRv7RjiTNnKtuQwCRWUx59\nzpS3hTFf4GLKmIC+J4j5UAFXHRYx37qPYn+ZZT6bb6ubL+iJbudFy3Kz0N18uTEfoxGKGCtTGI9i\nCl3MRzWVZZa5AX0opqKe51Ex3wTHsx8Ud/NRTGVpGoopjEYxha6bL+qKSaGhmy855ut+DBunjunm\ni1pNFcd8Qf+NlSLmQwUUU6Cbr8AQ86V28zlDO4u7+aLvmcrp5rOGYioX3XyogGIK3dDOsHf5Mlsx\nX9rrw++ZYmhnvpJuPiK+fHTzoQKKKTC0s4BlTkBnaGdpNx8xX1bMRydfPmI+VMCVB+ZMFWBoZ6bS\nbj6GdubFfHTy5aObDxVQTEGtB77JF8od2tlG3zNVGEE1kedMZcd8TsxXgm4+VEAxBbXMmco2DO1M\nvNG37jLb6gIMpzCCCl1MDTFfajFFzFeEmA8VcOVB89bj7uUpNMscjTBvg4+f8La4my9qyrfVzZex\nAT3s84pGIOZDBVx5UMueqWz9+Upt2289eMdkW7oylb4vbeOUjkZAHmsYjYDRuPJAMVWgL4xSV6a6\nczzlEZ3hSvdM8aDjzG4+9kxla2bEfBiNYgrEfIVyIihivrIIahZ6z1TBnCm6+fIR86ECiikQQRVq\nLCfmC/78w+KYL/JohNyYrwUg3m4AABFYSURBVGyWV3hGNx/G48oDEVShxiw95ov+yJ7CCKpp0pvZ\nNs4Q8+V087Eyla2hmw/jUUyBCKrQLGM/z9yDR6mjuvmirkyVxHx8Sc9GzIcKuPJAzFeoi6DSXht+\nMOqYmC9sMbX490I337To5kMFXHlgaGehxjKGdkY/x2MmoIfdM2V5AyWJ+crQzYcKKKagNnoEVSgn\nggp/jku7+SIP7ZTyIii6+coQ86GCpK9uZnaJmd1pZifN7Jodfv9HzOx2M/uYmf2Rmb2o/qFiKnP3\nuI85GSGn02zeEvMxtLNATgRFzFeGbj5UsPLKM7OZpGslXSrpIklXmdlF2172EUnH3f3rJf2mpJ+u\nfaCYTssG9CI5AyW70QgTH9CZrLSbL/KcKSkvgiqMUsOjmw8VpHx5f4Wkk+5+l7s/Lul6SVcsv8Dd\n3+vuX1z88kZJh+seJqbUumJHUIVmZsld660HL1jp5itjMyn1/dPNV4aYDxWkXHnnS7pn6denFh/b\nzWsl/f6Yg8LemreuyPf5UjlDO+fR50wxtLMMMd/06OZDBQdq/s/M7F9IOi7p7+3y+1dLulqSjh49\nWvOvxgjhV00K5cd8gc+xe/Gz+SIvTKnJ2M9DzFemYWUK46V8G3OvpCNLvz68+NhTmNnFkl4v6XJ3\n//JO/yN3v87dj7v78UOHDpUcLyYQvtOsUN4EdMUuWH3EBvTI1VT2aARWprLlnGNgFylX3k2SLjSz\nC8zsLElXSjqx/AIze5mkX1RXSD1Q/zAxpXkruvkKzBrTPPVBxx48Sm3nxQ86jh3zzfJiPkYj5GPP\nFCpY+dXN3Z+U9DpJ75Z0h6R3uvttZvYmM7t88bKfkfQsSe8ys1vM7MQu/zucgbqVqXUfxf6TO7Qz\n9Opf6dDO8DFfzpypsig1vJwoFdhF0p4pd79B0g3bPvaGpZ9fXPm4sIfYM1UmK+aLHqWOiflCr0xl\nPOnZ51JTdRtsDMR8qID1CCy6+QLf6AvNmoyhnR48Si2MoLooNXIxRcw3uZxzDOyCYgpy5kwV6QZK\npr3W3TWLeordJRV285nJIxdTdPNNj24+VEAxhcUMpHUfxf7TNOl7pkLPmepvVMyZykc33/SI+VAB\nVx40jz4DqVBOp9m8DXyO+xtV4YOOYxdTxHyTo5sPFVBMYRFBBb3Rj5AztNM98JypYWWqNOarfDz7\nCd1806ObDxVQTCF2BDVCzkN455EfdNyvrDC0M192zMd1nI2YDxVE/fKOBXdX64obQY2QHfNFvdEN\nMV9hN1/4mC/1adrEfEXo5kMFFFPB9d/0h42gRug2oKe91iPPmRoT84Uf2kk33+To5kMFFFPB9RFK\n1Pv8GDlDO+ceeGVqVDcfMR/dfBMj5kMFXHnB9REKMV++WcYG9HkriqmSmC9jX9pGyon5vCXmK2Ez\nxV7+RA0UU8ENMR/FVDaz9Acde+TnH/bf9RcUk7bo5gs7uNOajNEILStTJXLOMbALrrzgiPnKzUzp\nMV/kDehDN1/ZBnQp8PP5mllmzMfKVLaGmA/jUUwFN8R8UW/0I2TFfJEHo46J+RbnLGotlRVBeVs0\nGDU8hnaiAq684Pr4hJgvn2WMRgg9tLMtnzPVn7Kw+6bM8iagE/PlI+ZDBVx5wbEyVS5nc3To5x+O\nifmMmI+Yb2I55xjYBcVUcP09KmwENUIX86W9tg0d8/X/yMbEfEGLKbr5pmczSU5HH0ahmAquZQN6\nMcvYgN5G3oA+KuZbFFNRt7TQzTe9/pyxbwojcOUF18cnYffzjDBrLHmg5Dzyw6RHPJtvtjhlYQd3\nEvNNr9+0T9SHESimghtWpliaypazZyr08w+rdPMFLabo5pteX4CyMoURuPKC6+OTsBHUCGaWHD+1\nkTegV4n5ohZTdPNNboj5WJlCOa684ObDaIQ1H8g+NGvSu8zmoR90XGFoZ9SVKWK+6fUrpsR8GIFb\naHBbG9CD3uhHyBnaGXoD+shn80nRh3bSzTcpYj5UQDEVXMucqWKWu2cq6jnus9AxQzujVlOp3Xzu\nXTFAzJePbj5UwJUX3JwJ6MW6Dehpr523gR907OXFVPgN6E3iylR/foj58jWsTGG8qF/escAG9HKN\npe+Zij20s3wDehN9Aro1aUO2Rpzj8IblT/ZMoRxXXnAM7SzXNJY+tNMD75nqb1IF+3ma6CtTNkuL\n+YZzzJf0bMOeKYoplOPKC64l5iuWO2cq7tDOPuZjA3q2pkmM+crPcXjEfKiAYio4HnRcrsmZgB55\nztSomK/7MXbMl7IBnZivmDEBHeNx5QXHBPRyTeLQzqFjMuo5bstHI/TnLG4xlRvzsTKVjZgPFVBM\nBdffo8JGUCPMmrS9PEOUGvUcV4j5om6ZSu/mI+YrNsR8Uf+RoQaKqeC2Yr41H8g+1FhazDePvvo3\nRFD577/fTx12AnpyN1/5+InwiPlQAVdecOEjqBEas25W4oobffjxE2O6+cKPRqCbb3I8mw8VcOUF\nN8R8FFPZtgZKnv51bfTnH46J+Zo+5gtaTNHNNz26+VBB1C/vWJgzZ6pYaqfZPPrzDxnaWY5uvukR\n86ECrrzgeDZfudSBkuHPcY2YL+rKFN1806ObDxVQTAXH0M5yWwMlVxRT0aPUEc+N24r5ah7QPkI3\n3/SI+VABxVRwDO0slxpBhe+YHNPNx9DO7ia/qpqkm6/cEPNRTKEcV15w/T2KYipfk7gB3cOPRhg/\ntDP0s/mk9GKKmC/f0M1HMYVyFFPBbU1AX/OB7EN9bbTqYcfhN6C34zegxy2mEtv22/LVv/AYjYAK\nuIUG18cnYadzj9Dv51m1OTr8OR5ivvIJ6POoiwZNYqfZiHMcXr+aRzcfRqCYCo5n85VLXTXpfzvs\nOR4V83U/xl2ZStwcTcxXLvUcA6eRVEyZ2SVmdqeZnTSza3b4/bPN7DcWv/9BMztW+0AxjTZ6BDXC\nUEyt+BocfgN6jZgv8gZ0KSPm4/vjbMR8qGDllWdmM0nXSrpU0kWSrjKzi7a97LWSPufuXy3p5yS9\nufaBYhp9fBI2ghqhn2i+MuaLPn6iwgT0sHOmUiMoYr5ywzlmZQrlDiS85hWSTrr7XZJkZtdLukLS\n7UuvuULSGxc//01J/93MzNf4DIjP/uWf69Rtf7Kuv37feOK+R3Vx84Cecffj0oNnr/tw9pXz7/+s\nLm7u1p994LN68OlP2/V1D3/xcV3c/IVe+JnPSZ943h4e4Rni/lu7Hwu6HPqVqY/e83mdfSBeoXDk\n/i/qayR9/I/fpfmBZ+76umd84W69RNKHTz2qh+b379nxbYJzH/q8Xinpro+8R4/e/cC6DweFDl7w\n9Tr81V+7tr8/pZg6X9I9S78+JemVu73G3Z80s0ckfYWkzy6/yMyulnS1JB09erTwkNPcc+sH9LL/\n84OT/h2b4BskXXWWpP+17iPZf14j6TVnSfrQ6td+61mSblz8F9GBp3f/ZTr3nAMyk37pA5/WL33g\n0xMc2JntO5uH9V/Pkr7ugz+W9Po3/eF9usVvnvioNsuL7DN639nSi2+/dt2HghFuvO/fnfHFVDXu\nfp2k6yTp+PHjk65avfjvXKKTL/i9Kf+KjXHuOU/T85/FqlQul+svHv6SnkhoNTv7QKPDzz1HpqBR\n3zMPSU87J/uPPf/ZT9f7f/xb9MiXnpjgoPYBf7U+9bnvlLWr33974Bn6T8/5qj04qM3zyUePq3n8\nC+s+DIzw1S+YdoFmlZRi6l5JR5Z+fXjxsZ1ec8rMDkg6T9JDVY6w0HnPPajznvuadR4CNpxJetFX\nrvsoNt+R5z3jKV+Awjm8PQhAdef/7XUfAfa5lE0MN0m60MwuMLOzJF0p6cS215yQ9L2Ln/9TSe9Z\n534pAACAvbJyZWqxB+p1kt4taSbpre5+m5m9SdLN7n5C0i9LepuZnZT0sLqCCwAAYOMl7Zly9xsk\n3bDtY29Y+vljkv5Z3UMDAAA48zHhDQAAYASKKQAAgBEopgAAAEagmAIAABiBYgoAAGAEiikAAIAR\nKKYAAABGsHUNKjezByX9+cR/zUFte9hyMLx/3n/U9x/5vUu8f95/3Pc/5Xt/kbsf2uk31lZM7QUz\nu9ndj6/7ONaF98/7j/r+I793iffP+4/7/tf13on5AAAARqCYAgAAGGHTi6nr1n0Aa8b7jy3y+4/8\n3iXeP+8/rrW8943eMwUAADC1TV+ZAgAAmNTGFlNmdomZ3WlmJ83smnUfz9TM7IiZvdfMbjez28zs\nhxYff6OZ3Wtmtyz+u2zdxzoFM7vbzD6+eI83Lz72PDP732b2qcWPz133cU7BzF669Pm9xcweNbMf\n3uTPvZm91cweMLNblz624+fbOm9ZfC34mJm9fH1HXscu7/9nzOwTi/f4O2b2nMXHj5nZl5b+HfzC\n+o58vF3e+67/1s3sJxaf+zvN7B+s56jr2eX9/8bSe7/bzG5ZfHyjPvfSae91673+3X3j/pM0k/Rn\nkl4s6SxJH5V00bqPa+L3/EJJL1/8/NmSPinpIklvlPRj6z6+PXj/d0s6uO1jPy3pmsXPr5H05nUf\n5x6ch5mkz0h60SZ/7iV9k6SXS7p11edb0mWSfl+SSXqVpA+u+/gnev9/X9KBxc/fvPT+jy2/br//\nt8t73/Hf+uJr4EclnS3pgsV9Ybbu91D7/W/7/f8i6Q2b+LlfvKfd7nVrvf43dWXqFZJOuvtd7v64\npOslXbHmY5qUu9/n7h9e/PwLku6QdP56j2rtrpD0q4uf/6qk71zjseyVb5P0Z+4+9UDctXL390t6\neNuHd/t8XyHp17xzo6TnmNkL9+ZIp7HT+3f3P3D3Jxe/vFHS4T0/sD2wy+d+N1dIut7dv+zun5Z0\nUt39Yd863fs3M5P03ZLesacHtYdOc69b6/W/qcXU+ZLuWfr1KQUqLMzsmKSXSfrg4kOvWyxvvnVT\noy5JLukPzOxDZnb14mMvcPf7Fj//jKQXrOfQ9tSVeuoX0gif+95un++IXw/+pbrvxnsXmNlHzOx9\nZvaN6zqoie30bz3a5/4bJd3v7p9a+tjGfu633evWev1vajEVlpk9S9JvSfphd39U0s9L+ipJ3yDp\nPnVLwJvoNe7+ckmXSvpBM/um5d/0br13o1tXzewsSZdLetfiQ1E+939NhM/3bszs9ZKelPT2xYfu\nk3TU3V8m6Uck/bqZnbuu45tI2H/r21ylp34ztbGf+x3udYN1XP+bWkzdK+nI0q8PLz620czsaer+\ncb3d3X9bktz9fnefu3sr6Ze0z5e4d+Pu9y5+fEDS76h7n/f3y7mLHx9Y3xHuiUslfdjd75fifO6X\n7Pb5DvP1wMy+T9I/kvTPFzcULSKuhxY//5C6fUMvWdtBTuA0/9Yjfe4PSPonkn6j/9imfu53utdp\nzdf/phZTN0m60MwuWHy3fqWkE2s+pkktsvJflnSHu//s0seXs+F/LOnW7X92vzOzZ5rZs/ufq9uI\ne6u6z/n3Ll72vZL+53qOcM885bvSCJ/7bXb7fJ+Q9D2Lrp5XSXpkKQ7YGGZ2iaR/L+lyd//i0scP\nmdls8fMXS7pQ0l3rOcppnObf+glJV5rZ2WZ2gbr3/qd7fXx75GJJn3D3U/0HNvFzv9u9Tuu+/te9\nM3+q/9Tt4P+kukr89es+nj14v69Rt6z5MUm3LP67TNLbJH188fETkl647mOd4L2/WF3Hzkcl3dZ/\nviV9haQ/kvQpSX8o6XnrPtYJz8EzJT0k6bylj23s515d0XifpCfU7YF47W6fb3VdPNcuvhZ8XNLx\ndR//RO//pLq9If31/wuL137X4rq4RdKHJX3Huo9/gve+6791Sa9ffO7vlHTpuo9/ive/+PivSPrX\n2167UZ/7xXva7V631uufCegAAAAjbGrMBwAAsCcopgAAAEagmAIAABiBYgoAAGAEiikAAIARKKYA\nAABGoJgCAAAYgWIKAABghP8PV48jZt7U4foAAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 720x360 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlIAAAEvCAYAAACOiy/xAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3deZBsZ33e8efX65mt5yJ00X4lgVgs\nCCBywXgBG0yMwDaKcdmBeIHglMplnDJeQnCRciApV+LYcaVikzhKQZkYbDCxFSgDNuDCcRGzSVgI\nLSxiMYgIISTudM/S+5s/Ti8zPed0zzl9epl+v5+qW/fenumet9+e6fPM793MOScAAAAkl1t0AwAA\nAE4rghQAAEBKBCkAAICUCFIAAAApEaQAAABSIkgBAACkVFjEF7344ovdNddcs4gvDQAAkMjtt9/+\nLefc2aiPLSRIXXPNNbrtttsW8aUBAAASMbO/j/sYQ3sAAAApEaQAAABSIkgBAACkRJACAABIiSAF\nAACQEkEKAAAgJYIUAABASlMHKTMLzOwTZvZpM7vbzN6YRcMAAACWXRYbcjYkPd85t2tmRUkfMbP3\nO+c+lsFjAwAALK2pg5Rzzkna7f232Pvjpn1cLEijJt3zHqnbWnRLlstjrpeuelY2j+WcdM//luo7\n2TzeMsqXpOtvkkobi24JAMxUJkfEmFle0u2SrpP0JufcxyM+52ZJN0vSuXPnsviymIU73ym991cW\n3Yrls3mJ9Kufz+axvnGn9K5XZvNYy8w56YafXHQrAGCmMglSzrmOpKeb2RlJt5rZU5xzd418zi2S\nbpGk8+fPU7FaVvuPhH+/5jNSbiFHMS6fj/xn6bY3h8HAbPrH2384/PufvE264h9O/3jLpnUg/e4z\nhs8TAFZYpldK59wFM/uwpBsl3TXp87GE6jtScUM6Q9VwYPsKqduWWvvZDFX1h/QuepxUuXz6x1s2\nzkmWX+2hSwDoyWLV3tleJUpmtibpH0n67LSPiwWpX5CCyqJbsVzKvf7IKhj0H2dV+9ksfG4EKQAe\nyKIidZmkt/bmSeUk/Ylz7s8zeFwsQn1HCrYX3Yrl0u+P+k42FaRBkFrhfg62CVIAvJDFqr07Jd2Q\nQVuwDAhSxx0OUlmo70iWk0qb2TzeMiJIAfAEO5vjKILUccGZ8O8sg1Swnc3E9WVFkALgCYIUjiJI\nHTeLitSq9zFBCoAnCFI4ql5d/Yt8UpkHKQ/6ONiWGtVFtwIAZo4ghSHnwrBQXtHVZGkFM1i1t+p9\nXKYiBcAPBCkMNfck11n9aklShbJUCBjaSyLYlpq7Uqe96JYAwEwRpDDkw7L8tLKc81PfGU5gX1X9\n7yGG9wCsOIIUhghS8TIPUivex4N5ZRcW2w4AmDGCFIYIUvGyClKdltTaW/0+znqCPgAsKYIUhvrD\nMKs+7JRGVkGqURs+3iojSAHwBEEKQ1Sk4mW1nL8/1LXqfTwIUsyRArDaCFIYWvXDdKdRzugQXl/6\nOOstIwBgSRGkMNSvlqz6Hkdp9If2nJvucXyp+jG0B8ATBCkM1XfC/ZKKwaJbsnyCbanTlNr16R7H\nlyBV2pJkBCkAK48ghSEfluWnlVWFxZcglcuFw3sEKQArjiCFIR/OgEsrsyDVXxnpQT9zcDEADxCk\nMERFKl5/S4hpV6HVdyRZb+hrxXFwMQAPEKQwRJCKl+XQXlAJh75WXXCGihSAlefBuzlOrL7Dir04\ng+X8Ux55Ut+Ryp6E1ay2jACAJUaQwhAVqXiZVqQ86WPmSAHwAEEKIef8usgnRZBKjiAFwAMEKYTa\ndanb8ucin1QhkPKl6YNBw6OVkf3J5t3OolsCADNDkELIl/2N0jLLpsLiW0VKYuUegJVGkEKIIDVZ\nFsv5fQxSHFwMYIURpBAiSE027Sq0bqc3tOfJykgOLgbgAYIUQgSpyaYd2mt4tKu5xMHFALxAkEKI\nIDXZtEHKtz4mSAHwAEEKof5Gk75c5NMgSCVDkALgAYIUQj4dppvW1EHKsz4mSAHwAEEKofpOuE9S\nIVh0S5ZXsB3ut9VupLu/bxWpMpPNAaw+ghRC/WX5ZotuyfKadjm/b0Eqlw/DFPtIAVhhBCmEOLB4\nsmmHqvr386mfObgYwIojSCHk00aRaRGkkuO8PQArbuogZWZXmdmHzeweM7vbzH4xi4ZhzghSkw2C\n1IV096/vSKUtKV/Irk3LjiAFYMVlUZFqS/oV59z1kp4t6dVmdn0Gj4t58ukw3bSmrUj52MfBdvrg\nCQCnwNRByjn3gHPuU71/1yTdK+mKaR8Xc0ZFarIshvZ862MqUgBWXKZzpMzsGkk3SPp4lo+LOfDx\nIp9Uv3/SrkLzsY+DbQ4tBrDSMgtSZrYp6U8lvcY5d+yd08xuNrPbzOy2hx56KKsviyy06uH+SL5d\n5JMqrku5whQVqQv+9XGwHQbPbnfRLQGAmcgkSJlZUWGIertz7s+iPsc5d4tz7rxz7vzZs2ez+LLI\nim+H6aZlNt1y/vqOFHi0Yk8Kn6/rSs3dRbcEAGYii1V7JunNku51zv3O9E3C3Pm2UeQ0ppnz4+vQ\nnsQ8KQArK4uK1PdI+mlJzzezO3p/XpzB42JefDsDbhppg1S3KzVq/vUxQQrAipt6Qxvn3Eckca7I\nadZfnu7bRT6NtEGquRsOcfnWxwQpACuOnc3B0F4SaYOUr31MkAKw4ghS8Pcin0ba5fy+9vG0W0YA\nwJIjSMHPM+DSmrYi5Vsfl6lIAVhtBCmEFznLS6WNRbdk+QXbUmtP6rSS3c/bilQvOBKkAKwoghSG\ny/KNNQMTDeb8JByq8jVI5YtScYMgBWBlEaTg52G6aQ2CVMKDeAebnp7Jtj2nAQcXA1hhBCn4uVFk\nWmlXoQ0qUp7NkZI4uBjASiNIgSCVxDRBqrgRDnX5hiAFYIURpECQSiLtcn4fDyzuS7tlBACcAgQp\n+HmYblrllKvQfO7jYIqDngFgyRGk0LvIezgJOo1phva8rkgRpACsJoKU7zotqbXv70U+qdKmZLkU\nQcrjlZH9IOXcolsCAJkjSPmuP3fF14t8UrlcOLxHRerkgm3JdaTm3qJbAgCZI0j5rr+/j68X+TTS\nDFX5HqQkhvcArCSClO983XF7GklXoTlHkJI4uBjASiJI+Y4glVzSilRzLxza8rWPqUgBWGEEKd/1\nL25lT5fmp5E0SPnex2WCFIDVRZDyXYPJ5oklDVK+9zEVKQArjCDlO4b2kktbkfK1jwlSAFYYQcp3\n9Z1wX6TS5qJbcnoE21KzJnXaJ/v8QZDydNPT/o7u/RWiALBCCFK+q++Ec3dyfCucWNJVaL5XpApl\nqbBGRQrASuLq6Tufl+WnRZBKjoOLAawogpTvfD5MN62kBxcPNj31uJ85uBjAiiJI+Y4Di5NLOnm6\nviMVgnCIy1ccXAxgRRGkfOfzYbppJQ5S9DFBCsCqIkj5jjlSyaWpSPnexwQpACuKIOU7LvLJEaSS\nI0gBWFEEKZ912uF+SL5f5JMqVyQZQSqJfpBybtEtAYBMEaR85vvRJWnlcmGYOulyfoJU+Py7Lald\nX3RLACBTBCmf+X6Y7jSSLOfvb3rqs6RbRgDAKUGQ8hkVqfROOufHubCffe9jztsDsKIIUj5jx+30\nThqk2nWp06SP+3uVEaQArJhMgpSZvcXMvmlmd2XxeJgTglR6Jw1S9HGIihSAFZVVReoPJN2Y0WNh\nXrjIp0eQSoYgBWBFFbJ4EOfc35jZNVk8FuaIi3xina5Tq9NVobSlfGNHNukOgz729xieRrsjV9hU\nIJ0oSDnn1Gh3JUn5nKmYZwbCJM12V93e1hLlQk5mE78zAWQkkyCFU6q+I8lYUXZCrU5Xz/nND+sb\n1bp+qfAt/YtCVa7TUS6fj7+T52H1LR/5sv7tn9+jspr6XCDd/aWv6cnPHH+fm//wdn3wngclhaHg\ng7/0fTr36PU5tPZ0+uRXHtHLb/mY2t0wSL3yu6/RG17y5AW3CvDH3H7VM7Obzew2M7vtoYcemteX\nxTj1qlTeCvdFwkTf3m/qG9W6fvD6S3Tp2ccoJ6fd2oXxdxoEKT/D6r0PVFUJCvrFG/+Bmq6g6oWH\nJ97nnv9X1VOuqOinn321Gu2uvvzw3hxaenp94cFdtbtOr37e43TuonXd+8AJ9zcDkIm5XUGdc7c4\n5847586fPXt2Xl8W47BRZCLVg7Yk6Yeeepkee9XlkqS9nW+Nv5PnFalqvaXLttf0899/nWq2oVxz\n8tBetd7S+asv0s9819Xh/w9as27mqVath/3z6uddpydcsqVqvb3gFgF+oRThM4JUIrXeBasSFFXY\nCOc8HdS+Pf5OngepWr2tylo4g2DfNlRojq+WdLtOu422KkFBlbXi4DEQr1ZvqZAzrRXzqqwVBt+n\nAOYjq+0P/ljSRyU90czuN7OfzeJxMWMEqUT6v+lX1goqbjxKklSvPjL+TvUdKV+SCsGsm7eUqvWW\nKkEYiA7ymyq2amM/f7fZlnNSZa04uF+VYDBW9aCtylpRZqZKUKSCB8xZVqv2Xp7F42DO6jvSmasW\n3YpTo3+B2gqKss2LJEnNvRNUpIJtydNVVNWDtq47G77N1AtbKrfHB6lhHxcUFHMq5IxgMEG13tJW\nEPZxJSio1mir23XK5fz8ngPmjaE9n1GRSqR6aGhvvRIGqdbeCSabe9zH1XprMETXKmwp6I6fON6f\nh1YJehWWtSIVqQmqB8OqX2WtKOfCyh6A+SBI+YzDdBMZXOTXCtroBanOwQmClKd97Jw7cpHvlLa0\n0d0de59BWO2Fr0pQGPQ7olUPzUMbDIdSxQPmhiDlq26Xw3QTqh6a1Lu5HQYpdzBhFZrHfbzX7Kjr\nNLjId0sVbbpJFalh1U8SFakTOFqRKvRuI3wC80KQ8lWzJsl5e5FPo9YbpjIzFYol7blA1pgQpDwe\n2hsNRQrOKLCWGvX4MFU7NKG/f1+qK+MdntDPBH1g/ghSvvJ8WX4a1YNwWX7frm0o35iw+aHHQWoY\nisKLu62F/bC7E7/S8fA8tPC+BbY/mODwFhNsGQHMH0HKVwSpxA5PnJak/dymCi2CVJzRUJRfD/fe\n2t+J3928PyQ1XIXG0N44rU5X+83O8YoUVTxgbghSviJIJVY9GC4zl6R6flOlcfsitepSu+5tHx/e\nykCSiifYxLRab2m9lFehd1DxFpPNx+pXnvp93P+b8AnMD0HKVwSpxKr19nC+j6RGYUvlzphVaP1h\nP0/7eHQFXqm391Zjd0yQOjRxWgorLAetjprt7gxbenoN5qH1+ngQpAifwNwQpHxV71/k/Vyan8bo\nRb5d3NTauOX8dc+D1GBPqPDivrYV7gY/bhPTcPh0WPUbzvmhwhJldPi0kM9po5SnIgXMEUHKV4OK\n1JnFtuMUOTypV5I6pYo2xi3n97zqd3gneElarzxaktQes4lpbaTqN1jOz+TpSMO9zQ73GSsdgXki\nSPmqf5H3dLPIpJrtrg5anSMX+W55W1tuT64bM+xU7wUGX4NUvaW1Yl6lQvg2M9x7Kz5IjU7oZ/L0\neMPh00NVPCboA3NFkPJVfUcqbUr5TI5bXHm1kfk+UricP29O+7sxe0l5XpEareCtrW+p5fLSmL23\nRreYYDn/eLWRoT2JLSOAeSNI+crjZflpVEc2ipSk3Fo4LLpbjVnO73mQOrxRpCRZLqdd25DV47eM\niK1IUWGJFDm0R0UKmCuClK/qF7y9wKcxmO9THl6wBvsiVWMmT/sepA7aR7aLkKQ921A+ZsuI/tl8\nh+8zXIVGMIhSrbeUM2mjlB/cxpYRwHwRpHzl8RlwaYwu5Zek0ka4Cq1ei9mpu1GVcgWpuD7z9i2j\n0eqSJB3kNlWM2cR0cDZfcHTidP+xcFwYPMNji/o4nxCYL4KUr+o7TDRPYPQMOEkqb4ZBKnZfpH4f\nH7rI+WR0uwhJahQ2VG5HbxkRNQ9to5RXztgXKU51ZB6aNDyf0Dm3oFYBfiFI+Yo5UokcO4BX0lol\nXIXWjtsXyfM+jrrINwsVBZ3oob3hvlOH5lWZUWEZIyqsVtYK6rqwwgdg9ghSvvL8Ip9U1NDeRi9I\ndfZjlvN73Mf9+U6jF/l2aUvr3ei9t6KW8kvDCguOG53QL7FlBDBvBCkfOef1RT6N6kH72KTeze1w\ng8lu3L5IHvfxQaujdtcdmyPVLVW0GbOJaVTVTwqDFRtyRqseRAztMa8MmCuClI+au5LrenuRT6M/\ncfrwpN5SOdC+K8saMcv5PQ5SgzllI6HIBdtat4Zazcax+0RV/fqPwREx0WpjKlLsJQXMB0HKR54v\ny09jdFl+365tKEeQOmZ4PMzRPrPekUS7O8dXOvbnSI3eh+X88ar19uAInj62jADmiyDlIw4sTmz0\nDLi+g9yGCjHL+VX3d4uJuOpSfj3sj72d45uY9qtOo0GKDSajtTtd7TYY2gMWjSDlIypSiUVN6pX6\n+yJFrELrtKTWnrd9PFyBd/QiX+ztvXUQsfdWtd5WUMypXMgfuZ1DeKPtNqKHTyuDihRVPGAeCFI+\nIkglFjWpV5IahU0FnYh9kQZVPz/7OK4iVdrsb2J6fMuIqFV+UhgU9podtTsxh0N7Kup4GEmDoT7C\nJzAfBCkfDYLUmcW24xSJq0i1ihWtRQap3ko+X4NUzAq8ta0wSLX2oipSx3dCl4bbITB5+qhBWB2p\n+pUKOa0V8wztAXNCkPIRFanEqgfRF/lOaUsbUcv5Pe/j/nYFo/Od1irhlhHtiL23qgftY6FA4uDi\nOIOwGhM+GdoD5oMg5aP+RZ4jYk6k3elqr9mJrEh1y9vacHty3ZFhJ++DVEvlQk5B8eh8p+HeWzuR\n94kOBSznj1KN2WKif1utQfAE5oEg5aPGTniQbqG06JacCv1JvVHbHyjYVsk6qh+MVKUans+ROji+\nLF+SNja31XEmVz8epGoRS/kllvPHqcascuzfRkUKmA+ClI84sDiRuEm9kpRbC4PS7uhyfs+rfmF1\n6fgF3nI57dq6chFBKpxsztDeSY0f2mPLCGBeCFI+8nijyDTiJvVKUmE9nLC/HxekPO3nuBV4krRn\nm8o3j+695ZybONmcCstR1XpbZtJWOTp8UsED5oMg5SOCVCLjfvOP3RepviNZTiptzrx9y6hab0f2\nlyTt5zZVGNl7q97qqtVx0fN92GAyUvWgpc1yQbmcHfsY5xMC80OQ8hFBKpFhRer4Rb7c2xepsTuy\nL1J/+DTn549YLWaYTpIa+U2V20eD1HDfqeP32SwVZMYcqVFxW3JIw4qUc27OrQL8k8m7vJndaGaf\nM7P7zOx1WTwmZogglchwjtTxi/xa5SJJUmsvIkh53Mdxw3SS1CxuHdvENG7fKUnK5UxbZSoso8JN\nYmOC1FpR7a7TQasz51YB/pk6SJlZXtKbJL1I0vWSXm5m10/7uJghj8+ASyNul25JWu/ti9QZXc7v\neR9XY84mlHqbmHb3jn2+FN3H/dsZ2juqVo+v+vX7ni0jgNnLoiL1LEn3Oee+5JxrSnqHpJsyeFzM\ngnPeV0uS6k/q3SxFDDtthxWp7sHIBpMe93G91VGz3Y2s4ElSt1zRphsNUvET+sPbi0w2HzFuHtpw\ngj7hE5i16HetZK6Q9LVD/79f0ndm8LipfeLW39Vld/7XRTZhaZmcrnQtfe2gqKsiPv53X/22Xvu/\n7lS7y9yKvod3G7GTeoO1DdVdUdd/+Q/0tTfeOrj9EvdNPXjJcyP7+MJ+U694yydWdqiq0/veidoT\nSpJceVubdqCvvfE7Brc91jm9q7SlSv59kfeprBX0kfse0vN++68zb++y+L4nnNUbXvLkY7e7bld3\n/PYP6eKDrxy5/U1dp4+uvUrS+WP36VekXvGWT6g8simqzy7bDvTWVz1LxfzxGsJH3/yruvL+9y6g\nVZjWI895o572/J9Y2NfPIkidiJndLOlmSTp37txMv1b5zKV6cPM7Jn+ipz5+4Tod5J6tn4r42Ce+\n/Ii+8M1d/dBTL1PejgcHXz39qvhzCe94/KtVePAzR267q3at7rUX6pcjPv/zD+7q0/fv6Lsf92hd\nvFnOuKXL4ZnXXKTnP+kxkR+7/Ltfpk/ufFnWHc7feVT7m3rmwWfUyT8k6eyx+/yz77lW773zgVk1\nd+HuvP+C3n/XA5FBqlHf1w37f6sv5h+rb69fO7j9KXsf1QuDuyIf74ZzZ/TyZ12lvQZzpPq++si+\n/vaLD+uhWkOXn1k79vFLvv5BlVxDX9t6+gJah2ls9qZYLEoWQerr0pFfvK/s3XaEc+4WSbdI0vnz\n52da7nja835cet6Pz/JLnGo/+a/fr1faZZEfq9Zbypn0ey+/QUaQOpFn/9Qbj932sls+qtFTY/r6\nwy3/6sYn6WljAtqquvpJz9DVT3rX0Rvv+yvpbS89tr9U3wuffKle+ORL59C6xfiN996jt33sq5Ef\n273wsAJJ33riP9V3/sS/HH7gvz9XQb4eeZ+toKh//9KnzqClp9f7PvOAfv7tn1K13tLlOh6k1rt7\n+ur2M/XMX3rnAlqH0yyLOVKflPR4M7vWzEqSXibpPRk8LmYknLgbPaxU6827IERNpxLET47un4EW\nN7/FS0EvUNajg9SqqwRFHbQ6anWOp++93h5l+Y2R0B1sD48iwkSDHfJj5tqtuz11PD2JANOZOkg5\n59qSfkHSX0q6V9KfOOfunvZxMTuVoBB7kR+3IzVOrrJWjF0xNdhOIWZitZf6E/Mjjo7xwbiDmfub\nvZbWH3X0A8G2t/2VRn8Cfi3iva/b6WhTB3JlPxeIYDqZvJM7594nKXqWKJbO1pjjI8KVQFzgpxUe\nGhsfVsPPIbAOBL1KQP3C+M9bUYdX2V20cfQw8WYt3KOstDlSkSoTpJIYd2ZjrfptbZuTBVSkkJyf\n2y57btzQXvWgpa0yF/hpVYKiao32YAXbYdV6S0Exp1KBH7+B/pCKp8Gg/zMXdZFv7odBam3roqMf\noCKVyFYQf2bjXu+szNy6f3MWMT3eyT1UCQqqxVakWlSkMtAfqtmNCKzVg/jNKr1VDKRC4G0wGJwn\nGHGR7+yFVbqN0ZVJwbbU3JU6q7mNRta2BnOkjr/3HVTDIFUYHT4FToAg5aFxu0Rzkc9Gf/5TVD+P\nOz7Fax5XWAZDe1Hzd3qbvW6eiQhSEhPOT6hUyGmtmI/s43rc8ClwAgQpD4UryuJW7XGRz8KgwhA1\nH6PeZqJ5FJ+D1JhqiepVNV1B5WD96O2DCfp+zitLo7JWiKz6NXtnZQabFx37GDAJQcpDlbWCmu2u\n6iMHmrY7Xe01O1SkMjBuqTUVqRgeL+cfF7xzzapqtiHLjbxdD4KUn32WRty2JK3e8On6gjd2xOlE\nkPLQVszqldrg4FiqJdPaGje0xxYT0TyuSG2U8spZ9PYHhWZVe7Zx/E6ebxmRRty2JP3h041tKlJI\njiDloUrM6pXhwbFc5Ke1vRY/VMMWEzHKFW9DgZmFcxcjvl+KrZrq+c3jdwr8XumYRtweeq7Xhxtb\nTDZHcgQpD8UNI/SD1Rbzd6Y23LPmaFh1zoVbTBBWj/O4IiX19h6LqJaU2zU1IoMUFamk4vbQs/qO\ndt2aCkV+LpEcQcpDcRNbBxUp5u9MbTMYbrB42EGro3bXUfWL0g9SbqZHcS6tSsxFPujuqlXcOn4H\nglRilbXosJpvVrUbNXwKnABBykPbg6MSjr6h1Bjay0w+Z9oqFyL6mHlosYJtqdOU2tEH8a66uInQ\nG91dtUsRO26XtiQZQSqBflh1I2G90KzqIBdR9QNOgCDlobijEgZnwHGRz0TUfl39igNhNYLnFZa4\npfmbbk/dqCCVy4XzpDztrzQqa0W1u04HIyuWS+2aDqKGT4ETIEh5aCtmaT5De9mKOm+PPh7D8+X8\nURWpRn1fgbXkgpjDdD3eMiKNuG1Jyp09NQsRw6fACRCkPBQUcyrmLbJaYiZtlqhIZSHqwjio+jGh\n/7igt6u0pxWWqKX5uzuPSJJyazE7bns+QT+pymBaw9Gfy/XurlolghTSIUh5yMwiJ7ZW621tlQvK\n5WxBLVstUUM1VKTG8Hw5fyUoarfRVrvTHdy23zsDLr8WU5EqE6SSiJvWsOF21Y2a0A+cAEHKU+H8\nnZGLPMvyMxVdkQr/zxYTETw/8qT/PbHbGP5cHvTOgCvGnQFHRSqRrYg99Lqdjjbdvrpxw6fABAQp\nT1Vi5u9QKclO1AaL/fDKZPMI3k82Pz5/p1ELh/bKGzE7bhOkEonaQ29vd0d5czKCFFIiSHkqnI9x\n/CLP3J3sVIKCdhttdbvDpdbVekulQk5BMb/Ali0p34NUxLFCg8N0KwSpLETtobe3Ew6fxs5DAyYg\nSHkqahfl6gEVqSxtBUV1nbTXHPZz9aBNNSpOIZDyJW9XoVUijhVq74chaW1ckGrUpG43+uM4YngG\n5vBncr8aVv0KGwQppEOQ8lTUZPNanYt8lvorhA6/aYfDp1T9Ipl5XWGJmgjdP0x3s/Lo6DsF25Kc\nt+EzqaCYV7mQO/LeV+8NnxY3OGcP6RCkPBW3WSQX+exEDSNUD1qE1XE8Prg4KnirvqO2y2l9I2ZF\nmecrHdMYXWjT3AvDarBJkEI6BClPVYKC6q2uGu1wh99O16nWoCKVpaihmmq9zfDpOD5XpCK+X3KN\nqmq2IcvFvFV7Pq8sjUpQOPJLZKs3D21tK2b4FJiAIOWp/pt2fwPA3d7fLMvPznCoZvjbb+2gRR+P\n43GQ2iwVZHb0+yXfrGpv3GG6BKnEtkamNXT2w4rUxnbM8CkwAUHKU/2LfD9IsVFk9qJ2Ua4yD208\nj4NULmfaLB/dlqTYmnCYLkEqsdGhPdfru81tKlJIhyDlqeHGdOGb9iBIcZHPzFbUHCkmm4/ncZCS\njm/iWm7vqlEgSGWpEhRUO/QzafUd7bmyiqXyAluF04wg5anRjekGZ8Bxkc/M6FLrequjZrtLWB0n\n2Pb20GKpv4nrsFpS7uyqWajE36EfpFi1d2KjC21yk4ZPgQkIUp4aPQWdilT2ivmc1kv541U/hk/j\nBdtS+0BqNxbdkoUYnQi90d1Ve9xhumVW7SUVbv0yDKuFZk1744ZPgQkIUp4aLrXuV6TCv7e5yGfq\n8FDNoOrHZPN4g6EqPyss4ZbYWk0AABJhSURBVIkDw4v8pttTtzSmIpXLS6UtglQClbWCmp2u6q1w\nxXKpXVUjT5BCegQpT43uccQZcLNRWSscr/oRVuN5Pufn8Ea5rWZD69aQm3QGnOfzypIa3fh04jw0\nYAKClKfWS3nlczb47be/smyTakmmKkFRtUbYt7U6FamJPA9SW4eG9nZ3wh23Jx6mS5BKZLjQJvx5\nXOvuqlUcU/UDJiBIecrMjrxpVw/a2iwXlM/Zglu2WraCQxWpA+ahTTQIUhcW244FqawVBwdd9w/T\nzU86TJcglcjoQpsNtzd+HhowAUHKY4eHEar1FpWSGTi8QoihvRPwvCJVCQpyTqo12jronQFXmHR0\nSbDtbfBM4/C0BtftasvtyZUmVP2AMQhSHqusFQZzo8Jz9rjAZ+1IWD1gHtpEni/nP3xMTH03PLqk\nvHGSipSf/ZXG9qEzDff3aipYV1ojSCG9qYKUmf24md1tZl0zO59VozAfxytSXOCz1g+rzjlV6y0V\n86agyO8vsTxfzn94InT/DLhg0hlwgb8HPadxuCK1Ww2HT3OT5qEBY0z7jn6XpJdK+psM2oI5G12a\nz2ac2asERXW6TvvNTlj1C4oyYx5arNKGZHlvg8FgW5KDttp74XDdemVSkNoOK3jd7qybtxIOz5Ha\n703oL0yq+gFjTBWknHP3Ouc+l1VjMF+VtcJw1V6DitQsHD4culZvM3w6iZnXk6eHZ2C21D044WG6\nwbbkulJzd9bNWwnlQk6lfE61eluN3jy00saEeWjAGIwxeGwrKOqhWkM///bb9eBOY7AsGNnp9+lr\n//ROfexLD9PHJ0GQ0pv++ou676tfV8eZNjZPsP2B5G2fJdVfsfy+zzygWz96jySpNGlCPzDGxCBl\nZh8ys7si/tyU5AuZ2c1mdpuZ3fbQQw+lbzEy85zHX6zHnd3UFx7c1bUXb+i5Tzi76CatnKddeUZP\nvXJbD1w40PZaUS988qWLbtLy8zhIXbod6Huvu1j7jbY23Z7q+U3l8vnxdyJIJfYjT7tcpXxOaoR9\ndsVl/FwivYm/HjvnXpDFF3LO3SLpFkk6f/68y+IxMZ3vf+Jj9P1PfMyim7HSrrpoXe/5he9ddDNO\nF49XoZUKOb3tn39n+J8/+yPpqxPmR0ner3RM4w0veXL4j098XnqftFm5eLENwqnG0B6A5eJxReqI\n+s4wJI1DRSq9fp8F7GyO9Kbd/uBHzex+Sd8l6b1m9pfZNAuAt1jOHzppkPJ8y4ip1HekQiAVyotu\nCU6xqWa+OudulXRrRm0BACk4QyiQwj646NrJnxecGX4+kjlpWAXGYGgPwHIJtqXWntRpLboli3Xi\noT0qUqkRpJABghSA5TKY8+P55OmTXuTzRam4QZBKgyCFDBCkACyXQZDy+CDebkdq1k5+kefg4nQI\nUsgAQQrAcmE5//C5JwpSHvdXWo0qQQpTI0gBWC4s5z+0LD9JkPK4v9KiIoUMEKQALBeW8w+fe/mE\n+xuxZURyzoV9dtI+BmIQpAAsFypSVKTmoV2XOk0qUpgaQQrAciFIEaTmIWkfAzEIUgCWS2lTspzf\nwSBtkHIcY3piBClkhCAFYLnkcuG8FYJUsiDlOlJzb3ZtWjWDPj6z2Hbg1CNIAVg+vi/nr1clWYLJ\n5mwZkVg94RYTQAyCFIDl4/sqtPqOVN4Kq3MnwUrH5PobmAas2sN0CFIAlo/vBxcn3d+ICfrJMUcK\nGSFIAVg+vq9CSxykzgzvh5MhSCEjBCkAy4cgRUVq1uo7Ur4kFYJFtwSnHEEKwPIhSBGkZq3fx2aL\nbglOOYIUgOUTbEvNmtTtLLoli9FIGqSYbJ4YBxYjIwQpAMvH9+X8SStShbJUWCNIJcGBxcgIQQrA\n8vF5OX+3G+5xlPQwXd+3jEiKA4uREYIUgOXj85yfZk2SS14t8X1eWVJUpJARghSA5eNzkEq7LJ8g\nlQxBChkhSAFYPgQpgtSsEaSQEYIUgOVDkCJIzVK7IbXrBClkgiAFYPkMgpSHq/bSHqYbbPu7yjEp\nDixGhghSAJZPuSLJ/KywTFuRci77Nq2aQR+fWWw7sBIIUgCWTy4nlbcIUkmUK1KnGQ5ZYbxBH7P9\nAaZHkAKwnHyd89N/zon3kfJ4XllS9Qvh3wztIQMEKQDLyecgVdqU8oVk9yNInVzaqh8QgSAFYDn5\nHKTSXOD783187LOkCFLIEEEKwHLyNkhdSBmkqEidGEEKGSJIAVhOwbbU8DAUNKoEqVlrVKVcQSqu\nL7olWAEEKQDLqezpIbxpD9MNPD7oOal+H5stuiVYAVMFKTP7LTP7rJndaWa3mhmbcgDIRrAdbpzY\n7S66JfOVeo4UFakT43gYZGjaitQHJT3FOfdUSZ+X9GvTNwkA1LvQOalZW3RL5ivtRb4QSPkSQeok\nCFLI0FRByjn3Aedcu/ffj0m6cvomAYD8rLA4l/4ib+bvBP2kCFLIUJZzpF4l6f0ZPh4An/kYpJq7\nkuumv8gTpE6GIIUMTdzxzcw+JOnSiA+93jn37t7nvF5SW9LbxzzOzZJulqRz586laiwAj/h4cPG0\nh+lycPHJ1FOujAQiTAxSzrkXjPu4mb1S0g9L+gHn4k/LdM7dIukWSTp//jynagIYz8eK1LT7G1GR\nOhkqUsjQtKv2bpT0Wkkvcc7tZ9MkAJCfy/mnPUzX1y0jkui0pNYeQQqZmXaO1O9J2pL0QTO7w8x+\nP4M2AYCfR55QkZq9aYdPgREJT8U8yjl3XVYNAYAjyj5XpFJuyUeQmqx+IfybIIWMsLM5gOWUL0il\nTb+CQRYVqXZdatWza9Oq4Zw9ZIwgBWB5+VZh6T/XNEfESMNwwMq9eAQpZIwgBWB5+XZwcWMnPEi3\nUEp3/8G8MoJUrAZzpJAtghSA5eVjRWqaC7yPW0YkRUUKGSNIAVhevi3nr++kH9aTDm0ZcSGb9qyi\naYdPgREEKQDLi4pUMlSkJqvvSJYLFzIAGSBIAVheBKlkCFKT9at+OS5/yAbfSQCWVz9IxZ8+tVoI\nUrPH8TDIGEEKwPIKtiXXlZq7i27JfEx7kS+uS7kCQWocghQyRpACsLwGFRYPlvM7Fz7PaS7yZr0t\nIzzor7Sm7WNgBEEKwPLy6eDi1oHUbaU/sLjPt5WOSVGRQsYIUgCWl09zfrLa38i3CfpJEaSQMYIU\ngOVFkEqOIDUeQQoZI0gBWF6DI088CAYEqdnrtKVmjSCFTBGkACwvLytSZ6Z7HIJUPM7ZwwwQpAAs\nr/4xHj4cXJzVRT7Y9mOVYxoEKcwAQQrA8iqUwr2RfKiw9M/HmzpInZFae1KnNX2bVg0HFmMGCFIA\nlpsvy/mzOkx3sGUEValjOLAYM0CQArDcfJnzU9+R8mWpGEz3OIN5ZRemb9OqoSKFGSBIAVhuPgWp\nLC7wPk3QT4oghRkgSAFYbgSpZAhS8QhSmAGCFIDlRpBKhiAVr74jyZgjhUwRpAAsN1+W82d1mG7/\nMTi4+Lh6NQxROS59yA7fTQCWW78i5dyiWzJbVKRmj+NhMAMEKQDLLahI3ZbUOlh0S2arvjPcumAa\npU3JcgSpKFn1MXAIQQrAcvOlwpJVtcTMn723kqIihRkgSAFYbj4EqVZd6jSyu8j7MkE/KYIUZoAg\nBWC5+RCksl6WT5CKRpDCDBCkACy34Ez49yoHg0GQOpPN4xGkohGkMAMEKQDLzYfl/P3nlmlFaoX7\nK41uN+xnghQyRpACsNz6myeu8tlx/eeW1UaRVKSOa9YkOTbjROYIUgCWG3OkkiNIHcfxMJgRghSA\n5VYMpHx5tYPBLIJUsyZ12tk83iogSGFGpgpSZvbvzOxOM7vDzD5gZpdn1TAAGFj1CsssgpS02vPK\nkiJIYUamrUj9lnPuqc65p0v6c0m/nkGbAOAoH4JUrigV17J5PB+GQ5MiSGFGCtPc2Tl3+NedDUkr\nfhgWgIUItqVHviR99n2LbslsPHh3+BzNsnm8flj4/F9IZ67O5jFPu6/83/BvghQyNlWQkiQz+w1J\nPyNpR9LzxnzezZJulqRz585N+2UB+GT7Cumed0vvePmiWzI7lz41u8eqXBH+/Revy+4xV0G+JG1c\nvOhWYMWYm3Ciupl9SNKlER96vXPu3Yc+79ckBc65fzPpi54/f97ddtttSdsKwFfNPelbX1h0K2br\nzDlp/aLsHu9b90nN3ewebxVsXCxtX7noVuAUMrPbnXPnoz42sSLlnHvBCb/O2yW9T9LEIAUAiZQ2\npMufvuhWnC4XX7foFgBemHbV3uMP/fcmSZ+drjkAAACnx7RzpP6DmT1RUlfS30v6uembBAAAcDpM\nu2rvx7JqCAAAwGnDzuYAAAApEaQAAABSIkgBAACkRJACAABIiSAFAACQEkEKAAAgJYIUAABAShPP\n2pvJFzV7SOEGnrN0saRvzfhrLDOev7/P3+fnLvH8ef7+Pn+fn7s02+d/tXPubNQHFhKk5sHMbos7\nYNAHPH9/n7/Pz13i+fP8/X3+Pj93aXHPn6E9AACAlAhSAAAAKa1ykLpl0Q1YMJ6/v3x+7hLPn+fv\nL5+fu7Sg57+yc6QAAABmbZUrUgAAADO1kkHKzG40s8+Z2X1m9rpFt2eWzOwqM/uwmd1jZneb2S/2\nbn+DmX3dzO7o/Xnxots6K2b2FTP7TO953ta77SIz+6CZfaH396MW3c5ZMLMnHnqN7zCzqpm9ZpVf\nfzN7i5l908zuOnRb5Ottof/Sey+408yesbiWZyPm+f+WmX229xxvNbMzvduvMbODQ98Hv7+4lk8v\n5rnHfq+b2a/1XvvPmdkLF9Pq7MQ8/3ceeu5fMbM7erev2msfd61b/M++c26l/kjKS/qipMdKKkn6\ntKTrF92uGT7fyyQ9o/fvLUmfl3S9pDdI+tVFt29OffAVSReP3PYfJb2u9+/XSfrNRbdzDv2Ql/QN\nSVev8usv6bmSniHprkmvt6QXS3q/JJP0bEkfX3T7Z/T8f1BSoffv3zz0/K85/Hmn/U/Mc4/8Xu+9\nD35aUlnStb3rQn7RzyHr5z/y8f8k6ddX9LWPu9Yt/Gd/FStSz5J0n3PuS865pqR3SLppwW2aGefc\nA865T/X+XZN0r6QrFtuqpXCTpLf2/v1WSf94gW2Zlx+Q9EXn3Kw3u10o59zfSHpk5Oa41/smSf/T\nhT4m6YyZXTafls5G1PN3zn3AOdfu/fdjkq6ce8PmIOa1j3OTpHc45xrOuS9Luk/h9eHUGvf8zcwk\n/YSkP55ro+ZkzLVu4T/7qxikrpD0tUP/v1+eBAszu0bSDZI+3rvpF3olzbes6tBWj5P0ATO73cxu\n7t12iXPugd6/vyHpksU0ba5epqNvor68/lL86+3j+8GrFP4m3netmf2dmf0fM3vOoho1Y1Hf6769\n9s+R9KBz7guHblvJ137kWrfwn/1VDFJeMrNNSX8q6TXOuaqk/ybpcZKeLukBhSXfVfW9zrlnSHqR\npFeb2XMPf9CFdd6VXp5qZiVJL5H0rt5NPr3+R/jwescxs9dLakt6e++mBySdc87dIOmXJf2RmVUW\n1b4Z8fZ7fcTLdfQXqZV87SOudQOL+tlfxSD1dUlXHfr/lb3bVpaZFRV+Y73dOfdnkuSce9A513HO\ndSX9D53ykvY4zrmv9/7+pqRbFT7XB/tl3N7f31xcC+fiRZI+5Zx7UPLr9e+Je729eT8ws1dK+mFJ\nP9m7oKg3rPVw79+3K5wn9ISFNXIGxnyv+/TaFyS9VNI7+7et4msfda3TEvzsr2KQ+qSkx5vZtb3f\n0l8m6T0LbtPM9MbF3yzpXufc7xy6/fBY8I9Kumv0vqvAzDbMbKv/b4WTbu9S+Jq/ovdpr5D07sW0\ncG6O/Dbqy+t/SNzr/R5JP9NbwfNsSTuHhgFWhpndKOm1kl7inNs/dPtZM8v3/v1YSY+X9KXFtHI2\nxnyvv0fSy8ysbGbXKnzun5h3++bkBZI+65y7v3/Dqr32cdc6LcPP/qJn4s/ij8LZ+p9XmMBfv+j2\nzPi5fq/CUuadku7o/XmxpD+U9Jne7e+RdNmi2zqj5/9YhStzPi3p7v7rLenRkv5K0hckfUjSRYtu\n6wz7YEPSw5K2D922sq+/wsD4gKSWwnkPPxv3eitcsfOm3nvBZySdX3T7Z/T871M4H6T/HvD7vc/9\nsd7PxR2SPiXpRxbd/hk899jvdUmv7732n5P0okW3fxbPv3f7H0j6uZHPXbXXPu5at/CffXY2BwAA\nSGkVh/YAAADmgiAFAACQEkEKAAAgJYIUAABASgQpAACAlAhSAAAAKRGkAAAAUiJIAQAApPT/AdRH\nmJiJapP8AAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 720x360 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# 1D kernel - WIO layout\n",
|
|
"kernel = np.array([[[1, 0, -1], [-1, 0, 1]], \n",
|
|
" [[1, 1, 1], [-1, -1, -1]]], \n",
|
|
" dtype=jnp.float32).transpose([2,1,0])\n",
|
|
"# 1D data - NWC layout\n",
|
|
"data = np.zeros((1, 200, 2), dtype=jnp.float32)\n",
|
|
"for i in range(2):\n",
|
|
" for k in range(2):\n",
|
|
" x = 35*i + 30 + 60*k\n",
|
|
" data[0, x:x+30, k] = 1.0\n",
|
|
"\n",
|
|
"print(\"in shapes:\", data.shape, kernel.shape)\n",
|
|
"\n",
|
|
"plt.figure(figsize=(10,5))\n",
|
|
"plt.plot(data[0]);\n",
|
|
"dn = lax.conv_dimension_numbers(data.shape, kernel.shape,\n",
|
|
" ('NWC', 'WIO', 'NWC'))\n",
|
|
"print(dn)\n",
|
|
"\n",
|
|
"out = lax.conv_general_dilated(data, # lhs = image tensor\n",
|
|
" kernel, # rhs = conv kernel tensor\n",
|
|
" (1,), # window strides\n",
|
|
" 'SAME', # padding mode\n",
|
|
" (1,), # lhs/image dilation\n",
|
|
" (1,), # rhs/kernel dilation\n",
|
|
" dn) # dimension_numbers = lhs, rhs, out dimension permutation\n",
|
|
"print(\"out shape: \", out.shape)\n",
|
|
"plt.figure(figsize=(10,5))\n",
|
|
"plt.plot(out[0]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "7XOgXqCTmaPa"
|
|
},
|
|
"source": [
|
|
"### 3D Convolutions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 163,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 530
|
|
},
|
|
"colab_type": "code",
|
|
"id": "QNvSiq5-mcLd",
|
|
"outputId": "1c278db7-e2a0-4f53-d7d4-57472f2a794e"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)\n",
|
|
"ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))\n",
|
|
"out shape: (1, 30, 30, 30, 1)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nOy9ycttaV7v+Xna1e3mbU8XERmRkU2J\nWReLgrpYkLeoQWJRguKdqChaiuAox04UyYEiOHciDpwkTmqkI0Eo/wAp9FKmqGk2EXHat9nt6p62\nBus9kY1XTM00TkTG+sDh7HPWXpu19mZ91299f80jcs7MzMzMzHwwyFd9ADMzMzMfJ2bRnZmZmfkA\nmUV3ZmZm5gNkFt2ZmZmZD5BZdGdmZmY+QPS/sn0ubZiZmZn5tyP+pQ1zpDszMzPzATKL7szMzMwH\nyCy6M6+Uz33uc/zlX/7lqz6MmZkPDPGvdKTNnu7MR55f+ZVf4fXXX+d3fud3XvWhzHx8mD3dmZmZ\nmQ8Ds+jOvFLeeust/uIv/oIvfelL/OzP/iy//Mu/zHK55HOf+xx/9Vd/9R3v+73f+z1+9Ed/lNPT\nU371V3+VYRgA+OM//mM+//nPf8fnCiH46le/yh/+4R/y5S9/md///d9nsVjwUz/1Ux/o+c3MfDez\n6M58aPjTP/1Tfv7nf57tdstP//RP88UvfvE7tn/5y1/mz//8z/mnf/on/uEf/uF7sgt+/dd/nV/8\nxV/kN37jNzgej/zZn/3Zf9Thz8x8T8yiO/Oh4fOf/zw/+ZM/iVKKX/qlX+Jv/uZvvmP7F7/4Rd54\n4w3Ozs74zd/8Tf7kT/7kFR3pzMy/n1l0Zz40PHjw4P3XdV0zDAMhhPf/74033nj/9ZtvvsmTJ08+\n0OObmflBMIvuzEeGd9999/3X77zzDo8ePQKgaRq6rnt/27Nnz75jPyH+xUTyzMwHziy6Mx8Z/uAP\n/oD33nuP29tbfvd3f5ef+7mfA+DHfuzH+Nu//Vv++q//mmEY+NKXvvQd+92/f5+vfe1rr+CIZ2b+\nObPoznxk+IVf+AV+4id+grfffptPfepT/NZv/RYAn/3sZ/nt3/5tvvCFL/CZz3zmn1Uy/Nqv/Rpf\n+cpXODk54Wd+5mdexaHPzLzP3Bwx85Hgrbfe4o/+6I/4whe+8KoPZWbme2FujpiZmZn5MDCL7szM\nzMwHyGwvzMzMzPzgme2FmZmZmQ8Ds+jOzMzMfIDMojszMzPzATKL7szMzMwHyCy6MzMzMx8gs+jO\nzMzMfID8a0uwz8z8mxhDoPMeIQSNMRilAMg5z4NnZmaY63RnfgDknHEx4mKgdR6rJ6GNKbEsSm7G\nliE6lJA8qFeUyuBTIOaIkRol1Cs+g5mZHzj/YoQxi+7Mv5ucMyln9uOAj4m9c5AS502NS5HdOHAI\nA0trWdqKkCI+RZZW83S4JueMEoq3F48olKENW1JO1HqJleWrPr2Zme+HWXRnfrCMIXAYRzrvcClx\nUdccx5Ht0ONSZMweheA29pyXNffKBRu341l/ixcjn16+RqksY3T4OFDqnmPYIjIoafj08n9C5J6D\n+xpCKFb2M1i1ftWnPTPzvTKL7sz3T86Z3nuGEDi6kWVRMMTArh8olWYTer52e802DJwWNQ+qBV54\nbsYDAz2FFFitGVLLebHi3FZcuW9w3T9jaSSfXv4nlDC04QUy9xTiCiUKMgkhJK83/yfe/784/3dI\nuWZR/R9o9eDu2AagmH3jmQ8Ls+jO/PuJKRFS4uhGQkrEnNn0PSdlBSLzletnXHcdUkCtLUrBkANW\nKa79Dc/7PVF0nBVrHpZndPkZz4fnFKpjZS5ZaIngQK1WLNWBMT7Bhaes7Cn3q/+MINL7v6emoxAR\npR+RUwcic1L/V2L3f5PTLUKeYZr/C6lfJ8dryB3IC4SsX/VXOPPxYxbdmX8fYwhsh4GQIpth4LKu\nkULw5LBnOw700RFCZkyOQxo51TWj6PnHwwu61FJIy8P6FKUHet8hRMeQt7g0YGVHrc94WJxyDP8N\nF5+yVKDlAy6tBg6UsqQWz8ipx+bnWP0mi/LHyfGKFL5CJRJafxap7pHTFnLGFP8LjP8PIEGUiObX\nQd0D/3dAAP1JhDx/xd/szA85s+jOfO/knDmMI0MI7Iaek2oS2uv2iE+RSOa9w45N3xGJNLrirC75\n6vGKb+yvyQRqbRHSI6RkXRbswjNu/TWalkLVvN7cQ4kbhrBDi2sUGSsHKtlj1CVnSuDDX1OIA1Ya\nlHzAmRLEtKOUCotDIChxSP0Ibf5nCF+F8A2UrJD2x0E25LglIxByCfHx3aVQIhdfJAHe/RUgMMV/\nRqmHr/aLn/lhYhbdmX+dkBIhRQ7DSCRhhOJF12KVptKav7+95ll7IOTIUlvWZcHGt7zoOvZjzzEN\n7HyLVIJH9QolPd/sn9CmLVYpFqpgbRxJBAyRLJ9BPlLJHqsaHpX3yfGrhLShFi1GFaykQ4uIFisq\nsYf0jEYkEAVK3qeRkpBusSiksCgsUkqEuCDrt4nhK4j4AqFOUfa/gKwI4T2SKPHZI4QmkxFIqsUX\n6cJj2vH/Q8qSk+p/p9CvAZByRM6lbTPfO7PozvzLpJzxcbIPck7ctB2LomBlLY8Pex4fD/gYAdBS\nMODxMZKBb7QbHu9vQWZWtmRlDYd85GY4EEVPTJHCjCgluLQneF6wC88wsqWUioWuOFMDnhYjDpSy\no6SjUh4tFqxVTUrvYBkohcPIirWSpAwCixRHZB6ohAahUeINUj4S8gYtBJoGqdZAArFiFEtcfEJO\nB4Q6pyz+N7KwdP5rOFYECqx+RM4jKY9cLH6BZ8M/cfDPMKLkjebHWdr7hBTw2aOFxkjzSn+/mQ8l\ns+jO/HNyzuzGgT4ENn3HwhSsyoKrtuUwDmQyt/3IYRwZs8cgOW8qHncH/nFzxRg9RiqMVhgNLnmu\n+j23YYcTPUbBqW1YG8EuXtPnA4UcMVJwYjxaKiolSekFUuyoxUCpYCEKat3j055SDlREFmpEAYYS\nJTU+H6mEowBKVVJR4AnElMjCo8gYAUJYpHiTNj3F5yMSSSEahLpPwpFZsMsVIbW4NKL165xW/ys+\nS27Hf6TPl5TmNdb2ET71uNjyWvNf+Eb7HjFHhBB8evEZTu0pfXD4HLFSU6pZiD/mzKI78y1cjAze\nc3SOlBOLYhJaFyONNjxvW97db/EhYJXmYtEwxsCTw46rviWTOfgRoaHSipzhyXDDk2FPzoHCaNbW\noJSnSz19OoDoMdJTKMGZLhGqw6UNWgwslKNSglMdSTmgREKxx8qBUgQaFamlQaQenz2VjNQi0ahA\nQqHRDFlD9hjhsFKwkEsykj46HB6FRItEEgklapL4BFf+MS55pDCc6IqkPksXj4RcccwNkOhyRaPf\n4FH9P3KIgafde3hO+ETzaVZmhU+eIfY8LN/mxXB8/0p7s7nkxNa03hFyplSKUs9C/DFiFt2POzln\nQkq4GNkNA0ZJbvuelDMXdcNN1/L1zS0uRKQUFFpjpWIzDoQYuR16nnQ7QkjU1nBalRx8z7vtgVt3\nJOGRSqCVYGE1fR65dRuG3FGqSGMNJ1oSGRnyHiUGSuWxKtAoQS0h5ZaURgrtOZMtjYZCQogeIQKF\nCFjhUSKzlhElLG0KRBJLEVipiCYTkJAFh1wgiAgClRQs5An75GmTo0+KhRQgDC4rjKjoxSOu3DPG\nLFHCcmlPiOpH2bodPhf4vKBSBZ4zFuaMN+tPsPWOd48vKOQpP7J+g6UpJ198HLhfnTHGiBKCmDKX\ndUNjLIMPpJwwSlHoefzJDymz6H6cyTmzHQZ679mPAwK4XCzYDwMvjkdCSgzO41IiMLX2Lo3h6Bz/\nuLlm1/cURlMaQ2k1g3c8745cjy1t9CAS69JSa8OzccP1cCApj7VQCkFjJEpGDqkjcwQCS+NZG42V\nkjEeSIysdUelA1pAJTJaeEIOQKKQgXv6SKUEIUtcSgiRWUiHIpFRnMpInxW7KElCcCYGFjqSs6bP\n4LNgyBpJxqFpBFTqjOc+cIgZnzVnWhLlGT4bjNTs4z1uxhv6VFBIyxv1I6T8JO92z8mppFANl+Ua\nKy8olOG16pR2DNz0LUvT8MnVOZUxhJRoneO8qJBCIoUgpMTCGozSk0eeM0pKjJ4Tdj8EzKL7cWQM\ngYMb6ZwjpMRpVbMbBvaup1SGzgXe222IMZNEZl0VlNryeL/j2X5HyoJAwipFbQyjdzztW666A3s/\nUpaak7KABDe+5XY80KWRSKQ2krXVCAW37sAQOwrjqW2i1IZCJDI9MQ9o4RAyc2YcSw0+wxADWjgu\nzJ5SZwSagkAgkJIgCkElAvftESMUh6AYksCKyFoNBGHJaFbScR0MYxKM2XCuWpZa0OUFx+DpkkWK\nTAb6vGSlNKW64GvDyNFBlgUPbI02rzMEyDnRhzWH2BFiRa0NP7J6E8kJX9tfU+aCRVny5uKcpWxI\nAi5tgw+BLjrOTMNJXWG1JsaEi4HGWLSSiDsh1lKilSSl6fITQiDl3Gn3EWMW3Y8L35r4FdkOPYXW\nHMaR4+i4bCrGkPj7qxeMPmCUxGjNSVmxGwYO/UAfAzddR4oZUxoWSuFT4unxwFV3YIyBotAUUmGN\nYh8c7x027F1PlomqUJyUBWNwbFzLKAaUdkgUSw1NIelDz5AHyJFlMbI0CS01KQSScFg1omVECTjV\njkZ7dsEyRKiE57LYU8hMzAZEIAQQMhOyphKBe7bDZ8O1t4xJs9A9a+UYaUhZYoXjha8IGVyynOqB\nU2u5cmfc+h4XDYWOiFwwcs5SW7RY843e0Y0JqxreWl5wYh6ydQGfHCYtiTkhMSxtxWfW98he8fSw\np1YFi9LwqdU5hdSkHKm1JQMxJ2pjKO+EN6VEyhktFUrJ93/Tl6L78noVQswtzx9uZtH9OJBz5rbv\ncDFydI4hBB4slgze8+5uy+imZgWXAyUGnyJIgUiZZ7sDt0OPMRojJbXSRJl5uttx0484IjlFVlXB\norI82R94fNjTp6mmtykNpVT45LkOR7rkiTFgTWZdKWot2TrHEI9I7SlNxAjJQgmUGhhCJOaIFInT\nauBER8ZscN6DyDR6wMiIQLHWI0p4bn2Fi4paOS7LI4WALhaEHAlZYlVkSIaSxKXp2YaSq1ASkmJt\nOhY64vMZQ0yIFLiNFRGBS5YTkzgzCx6PK26GIzkbGiNo1CmZM7QUGBY87UcIgqVZ8SNnD7hvTrjq\nevrgWKuGhTWUsqDUintVDUj27fSkUZaKB6s1GoHzEasVUkiEBKMm0RVCTNPcUkJKgZTfue7ALLwf\nWmbR/WFmCH7ybO982cu6pveeF22LlkAWPNnvCTGihMQUkhNV8fzY8my/J4oMGSqrWZUlg/c8b1u2\nx4GjH7FWcVqXpJi5HTpe9C2HOBJzopSaRWmJ2fGiO3I1tESRKArJqoBCWNrU06WOmANZJGorOSmm\nm8R+dAgc1ngWhUdLg8kZKQaGlMiAkokTM3JqPTtn6YJAkVgVA1ZNnm0lPSkntqnCB0mlPZdFh0Fy\n40t8yiQEpfL4bNEolrrn2lXsQolPgoUOnFnJmC/Zjp6YPUMsQGhSrjgrKs5NxTtHzc3QUmBZlIZP\n1q8hcsMYHZaS3gcWsuTMNHzq7JyVqdh1HSElzoqaRWUppEEClbVIEt4ntFJorVhWJUJASlOEKwTv\nR7rfHeHOovuhZRbdHzZSzowhMIbAbhyojWHwnpuXg2hy5qs31+xHRyU1WQleWy7px8Dzw54uBFIM\nuJg5rSqEkgQfuG07XnRHQorYQlNKQ20NRzfyznbHZujIGkplOKktMQTePe646Vu8imgFq6Kk1Jo2\ntdz0jphGkkqsS2iMJaXMMR3IKaB1RMlMYyULE+hdpA0ZTaAoA2dmBAwhQMyOBAgp0ALW1rEyA0/7\nmiFKrEwsy5FKBEIu0Hj6JGhjQUySUicuioGYFc/HipwSCUGjPTkXICpU7rn1FW3QxCRprOJR3eDC\nCc/ajpQzmZJGWip9wtqWLIzhxV6wHY4sdMN5U/CfTl4jOEEXBhaqJsXIvaahUSUXq4pCGLpxRJBY\nlCW1LTBGkWJGa4mW3+npFoUmpTQl25S6sxzkLLofXmbR/WEh50zMmdu+x8dI70cOo+fhcknKia9v\nNuz7kUZr2hg4LytCzIxxqhdtu4FN32OlwhhDZTQqC160B66OPVkkBJJlaadGie7Iu5sDB9eBFDSF\n5qRsOLiBF2PLbmw5jg6tFU1RYDUcx4GNa3E5oSxUWrC0iiwye9/RhxFBQGvBWSUwWnDoHTEHkIHK\nBAoJhdQYOXCMMHqJVImFjVwUA4PTtEGRRUCJjJYJJSVL7dEi8MQtcGGyGNbWUatI50scCRcFMRsC\ngkpqTotAHwTPhpKQEkrC2iZKThhSgcsdx6HAxYyVloVd8MbiBDdanhyPCASVrHi9OeXENBil0RJi\nL2jHwElVcd7UfOb0lH4IUyRsLUrAZd2gtJxEVyv63qO1oLQGazU5Z3LOaK2+w88FUGqucvgQM4vu\nDwOdc+zGge5upu39xRIXA08PO8gCCVwfO2JKWKVBwLos2B8Hnh72uJxojAapeNDUuJh493bL0Y24\nEDBKcW+9JKbEvut40XYcvCOTKHRBo6firGftgRftnp5EoQ11oWhMwabf87w/0KWpRbi0cFbVIEZu\n+5HOjwiZsUVmVSis0rRhJMQelyMFicLCupBEHIceco4YG6hNopAAipg9oxd39bWRpkicFz0HV7D1\nCkgUOmFFQgmNlolI4LpriEytzCeVo8qw8SVDgJghCUEhDIWpKFWiGzMvektOCaMVF7XhRJyw8ZL9\n2ENUIATn+oTTesllVREGuDn2FFJxXq/51OkpK11Oa8TljE6anAOVLVnVJfeXC8bR4ULAGoPRmlVV\nAOLO04UQPNZOjRVa69le+Ggwi+5HlZjSNDg8RvbDwKoscDHw7NCyKiwSwddvN+z6nqW1uJx5Y7Ui\nxMx7+y3tMHVoHf3Iw2aNtYrj6NkdO3rvGOOUBLPKYISg9Z6b7sim68hiEtWTqkEo+ObNhuu+xeVM\nVtOyPI21PN7uuBoO9K4nklg0FcuyQojIi8OBPgUCUKhEXUkqLRhipB87YowIk1jUkoWSjCHShojM\nAYSnMlAYfRepOjqniGRKnVgVDoNiSJKQIy6puyWAYGEzS+O4Hiydn4baGBNZ6AS5JIqMi4GdKycx\nlJrTGnTK3LiC3qe7G5lgXTQUckEi0o+JwyjRSVJZw2urFaey4aqLHF1HKSyFMXx6cUGtCppCE0dB\n5xyrouSirnnjbIlK0zpxEiilRGmJUYa6tJTW4H1AiMnTVUphjX7fz30ptB/GCobvjsY/xsyi+1Ej\npkRMiduhn4TXeXbjyGurFZnMP764ZjsMLIuSLjgeNEtyjBzcyBA9wSVu+4FlYSmsQgmFStC7wJPj\nniygkorCGE6qknZ0fOPmls6PyLuW1YtVw+A9T7d7dq7j6D1KSc7qBq0U26HlSbtl3w9kBYU1LAtD\nSlNt7tF7EnF6XG8mb/jQ9+zGQEwjUksW1fSIj0gcfEcKAXJC28yq0JQqs3ceFzJCBbTI1BKU0cjk\n6UJiDIqcBWWZWRcOgmDnLSlHUhZTdC01hRYYM3BzLOi9QmqBlXBaZkIqaCO46Bm8QWZBoUvWpSER\n2B4NXfBYUbAsNA+LFVpUHIPDuUh0iqUpOC0bPnG2xmbN7X4ghkhjC5aF4VOnF6gkUAoQEFzitKg5\naSyrZQVZ3CXUJFoLrJm61bSWGG1IaRo69NLL/bDZCy+tkJcVF1LKyQ5LCWBqCvn41BvPovtRYj+O\n7MepGqH3nofrFc4Hnh72hJjREm67ASUESkgQmVpbBu94spvesygLYk7cb2pIgsfbPZv+iNWGlBOf\nOD1FZMFVv2fbjYwxMMbIed1QaoWLidvhyIvdgSEkjJU0ZcnCFhz9wDc3G/ZhJGcwAs6aBSlH3j1s\nad2RgEBbyaosKQvJ0Y9s+0BOEakiVaGpjZzmOIwdLkQgoC2clFNmv4sjLkSyD2iVKa2ktoqUHEeX\n8Vmic8QW0MhMFhqXPW6UjFEhSJSV5KSK9GPmtjcIkUEKjASLQWuFzwP73uCCwihJUxpWhcANik0A\nHzwCSSEsp+WC0hha1zMcJUFE1qrhpFrwerMkJ03vHWFMlFJxUddc1kvOF9Pv0O9HsoJaG85OGh7U\nDSknck6kJIHIoqioK0tVWlKKpJQwWqOURCn1zyLel/xHRpcvxfRf2vbyz8uStpf/DjHfHStkpvrj\nTGb0gZQnq6cw37JMUs7IH44oeRbdDzs+Rlo/1da2o2NdlvTe87w7srQWmQVf29xy7B3romAg8fpy\nBTHz7mHLoXfUStNGz2vNksJqdsPA7eGAkYpNP3LalDTGEnPERUE7jNwOLUoISmtpCoMWgufHI9f7\nI533aKM5X9TURcGmPfLe7sBmaElkmqLgvKnpnOPp8cDRTc0TxhYs65LGKnZDx6bvcBmUyBgrOK9K\nPJ7d4OlHh1AZq2BZapTW9GFgGEZ8ClgFhVGsK8XoPceQCCEhiRQ6URmLNpJhHDk6QUwKqxJVKag1\nuKjoY8R7gY9glKAuDKVJ7PrEYZAILdFkCiMppSVmTR8dwyiJUdCYgnVdstCSXR9px0QOGasU67Lm\nfrEiS8W+6wgByPBGveb+csl53RDGaXl6mQSVtLx+uqA2JYvCEGNkdAEjoSoqlouSRWkZXUKpzMtL\nsC7s5PneRb85gxDfSqa9FEUhBOkusvy32A/frgPfvU9K6f3o9eVnvnz/t0e2324t+JBIORFipjAa\nKQUpZVwI+JQw6q4VOqapPllJbu/K6rSSXNQN+rtqkj9izKL7YSXcrT9223fkDL337IaeR6sVKcHf\nXT/n0I2syoI2Bt5olnezFHo670kxsx861qamLg0JiDERcmKzP+ITLGuDlIrzumZwgSfbW/ZjoNQK\nozWPliuShK9f37Ibe4JPZJW5v1ygBWzGgatjx67rSSpTqoKmlKSYuRqO3PQ9g3cIo1hVBY3WHIJn\nM3YMbkBIQW01J01DylPHW+cDmYzVicZICqs4Dh4fPUOISA2VUqxrTe/DFD26QBKZ0mRKW1DqzLHz\nHENGkFACmgKU1GQp6MfI4AQZMEawLBRCQDcIBtI0kzdx1+ZckFSiHSPtOCUlC21ZlxaNZAiZPmZi\nmPa5KBsu6gUyZ7aDx43TAJuVKXhQLji5qxppx5EYp4Tep1en03dqNKMfySlTi+lmt15UlKagKCTj\nEMkxo6SgLi1NU2CNJoZ4J7QaIUFJ9X6ybbqOxV1J2T8Xq/9eJ9t3C+e3v/72poz//j7T577crpTE\n+0CIcXoaUxKtp5kSKU3R6+AdPmVCTDTWUBYGF+O0SkkM1NZQKI2L03nebxYfZW94Ft0PGzlnNkPP\n0Tla54gp82C5YPCBd7fbqXQJwe0w0iiJEBLuvMlhmGptE4lFUZJE5rKoUELzbL/lquuotSaJyJun\nFxipeLrfshl7VBQMMXG+XNJoxRgDm27gEAbawVEpS1FqKqkZUuD5fs/N0KGFususW7TSvHN7w9XY\nk3IikbloGmpb8KzbcXXcT91uIrMoa04rS0LwotvS9h4hE8ZolqWmKBT7Q0vrEy46jJYUhaG2Eh8E\nQxiI3pFIqEKxMgalYO8DowvkMA2KWVhJVVQgPIfR03sQSIyA2gIYBpEYfSaMoKSgtJqmNjjnaR0M\nKWPuvNJalVPCj8yu88SoptpgU7IsNSDoB3AhorKgVIYHzYLTYoEQidtdT0qZ06LmXr3gYlljUcTo\nGVwmZ6iM5sFqwf2qJiEZRodU0JiKqtQsmhJBxkgNTHMatNJTsq2wQJ4EXUp4+QifQUmJupvdEO98\nYH1nS3w7KX0rOp2aMF5GrADf8mS/Fc1OuYacAJHRd1F2jJEQEjEntBKEkMlisg26YeQ4egbnKaxm\nVU9PcD5EtJTs/UhMmT4FGmu4qJtJoEPgwWIxndtHk1l0PyyMIdB6R+c8LkZWZcnROa6OB1ZlRQqR\nf9rcMDpPbQoCmddWK4KPPN7tOLiRylh673hrtaK0ltu247o7YrNk50buLdYsS4tLkd71jCGzb3vq\nQmO0xVhFJaeGh/c2exAZJRVlYTgpDJ3zfPNmS5dGyJqiVNyrG1xIvLO9Zt87+uxQUnN/sSCR2AwD\n1+2B1o8oDVVZcGorEvCs2zMOHT5ltNGcLGsKPZVdtf3AmBJWZOrSUJcFMXm2g2ccE4KAKTUro5BK\nsRk80Y9kH0CCLfRU0REc+xjwPpEjlApqW6CNoYsjXR8JCRSCqtRUVuNCpEuZGDIhgFSClalpasFh\n8NPAGwQWgdWSpS3RUtGNnjY6UhTUyvKoOqE2it4H/BBxMU2WRGE5qxouqpoUE7e7DqHgvKp5WK9Y\nLyqIiRQSPiZEkthacF4vOa2qSZTaHmslhVKsFw1STyI0/SVIGRKZ0ijKogAghPT+svXqrslCSoHR\nihAjMWaE4H3R/HZhnRoypqj5ZRTtQyDGqTvQaAVZkHJCIPAhEmKcbr4p01TlNBRo9LiYGEKgMpYx\neDJQG83tMHB9PNKlaV7zW6enbFzP0XlOy4KFLbBK8XC5nCPdmX8fL2fZ+hS56TqUVGz6nsF7Hq1W\njN7zt1cvGEbHorD0cfJrY4jsnWM39IgIm+7IRbNiWRW4HBm6EaMUt/sWJCyKAqkVi7KEGHh8e0sX\nA7UpMVpyuV5jpeSb1zdsxwGZMkEIXluvMdZwtd1x2w/TCgjRc1I3VGbK3j8/dGyOHV0cKIu7SMwU\nHIeRd3dbDkOP1JJSa84WC3KKXLcHdn4k5URhDMuqoNSCwxDYDntCykgtKJXirFkypoFN2zP4AJGp\nXbnRKCHZjYl+nLropEhIWXK2FAwBduNIcpGUIkbDqmyQRjG4jt2QyCIjkqDShrrSxARdcozD1JFm\ntGRRlBgFBxdxORFTBhSFgrUqMNqw9x1jjJAFC1WxshVWK/ARHxNt8EgJC13y9vICKzTbbpqFoYXg\npChYlhWV0ZzpGovked+RcuKyrnjr5JxaW1yM5JgQMhMTlNawKDWndY0Sik13RCvNorY0RYEQ8u47\nnm4kKU2+bmE0ZWHId8sxCXr3g68AACAASURBVKYKAvX+BLPp3EfnCHEatFMW06jJlCbvXMg82TA5\n42NCa4VWkq4fGX0gBkFTGbRRjOP0xBZTohs9/m5+89miRijB1bGlHUcOwXG/aQg504dAZQ2JzE3X\nclo1KCn45MkpF03ziq/c74tZdF8VMSWu+w4XApth6gS7aBYcx5GvbzbTWMEI23Gg1hqBuJu3CofW\ncXs8koVgVRVkCWtbYoXkyX7PpmtZaENUgjdPT6mV4clmy/XQUmRFnyKPTlcsyoLWj2wOAylFbvqe\ni2ZFYxVZwOAj267jtu+QQlGVilUxPdpe9S3PDwdiSCAlp3VFYwzbfuC93Z4uuimpZi3ndcFh8Ny4\nnm7s8SlSVxUntqCyhtu+5dD3jMkhlGBZ1Jw1U5fY1e2WYXQkAdoKTuoapQT7sWcYwKWMvkscnTWT\n2G7bSEgOckIYzcqWWAOH0bEbI4QACpaypKo0UUyrHLsU75aiNCwqhUIQcqKNiUBCArVULE2JE4F+\n9EQySDBSUWjFUhboqDnkAZ8dKitO9IqLoiEQ8TFjgqAXA8jMwlR8sjynEsU0AjM4Km25tAvWVUlW\nsBSWhbQ8Hg9A5rSpeHt9Sa3vKiJSQEsFZKw2aC05qyustlwdWoTKlMpyvmrIaRpDqZXCxTjZAnHq\nbGvKgpQSgwvklEBMwp5yBgFaSpyPhJjw0VNZQ2kto/N03pNTImew2uDCJOaVVdwcew7DiEuB07Jh\nuSi4PrQ4H/A5sfcjEsl2GLi3WrAqDN/cH6bxloVhXVS8cbJGi6mq5eFy9Wov3u+PWXQ/aDrn6Lyn\ndSMIQWMLNn3PtutYlyWHYeRr21t8DCxtRSbzoJ7E+NmxZd8fWRYlo4+8eXpCZQzPjy1Pb3dUQjPE\nwMOTFcuioPeBg+8RKbFpBxqraGxBUZUIEm3nuOqmRF1hNHWlWJsGnwLfuNnQhREhFFWhebBYklLi\nnd20xLoPASUEl8sTjJW8OB64ObZ0bgAhWFU1VggicDu03PYjPjnKwnJWVlS2ZDu2bNsBHwaSkJwu\na05sRcyBF+2RtjuSFBRKsSoXWAPb1nHse2KcHvvLhWSpC7rRcQgZ3wayFKA0SyuoKtg7wWFwhCGh\nCsGq1BRC4WSmD5EuBUQEoRW1FFgjcCnTxbuEkcwURlAZPS3W6RNJeYSOSAllqlnIhi4NjNkhIyib\nsGiM0VRYzFjTywNBDWihORWnXBZrujDiY8JGidMjUkms0bxlLlmrklvXsnEdpbJ8oj5hVZX0KVEJ\nxYkueeoOgKAqSz61WrMuK3Z9j4sRI6e6amsNkcRJWVJoM/n+YbJ07i1rtFSMLiAlhJynaDcBAlZ1\nCcC268lpSoS9LDtMKaKEou1HxjhF9YXRLMvp99h3PT5MUrGsDZ2f9qkKw5P9gW3Xc/Cei0XNGyen\nfH17y34cqe6qFhZlgZaSUmkerFbYu+lqDxbLV3b9/gCYRfeDIN0VgnfOsR0HtFK8aI8o4H6zZDsM\n/Ldnzwg5UilNykxDaHxgPzg2XUuImX3X8ehkzWlVMnjHtnUo4NAPWCOprUFqQ2MN5MSz65Y2dKxt\nhSw198qaQhne2225Hg5oNNooHi1XVKXl+WHP1bGFJBij46xesKo0Y0pctR3Hrqf1icpI6tJSqYJD\nGHhxaNl3LVpbjILLekEUkXe2BzbuSPIZaQ2vNQusluyc43q3x6WIMJKTouSsbogp8Lg90Hc9MQaq\nquBkWWMSbEfHvjtOCSIFtjCcFjWOyHXX4zoIeRLii0ZgbcFN5zgMkRQTCEVTSioF0sJ2zLQ+TG3J\nUlAbhSkErQ+4kEl6sh+szhhtkFkSkifKqWVZIqhFibUSnxO+F5T1AWMCUoHwK6q4ZkgDThzRGYye\nZjEIBQaNbU/ozC3ZHJHSssz3eWhO6WKg9T21tAQ1jd3UUvLInvJQnfJ43HDrW2qteVSd86BZ0kUP\nKXOiS1644+TbasOnTs44Kys2Q8vg4zQ/wxbUZcEYPKXRNEbzdHsk5oxSgnuLBVVhOXYDIU5VCMi7\nWrQMi7tk3ebQ0XuPkIIH6zU+Bg5tj1Ga4+gYY7wTccHDVcXRBR7vt4QQ6Ym8dXbGpu/pnacpLHs3\n0jmPtYpaG/6Hy0tuh57NMHBe12gp+czZOauyfLUX9PfHLLr/0bgYue6OhJh53racVxXLsuC26/jm\ndosWkuFuZoKVEqkkCSBmrvZHNn1Poaa+eyGg0QUKeL7ds21HSmMxBbxxdkYtFY9vNly1A1ZAFIIH\nJw2XyxXH45HnbUtOmf3Ycd4sWTcVIDg6Rzc6dn2PFmpKXBUWqzQ37ZGbtqcPgcZqqqJkWSh2o+Pp\n7kAXIjGONFXBw2ZFGzzv7DYcup7A9Ph8VlcIDZtjz+3QTbNwpeTesuGibhhd4EXXcXAtRCgrzVnd\nYKXk9tizGY94NyKSpF5YzhZLQsg8O+7wQyIkkAWcFwWVqdi5A7d9xt2Vha1KWJaSJCU3vacPgSQE\nVoAtwUhDItJlT5RTvZiRglqWIDN9HkgpU5bDFNkaAVnjhgIpRor6QKEiKRtUXKF1IuAZOsPlYos2\nDqkk0a8Qwz1cGsBs7lZ+UDSsmZbSBNNd4O0Nye5JGFY84pF+wDE4jmFPJSqiDlhpERJOzYJPmgue\njDuejTussbxZnPNwsWbMgb73PFiseD4cJg9bSF5vVjxarnh26NgOR4zQnBYli7qkHwNIwUprrrqe\nEKebz6P1iqaw3BzaO58X1MtKBOemZKxUvLvdMPqID5nXT9bYQvNscyT4SB89XkZyntqfH60WjDnz\nznaLUlNJ3ydO16yKkuuuu7M3LJXWXC4atJCcVOU0Le+jyyy6/xHknDk6xxgCN33H0lqM1jw57Bl9\nZGktL7ojX7/dTOVDVYmSinVhuel6Hu92DMHfLVbo+eTJGZUxPN7ueW+zo9KSnAVvnJxwUpZsuyOH\nYUQk2PUDtSm4WJQoa4BE3wduDkeETiyqmsoUlNZACDw7HNgOI0ZBYQsenKzRGb5xe8vB+2lItlbc\nq2sqrbnqOp5ut4x3CcDL5YJFWdL7ketDy2FwjDFgK83aVCysYdN3PN8f6cYOYwvOqpqLZYOPkav9\nns57HIlaGS7rhroouRn2XG2PODeAEpyuFlyUC3ofeN7tOQ4j2QMSHq5qyqJmOx652nuGkFEIVhbW\ny4Y2etph5BCn5I/UibqQ1LpgSCP7fLcgZJGRSVBaCToz9glhRhbNgNWRlAuyb5AykERPSol7iwNC\nSEoDOQs2hzOsPHJ/dYORkT5VOHeKEoJET9fXPFxdY4wnofChYmzfJOaAKp4jECRRUMULohTAiOwu\nwW4Jek8Qmipe8GbxFr13XPkNlSzIWrBWzVSxIC2frR7weNjxuL/FKMMnFpe8vTql84Fd13OvWXDd\ndUgliBkuyopPLNc8b4/cHAekgrOy5mJR0Y6emBNLU/C8PUwWS8hcLhsum5pnuz27fiTkTKU1q6bk\n9jigpaKQkq/vtwwx4ELgzfUJi8by1Re3dD5gjcQaxfmioXUeJSTLwjCERGk1WinOq5LXVmsQoITk\nvK5f8RX+fTGL7g+SlzbCdhhovUMLyTv7LadVxUlZ8e5uy9++eI5CTXWnSrEupvrEm65lNzrawdGG\nkbeXpyyrksPguD4cpnKbGGnuakWNVSy0YnCep9sWFzwX1ZKy0qzrApMyV4eeq+5IJTVSKx6sTjhd\nlrzYbHh3u0PLyXO9WCy4WC44uJ7tOHDcD3QxsigNJ3WNsYZjP/Bit2UzjBS2oNKG06Yg58zzY8fV\n4UBMUyPCxWrJqqp5cdjxdLfDhUwWmZOy5rKuEQKeHo8ch8knrKzhjfNTjNBc9QduDweGFEgJ7i2W\nrPRU8vSk37I9dtPFjeLB6QmlNVy1R66PPUOYiu0XGpbVgkzimDpu+qkpRBtYaYstFcccpgSdjWSd\nsHoabiOjwWePKDqM8lRlIMSCWmWk9hwHw2lx4MF6g5EZnwy37TkyJxblDiUil8WBkA1KCiSJx/sH\nNOrAZ07fA5HZ+QWb/iEZQxYt7VDxaHmLlgGHJSfN/vA2GYEq/3/23qvHsiXJ0vvM1VZHhEh1dYnu\n6ZrGEHzg//8NPUMSxZ4u1VfkTRXqqC1cGR92VM28sElMkyxU3XIgEYFAnkycHceXmy9bttYPGDUs\ntafV16g4crmg8wtMcyaZkSyGTrf8zP2CVOFt/EQjDWItL8MVYkBr5RfdK+7jyHfjPYjjTbPlH/dv\niLXw/nxkYwNjTQyuI2lmcJ6vtzs+nic+zWdUYec7vrrZcJ4zp2lhEzwPl4kqStRK7xxf3ez5/vHA\nw7iGnvaN5av9LR/OJ85LZPCOQ5rBCNZZNj7wzf6K704Hzktagbtt+IfXr7kbT8wpc9sPNN7xs/0V\nfQh/3o3+71t/A93/t9YYI5+m1R3r02Xky/2eYC0/HA78eD6ArFeqlAtZYOsDtVbOcebb05FpiWx8\n4KbryXnldpesvDsfOE6R1ll2bcPXu2tqLXz78MDhsrD1DcbAm+2Wm+3AOM68PZ0gZS5L5YurDW9u\nrlhi4mk6E1NeY2HawKZtGfoGo8JxHHn79IQ8z7zfbnfsup4xTnz/9MRxnLDW45zw+W4HYvjD4wMP\n5xNKpXWB19s9ITg+nE4cppnTPNG4QOuFm3bDlCNP08xhnAC46Qbe7LY4A8eYeZpOXKaItcI+tLzY\nDhQtvH06cVhGaqw0wfJqt8PjOOaZu9OJU1ZcFTad5abbspA5LqvdZURprdC2DSZYYowsNTL5hNpK\n2xh8bvAWJrOQa2bfXwhNprErUF+mAUrm5e6e3kWGEDnGnhah8TNPS8vnzSNfb+4QIGP5/vyGmB2v\nh3saidyEC2NpqeoRybw9v8aR+U/Xf6AgPOUN78aviLkDeyYXyy6MGFGSdKBwOP8CrQFp/xURYSkD\nTf0Sow1LPVPiFuMqi4yAIxD4xv8MJ57fXT7ijQHr+Ka5xVrLcRn53O+ZVflhfMI6y1YC/3j1BhX4\n7vGJ3gWWkrlqW6pRqMJnbc8xJt4dLxSbGKThl7fXnErk42GiN4anPOO8YyoJU4U3w5Z35xOHNNM6\nh8Xwdy+uWYry/nwkWEfwllebLY1zLDWzCe3a3KuVl8NKL3x1tad1/s+61/+d62+g++9ZpVbOKf5J\n9rULLQp8dzisnVYL3z498d3TE8YKV6GltwGxyo+nEz9ezmgutHY1Fv8P17dYEX57/4kfL2eC9YQC\nv9jf4q3hEGeeLjNk5ZQWBt/y+WbDpvEcYuRwmhjnuCbudj1N49bYlyq8ezpwulzo2oauDXy23eKs\n4fvHB+7OIw6h6xteDD37vuVhXPjx6cg0z2RVXlxteL3b8TiOvH98YiyVuSzsmy2bzhOc5TxPvDuc\nGecF6wybruGz/Y55Wvj+cOCyRKy1dM7zZrenM4ZTytyfDozTQtsYbvsdL3Y7lph4fz7wNJ5ItdIa\nz+vNjk3oOWjk3dMd47hOVV13gc3Qk9RwjiPHZWQCrDV0TUNnLVEqhzKz6AI2E3xYaQSbWaIgmthc\nHWhCJtjKFLeEqrjmjFL4cnhk8BFvKsEWPow7crL848137O1EsJlPcYcWT+cmLiXw2p34sntaBxVE\n+HZ6xWEe+GJ4pJWFzi0cU8dUB5yN3C17UvL8w+47ohrOtePt+BXTskfdCUzBi1LVUnRAJTPNn1Hj\nFan5DjGVuQy0+Ut27LnTEzUanA/MOmElIAhf2dfs/ZZfH9+vI9cm8E13y1UYuJuPbKRbbyPThcYJ\nosKvdi9wGH73+Ig4g2Z4MQw03nCcFvY+EHPm7XzBChQVfnV7w5QK3x6ONNaQqVx1HUPjeTxPDMGj\nyBrV1ASSVt50W5w3PEwzXXA01vH11RUvNxuWvJq8b5+HPf5C199A939kVVWqKh/PZ9Izt/nhfOLL\n3R5jDP/1/hP/9eETvXNoFa7ahixKrfD2/MgxjjxMC1kLv9y+ZBsaPo4HPk4XalJirmxcg7FwFTpE\nhft04cPhiKjhVbNl33oa5ymqPD5dOKdIa9bZ+6+vr3m13fD9/QN/eDwxiKG1hpe7DS/6gXNc+HQ5\nEWNhzImt83x2s0Oc4XAZeZpmLstCMI79vqfxKzd8OM28O13IacGHwIvthn3bcJgW/vXxniklrDG8\n3m647Tdc0syH04XzOTLVmW3wfHZ9xWAb7qeR47Ia2DTGcb0duA4diOXjeObhdCDlyND2vBk2DF3L\n3bjwdFllc+KFne3o2w4jhlMauRsvLKVgW8/Od/hgORO5xIVcE6lNeGdpWpDcoDlR2gv4xK6dURyt\nga4ZuSwBo4lfvXxH7yKNydwte0r0XHdPNHbmq+aJ1iR4zli7W7Y8LQP/y/4tWzODFD6mDXNpaO1M\nRtiKPKcSVzLK++WKd9MVn/cHtmahSuEx9ZzLBjGZKQce045v+ndkDFNteD99xmV+RfYnDDOFQNGA\nqbs1jD5u0fSGk3uHmkyugU15wxvzirfpiaUkgnRE1sy2pVZe2ys+D1f876eP5Fpw1vNVs+WLfs+7\ncURrpRHHfZzog2eJmS+HPTvn+e3xCdVMqfCyHbjadPz4cMBawRrLJSf2TcdYE7ddy973vB2PVFG8\nsbzser68vuL3D3dMWdk1gX0X+Ob6hmOcyVXXIsFYPttu/lbp/tTW4zSuYY8ps5TMF9sduVZ+9/TA\nmNao8sdppqoya+ZF05Nr4cfpkX89PlIVetNw03Zc0sLgGx7imbvpzCktNMbzwuz4bLvhmC78cDpw\niZlOA63xDF3Dm27gMI+8HU/MMeNU+Dzs+epqyzFFjvPEsghpyau8qw+86nsEy4fTkbvjhYDhemi5\nGnqCcyw58uPjkUtcCC4QvOOzq4E2tPzr3R0fDwcEwTaWz3d7Nk3H+/HMx8MTy1IQMdz2AzfbjkuJ\nPJ0m7o4HxFiGxrHvN3hrOS8Lp/OFpBBUud1suelbVCwfz0eOl4n8bHLyarsnOM9pXjgsE4/TCKrs\n2p7bbovawt00cpoW5hqxdjVLoTFIVi554WgmqgHfgSfQGWH26wh0Yxf67YR3hSFELktPWeDl1QNd\nmHjdXIjqkOK4bk7M1UKp/M/772lMopXMXd5yjj0vwxODW3hlE1aUXAVvEpca+DDv+E/DmcFkMpmH\n4jmmDmsnjBSstly7RFRl1sox9/xhvOVFe2TvZia1PKaOU7ymutVT4t30gpfdPQXLUgNP8w2X6Qtm\newF7Yck9pTZ0esNCImdDSK95Mk8kIloDV3LDz/0bfpweeSwTvWzIVG79wFgSHYFf+ht+Oz1xqgse\ny5tm4Jvhmh/nE1PMDMZzSgubtmGKka0NvO56/uV4QK3SiqWxnl+9uOG7w4kxZZxTWh/42c015zky\n5UzwFieGrvEE6/DO8WLouCyRjHLTdbTW8Xe3t3/JvgvwN9D9f7ZiKcw5McbIKUV2oWHOme8OB276\njkOa+N3DPW9PTwxNSyeeF0PH3XzmPl347vCEqevU1KiRX21fszDz2/Nb7qYTgR5Kw8+3VyyaWHLm\naTmSamXJBe8t37hXeAOP6cChTEyTIWjDy6YHa9j5lljWdIdTXGiN4zO3481mQ7SVHx9PTGPEYRk6\nR9cGvhz2HJaJ39zdkXJBqHy+3fPF9TVP08S7h0fGlMla2TSe3WZL2xrmGPnx4cxpngl2Tam93vQo\nhncPDxznGbQyuIarfmAzeD4dzjycZ+a0MATPzXbDZ/2Oc151yJfnZswmBK6bDW3jSVr4dLpwGi/r\n1Fu34boLjKKcpoUpLRzShLiGrrHsfI9q5VM5c0wLhIoLlsYZxBdKhpgUtzvimsgmJFJu8MXStGfU\nZHZ+4tVwxJvC3s2cYsdxafgPu/ds3cQrf2SunlId126dCszF8fftAUslmMq5NBxyy96eaa2yk7Xx\nE1UxkikYPqaWr0KhFSVq4lgNj6mjyhopvxTPxmYWDFM1RPX8y/k122ZksCOTNhxzz3F5QRJFJPN+\nviX4CVUh1oaUOsbxcxaTSPZMyi21BG54QRRlzpGh3vJYLmRT0OrYS88/NK95Px75lCd636IJPmu3\npOcR4C+7Le/mI+eYMWLYOs8vd7d8WFbHuV27eoO8aFtUDOO8sGsCS125f+MMORdeDT1jKixkvLP0\nNvD3L28YS+ZpXrnfIQS+2G656npSLdx2PY1zf2ZE+Hetv4Huv7VUlVgKP55OiMDjuI6wfrXfc4wT\n//nje+7m83PTC171PffpTCue357ecUgX7pYzrW1501yz8ZYfpjumOnHOMzmZlYIwlRf2imJOnPIj\nj3mipI5QrrgNPcXNGOAxnZmXgq0G8fCF+Yxt4/gQP3I3n9HYE3LPy6Zn6NcMr7vpwrxkUsmE4PmP\nw2ssysd04fEyEVOhU8dN3yN+HWUd54V3xxNzXOiblZfdN4FjLHx3f880TwTr6NuW11dXQOXt8cBp\nKsQysZWO11cbQnB8Oo48Xk5MS1ybYO2Gq21PKcpUEp8eTqgW2tDwut+wGzoep5nj5cKYF4zxDG2g\naTydBMa48HE+M+aIdQ37bk0YPkviPC8sORNDhGAIjeKrx6ow2gvRRLou4Z1isOy6kSUbYoTPb+7Z\nhJnrMDHlQMqeV80BZyN7O/EinDFSubYTc/E8ppaft/cMJnFtK1ENqTq2JiFSibXhpVME8GKICofi\naM1MMEJgdSSLCpWVoviUDRub8SiZylgN92kgSsZJ4VIDxiixOpYaqBj+dXxFNZXOLYyl5VICl/kV\nUS3FTByXPVkUQcg1QHHI/JIocNAzVRukGl5zgzOeu/nCzgyMOZNRgjicGn4eblhq5sfpwsY5lqR8\nsenxNDyMZ67bgakkMsqu8aRU+KzfUkR5ipFtE0i18vVmj3HCj6cDokJoHF9s9+zalh+OR6w1tNay\nazuuugZkTYp+MQwoyqth8zfQ/WtcpVbuxpExJZ7miV3Tsm0Cp2VefRGs8G4+cpxnppLwxnDjB6Z6\n5nfjJ+6WEzVbgvHsm4ZTGbn1PR/TdxzjxMRMY1pcvuWLIfBQ3jHl+XnQoCWoo2krV/KKJO+ZyhOR\nypIGdHrB58OWkSemmFhqIS2rsLwJlhe8IpeFs3niGGfK3NOXHbd9B24dLLobR8ZpwYrBB8s37Uta\nb/nufM+H4xlbZY0E37RsnmVrH44nxiUCyn7T8cXmmrlkPp6eOE6JtKwb63azJfiGcZn5dD4wpYrT\n1Wz79WagCBzOkTmNxKTs+5YXQ49znlgy4xw5xYlYC7t2y8u+ofUNH8czx8uadGyCIzQe5x1UXZUZ\ndWa0GReUxns6AoubGVOkmoobFowvbJqElkCdLf3mSGhGdmEm2IxWx8vmTFZIRfhmeKC3kVt3Iqpl\nSYE34ZFgC50kNiYhCntbKRguxXJlM94oGwlUIFVDawCUpEL/fDU2CBW4VCi6YEWprFrfCBStiChP\n2ZOoOFOY1ZHUcBd3LOqwtnAuLYs6YnWMpUOlcj9fcS4B7zJzDcRiSfM1WjsuciHmwFQ8DoOpPbUq\nXdmgNfCpnHF1jTn6zO65sS3v5gteDEY8pWQG15JS5U0/gCofxolts6pxrp/z3t6dznixOCO0refN\nZsv9eaJqoXUe6x0vNy3Ts0fDpgkspRCcI1jLVdvxZrvh3fmMMYarNrBvO35xffOX7DAG/wbo/kUf\nJf+ja86JUnWdhkEZgudpHrkfR4oWfn+845+fPvGQRr7qrhEMX253fHu55/fzkW/PH7BiWOrCpgl8\n2V8z1Q9cxu85LAup9mQaPm9essiIdQcOfL9OQPnCtW/J8YorN1LdOwrfMSXHXHsa6em6iHUz4r8n\n6JnqBZt7pvo5L9uWbJ54yL/lXJUaPYGGfp94gWPKB570galkivb0fs/X2xvOdeL3x/eMpRBLwolh\nez3wMuwZp4lv7z9wWCZCcdw2G642DanCHx4/8XgeqQ58Fa63O26Gjssy8e39A0tUDHA9dNyGhqkW\nvj8emUuhLpWhc7za79k+O2C9Pz+ypIgxsHUDn+8Hqq5OYe9OJ6ZSyUYZ+o6tD6gWDmnmFBPRZIw3\nbH1A/Dp6ekoLSRZKtxBCxYrB5XZ16LIL0mWadsSbSucqpTqW7EhOGdqFnZ3xkpmKpxiLkZUDtkYx\nUtmYRGVNHu41Y6WytxVv/lt6gwAihaSKyGpMDkrRZ/NvrXjKGhUEOIQqICpkXWNsvImIQlZLrhYr\nlcGOpDoggDMZUVhKYCkOaxPejzTmWeUQA8ZWih05FahSybJ6Rtg00GjLfT1zqQdyaXAqbCUwZuVY\nJ7Qoj2nCGYeUyK3teGkHvl8e+fb8yGDXCKW9C5zSwpwLD+OCitL3jlQgxcocC5MmWuuIVFwubFzH\n4/xEKoWlFIbG87Ora84pcogz+VQYQuBVv2HbtpRaKKq4v2zQ/b9cPznQfRhHHudVP/r2eODL3dWq\nUpDKPz984FwXgng6E/h840g1YyXzT4+/4+Ny5lzOvLAvaF3HmxB4O3/gh+kdYz2Rq8c7ZdtEJA/c\ntL9BzQeSKo/LnqQDvdng3Bnrvucq/ECsQmE1c3nQPZ+1TwT/A2M1HOKGOW3wOjB0ka054f1vCLIg\n1dLmjnn8mr11XOSOD/WB2TjyEuilx18pbVTepR846pHUg8kN27zndT/wlBd+/fAdc6oYA0NoGfqW\nXgI/HI88xYmkldZZXrQDV23Hp+nCP386kaOCKFe9owsNguGH84HzvA5I9MHw+mpLFyxzKXx7PFAW\npRgYmpZt6/HGcZhHTlNkJOIl0A8NvfEsZO7nkZQio1RqI/jG4cXgimVMiVEi6hXbFhyOYFav2Kkm\nhETwiSEsGCPE3BJNxrnItlkIYQ2tdFSSWooapurobGTnllU3q46iBlC8FEZdjdGvjZJXQy5iVVRg\ndY55TlR49rpFK2dNJMozybD+vDybhIsW0jMQV1YARyGrEAnM4nBGqSpMqcHaQpAZtS0RR1WHEaVU\nYakCBookqo2rW13qhiMLmwAAIABJREFUqFKZzMLTciRqxVRP1YJXh62OWCeOtXDUjBPDa9txyoVD\nnGnMgXOuBAML6+2ptY6HZWapC1UUqdCL46yZSTN30xlrDK82G3JZaa9/efiIefbHtWI4x5nvD0+A\n0ASHt2tGmjOy+v/+d9E/f43rJwW6qRQe54nBB0SE3gfeHp+YibwbT9wtIwed+flwjcNwKRf+j9Nb\nlnohZ6FzLd4r1i54P3MsfwAZUTfj8y2GwOetsvCB1v+OSmIpDd4UPu8vzLHjF8M/0dgDSS3v5xty\n3dFKS9s8snX/zG24Z64eXx17H/nhdMWr9kf2zXvm3vIhXnFadoS6oWkWXL3DN/fsyLTqyKVhMl8z\nWMNJP/HkPzFqQ02Bnha/NciY+N3lR5JbiB2E1vOCG4yBT9OZ75YHagbPyl9XJ0zLzKenM1MqqBiG\nwbC3PVaU95eRmCq2Kohl3whd0zOWyP0xMj+P8W46x43v8N5xnCYWnch1DX1s2x7nDEaV0zxzygvR\nVXDgg2UwLQuRccnEOqGdUjyrvlQMNjuWmCEkXCi4pqxBj8WRi6KSV+AwBm8jVSHXhkkirUm0JpHE\nUmtLZzLTMyAfawMIXhKdFByVooKVFRjutFARdmbdTFZgKZUshoU1bHPF0j9WbcqpwhlZK2TW1yUV\nRNevF9YJwKyWyjpNqKIcS8dcHBWDpTKlgBoQU6g2E4slFw/qqFVJkleT9gJJC6LQlIZYhJMkxnQg\nVWjKmkyiCpMppLoC6KdlggK3TU9Vy1OceXs5MpfMrmlpzWpq9DRPJBXcs4fCGCN3xwvGC84bNr6l\nyOrX652SUU5LpPeel8PAtmn5cDrxTs/0MfJ62Pyl56P9m+snBbr1vw/vU2UbGn5995GZGVHDddOz\ndZ6Py5H75SOP6UCslcyFz/s3BGNR85H3y1tqjVxShzUtnXG4bqLRe14031MpeJN5ijdUAm/8kda9\nY9v9hqzCWFoM8Hl75GAM/1P/X1bZEY73856L9gRt6MOR3f5/ZeMvXEpDJfB198T3ec+L/lv27R2X\nTeD9cs3jvKfRPc6NlPABaU5sqhLUsusS8/gZtVRmfWTuMsk0SG65paOEyuN05HzJZJeR3tCLY6M9\nsSgP48iSFatC4wxXTaA4w91lIi91BQUPG28J0pLywmm+MC9rUvDgYBsaxBgOdSGdziSFanSNxQke\ni2HMMw9zXKNhnMEFizeGAlyWyGgXFl9B1iDLRh1aMqnAYiaMVZxRglFKsUg1TBT6bsZLQQzEbGlx\nTBWCyZxLYFRPbyJGFKQwVUdvFjLCEx2iypWtTCq0ZJ4Uptohkukk0YhSAIdSFD5pZalCJ7qa7chq\nIRkJXFRZUAL6DMRKBi7V81A3qELGEKSQq6UgJLUctUMRMoZcLRmDmMw5d0R1jNXjpZKKJ1coskbg\n5OqgGqR4VCpTjSQ1pFIx2SGiNMauqSM5kvWCVqETy2AcpxS5jwuNbRBrwIIWQWulZlCjRANTigy+\nhVK5xAwduGxpreVq63l3vvDjcqBvG67ajp9fXzGmxMfTeeWavVsHKnwA+f822fjPvX5SoOutBVW+\nfXpYM7vOZz7bDlzwBDG8O19YlsJvTh9QThgTeNG0ZLEcykdaeU/KE04qjZ3o2g2iyhfdDxi5pzWJ\nxzSgtFgyt+HE35n3vA53KGsD5SltWbB87i9s3Zmu/Y6pWmZtKOp405xojfCr9j2IclLHXRw4l4FO\nDW04sb3633Amc8lrFfZld6CmLVv/e3btiWnr+Thd8bTsCWVDNgvFvkO6jC8GqY5dk0iXHWMcWepC\nbiqIp9WWKwaOdeRuPjIngxgl9OsGMsnwkCJxLJQKzls6gT54Zi0cx5GSAGtoWxispXGOMSbGpM8R\n8kLnHJ1zFIFznFlKApRohCZ4vFUcjjlnxhrXq6yXtZo0DkpmpqJSKO0KqN4WTA6UkqkuEnwkG2XO\nDm9WY3gvlTE7Bp9Z1GJQnBZKtVwIDDbymAc+6BVeMls3szcTi66R7ofiifQYKoMpLFigUgoc6sCi\nQmdnBp4TK1RZVHmoHecasJKxZNQoVivn2nKpDafa0MnKH1cMM4ZFLR/yFbUKF21pJFHVrHQIwjG3\nZDVogVIcBaVKYi4NqXrm6LEGpFpyhUUVKRWt6wHaY4m1MlKY6kLSQqOBWhSxBmMMEdCcmVGaKrS1\n46wLx1hJfgXH3lpyUcY0r9lvYY2en1JmTInvjmdKrbzZDWAMU164H8fVcc87lpzoQsNN19FYx1LK\nvxn5/pe+flKga0QwIohYHDB4hzOGc1n4l+Md7y8nHuLIm27DLIXbdsun+cScKk/6kRd2JOvqL9ta\ni/CBL8NblDU/qrWRa3rQyN+339OZGUPhofQUdWsTxp34ZXPk2q6JC0kzVlqWZPjaJQYbITxyqZYF\nTyqWGz9hEL7erEMXJ1bR/aVs6ERp3IW/2/+GinAuLSkHXnUjOQ/Y8CNtmFk2lqd5w2nZ40vHxMTi\nP1KDYKvgisOHynxW3qcTxRS0AxMqQRtChksulJQpGYxf02gbL5RJeZgzOStFoAlCa6EzcC6FQyqU\ntH7YmlZojEOBS14Yc6XKml0WGkdvLRVlipGsGZVKbQwiFYPgsMw5k0ymWsV5xVZB1FAjqFSKLTiB\nuVgGq1Q15JpXoxwfGasjxX7lbV1k52ZmKVQ8Ywl4UxjsAihLdZxoOZfAqXY4qbSm8MJeyKyUwFot\nr1RLI4mlGqoYulL5tlxxUY8zhb0sCCsZPBXhVLY81A1WKlEtGUsrC6fccC4NT9rTSkJYeduLrLTD\nQ2qpxXDIHVYqFiWrktUwp7URZxS0GJYqCJmlCLUKEkGMoYqy1MRUBIdgasEbYSctF5uYNHIXK6qV\nwfXkUikoo1RSrYgToiZiVk6xYUoL1juMgzkWLjESa0XM2tRrnGXfdMyaebgk3l/OtN7z8+01iOH4\n3Eht7Jq799cKuPATA91SKxX4er8H4Kpt+S/v33GXzsyp0FnPL/e3dM7yfjnyL6ffM5bEpCOv3Q0b\nqzjTccxHzilyFY4YKSw1IHhqLVy5T/w8PFBVKCoEU9k+d55/Hia8rJ3Zk1qyOooqg5n5vFM6WQ26\nJwpGDDELX/lIawrRHTlXR8KTi2djI6/8Izf+TFXhqC2X3HLJG4Jkgl34YjuzqOdSGlJq2YfIEhdy\nM7F1hR5hjg2XZUtNlrMmSjgj1mLUYNVirBDnxCWBikGDYsOqP9WojIsSy1o5GS+0z4m0E8olQqms\nEewOQlgr1UkzSwaq4g04I3gnVAo5JmKB6lYe0zcGK6zAuRSiWUEVBFMNpihilaKZrOCaijHrzxFh\nWQQJCTWgallSyxASzmeU8izB8lx0HTld1QsZR0VsZc6eu7pBRLmyF4xUchUe6AnFcle2GMCK8spe\nCBZSFQ7Vc6l7rIGWRK7Co7RsZeJtuuZcW6I6buyZIBXBMNbAU+m4yxschVEbIp6NnVnUcowt57z6\nKrQ2UauQcRgKU/Jrw23xqBq8K6hUYll5bS12LTowxLLeumLNqBoaPEEsSync6ZFcV/P1alYvCVFB\nrGUqmfs0URUG9ThRFjKRzKSFHYakhVOe8bOgAlf9wKuh5dN55DcPdwyhYd+0vN5tqTXzOE3su55g\nDFYEMfJXDbjwEwNdI4IVw3GZSaVyjgvOCNum5c12x2Fa0wL+6fF7znVkKZbWeLaupTEWMXckvsPZ\njEgllz2qE63xrJe0SisRQcm6NjyKFvZm5DNnKayNEitK/0w33PoGoaKqTAoVAbU0kvlZ0Of8rlUR\nsKUw54ZX7kIwmRt35lQCkUAsgcZkbvwTocsUFUbbMOfCmAYsGWMzN5snpuJYqidOA8EWLnUmeotz\nq8+sLZAWS5qUWQvVAo1FpGLLGqUyxTWUEAWCReqaUJsyVIVS1iZRG4Rn+Spj1pV6YO20B2HlCanM\nUSmsqgDr12rMAjVBtWuSbxUoRnFqMRS0mLUaZo3aEYQ8e4LPVAfGVBaEkg2OQOMSIsolNjhT8AaW\nHDhoS2sSV834fChaPsUBIx2qFmcyiOApOBk5lpZLDiCGF/5EYzNZLR/rwKY63sUtIkoVx4290Li8\nmvTkhu/LjoJhY2aSWO7LhmvOfEhXHHLLVAK9jexdxpSyUg/Z85QGnKmkulaygqJaOKeeJVuyGlpX\nMaJMRUjRErMgxZCTATWoqYhUahGogldPfW5wZS0UwzrEodCaSm8s93Hmk70gBRrn6XDrpKR3tGox\nYlgT55QxF4RMCIF91/K4TBzmC6qFKSc+2+5QUaJUMHCYM14yKsJt1/3JSzfX+mdCiP9/1k8KdEWE\nwXu+fXrAW8v9ONKHhuAMlxx5Px/5NJ9XXaU4/uP2G5aaSXXh9/PvuDEzSoOnpfNCFsFIYWMf/zRF\ncs5bsp0xCF4iSiU8n9xr5CFUrQSp7KyQeTbcBhpZxfRbG1AqlcpcQcWiteCk8NqfUISkBkHYmMSY\nGzZ2xBrFm8S5NCQa5txgJTP4M22wrMWfYgTmqaOwXt+bzYIURymWeW5Ro6tRi7NgQBSMBU2Q87Pk\nSQTcmipgUFSFKQMFZM2wxMofZVVrxUsFX8C1YJ4r4rlWSl2/92aN4RFdd35GiVRQwZj13zRZUFPJ\nGarJK+hHsM5i3BqYmKKsDZ+qiAoxeXCKk4J3hVSFx2m9mrcuE0xhzH6lPXzlkjpicXhTeN0eaSSz\n4PmQttzHnkU9jUlkHI7KjRw5lZZjWqvXW39hY2ZmFR5Kz1INH+KOijCrZ5CFrknUqjzlgQ/zQNLA\nzk0ksZxzQBUOueEpDaRiKRiuwowjcyyBS+pXGZkoKKTsgbR6PWdDyYYcPS5UxCo5Fijyx3R1BMVS\nSLq28rIajAq9rNz8TOVUF9QUtNoVLEsCzesUXZmYsVBhWwOxVlpr189zTeRnCV3Vdby+9543w8Dd\nNPE0ThRVGmv4xYtXFK1cUuT9+UzvHdfdX7R5+f/t+kmBLsBSMj+7vkGAF33Pt08HTtPCrw/vWXLG\nGuVVc8WCcshH7uKJUx4J2BVszRsmXag1ouYHoioPecAqeLue+FDZ2gVFSRhOtWFvCoZVPJ90Ff08\nyzVhrXURlN6sYKOsQBQEimY6s/7VDCx1fU1Rg0pl42ZitSRWiVFrMufU4kwCA52JTNlTNLCkBjUV\n72aCDZRiSQW8CkuyFK0kBGkyxihGLHF05LqCWLF1FZVaQRBKgZorFLs+YFnBUQzUKix13XhSV1CV\ntZlOBkp6FlNVCM96qzXCG0pV1KzNFFsMBl0d3FC06CoarYABi4O0SreML9QCaRaMCYSw/k5zVooz\n1ARUIatFBIKpWJuZC3wct1gpNK6y9TNz8ZzSyqMeU8eYG0B53R7p7Cor+5A2POSOpThak5g1YFEa\nkzknx33asVTLxi7chhGpcKoNZREe00BSS1IBFVqTMFQe8477ZWAqgW1YMBQOqecptqQMU2wpRZiK\np/UJbwrnAksMlGLQav70OYzxWf9bBdT86as6ZSrrMywVgqyzc8WspvJVhUteBzZ23jPmTFXlscZ1\n1NwFpCqlVKIqS81YBDGwLIWP5xPWWm77nk3Tcpwn/nB4pAuBq65h33Y4K2vDDPlTtX2Oyqvhb/TC\nX90ysv6yDcKcE4dl4e+2LznNC/uu5Q+Xew458930EYPBCWzdNftQWOoDuYykmmirI5YGI1tmIi5H\nbsMnEvBYVmWBfYZKRWnNahWZRRlV6HQFE1AiBfu8VYDnLQOwXglZC5q1IhRler5uV5U/aUBVhVQs\nKrJen8VQsJRisUDMhmoKRSwYwTwDW57XzrfYhHSKrZaazFpRJ6FoWYcEjIAoRgVNFa0GCqzlMGsF\nlaGSnze3ReqqR7UWEqAVmBWsYI1iFXArCNe8vh+LrjyuEaCu3KSu4KxW1optBnECTikmrf+9QokW\nL6uWmALFGExIpArl4hEVmjYz+EhRISa3Vu9ljaMX9XgzYZ7lX+/GPZaBYJSb5sJUA6fUgijn1HFO\nLaXCTTux9zOiC5+WgcfUstRViiYKh9Su1XE1fFj25OdhhhfNmRblIQ28X3ZMyTLXgEEZc8BJIZjM\nJTaUKszJE1ymdYkpW+bsidWQkqNWSIsDV3B2bSpqduszx7A+bKAWyCvQI2tgp6nKLMpJ19tCJ0Jj\nA3PKHPJp5cuNZ6lKEcXZlefFwhrCtFqVCnX1lN4M3F3OPCwji66TaJ16Lqlw1bZcty0fxwtvT0eu\n2o4XXcdnm92zPvmvdzACfoKguw0Nv/74kYpyiREUvtzt8c+n7tvjE09p4j5d+DJ8iTGG3gq/n9/x\ncYxYd0YELBaRLVaONPaO1TUVptyw2IakDSIZIXNrJzKVQzWIrhwYZCqO8CxKByXqs7oCsKos65T+\nnwAXDJXnK/SfANpQdeVNl+qYJKzgi0NZO9ZLMVSxFLMWiEolVksphiUbqqkkdWQMikFUqNlQy6pt\nVm+QVeeP1j9Gd8uKon9cPoMWNIMUB2IRnmkE6vN7NpAFcYJ7fk/VQ4k8T3WBN8/q1TWageLWZ1OL\nYC0YXauz6gy5VgTFVLNWv+X5tY1FJK3OYFkp2WIRrM+AonVtWKZkmEa3SuJ85aa/rAqAOfCxCjmu\nR2aSBh+W9VjUyrtpz920QZ5Bs1YYk+NOepbsOKSOqkJrK5suEszCu/maU2mIxdPaSGsjp9hzMi2B\nzMO8Uhq5GnbtzMYtjCVwNw0YUWKyWFMpRUhYrKuUaomzo+SAGqH1iZKEnBxLLmj26+dH/nhC8/x5\nMlD/m344Wag1YzD4AgXz3HhczXBmhFLBkmhcx9lELjpRomCtI2Qha2WwbrVjlEquEesMc0w4Y9l4\nz89vrnl/PvP2dOC0LPQ+8Gq/w4gwlsIlJZwx9PLXOxgBP0HQLVq57jpUYNs23J0vVCr/+e4Hjkvk\nmGYGH/imucY7y+My8u35yLku2DbR5q9JFK6841zf8ZSFG1NxAql6Fu0oemTnjoAyqeNcAtEmkq6d\n5qSGnSlUKmNdoXO1A6lkLQQxa0MNKLoK7wHk+fu1nhBU5blOXv+M2mC1rg28aigYpmLXibRqmGqg\nINRqyFVIGCKWKgaHMldDSua5el7HT6uuYnh47rDk50rzj6NUVMh57cAgUO1aeVcFt6zVahLIHrHC\nH7tqSQpU0Lxqp62sNw+trCAgz4VZXbn4UJ7vC1lRV9YrcwaJK8ViZE2HUIVSFCMGzYo1FlHQUPlj\nsZ4Xyzl3UJQ2ZJwoIjDnQC4wLw1mbjBeedWfoMK8eD4UuzalDGQxeFOJ3uApfJi23E0DIoarZqS3\nM4c0cB8HtFQOsWX1Y7A0XcJLZa6Wp9MOMBhRtnbmUFsuS6BW5TQHluSp1eBcpnORxVrOS8O4OGJy\neFdRNRQ1xLrSS1oMKs8Ho3s+uKustxJl5X5MIeqqAilJEOMRhWCFKSu5VkaxiFmlf/GZJkpUqiiZ\nTFWDZNDQUSRTTWYsEVXH47Iqc/5P9t6s2bEkue/8uUecBcBdcqm1i01KlCiJIz3N9/8MGhuNjYk0\nstldzWbXmnfFcs6JCPd58EBWz5jJ9ERJU0W8ZOa9iQvgAsfD/e//5bPbW3aaWK3xzesLLondMHI7\njeQ0cD/veFlWLrXweD4xDQOfHw7//IXgf+LtF1d0qzmHaWRMqeecnfj65ZnbtKco/Ov37zi2lbUk\n/vPz31KbsXplpzs+Gz8hqfLD+sQ36ytjMsyF75df4yiTGkP6gWfLDJbIEoGGRxv4wl65TQuOsxm8\nWGKvSvOgKZ3d2YvjNE4ey6Pq0fE1N7JIhxtikdVcP/7bic63EoXWCRhAHI62J+HBtnClWiiaVhuo\nlthaZqsaHaKFSUtrEh1vcmhgrUFfmiASlcv8iiVEES4dyKVBjoIa33OoOWb/2iBtARF4xmtGa+Dc\n3hVqrb8mdcAs7p+ETRzUurmMQnXEBKchCJ6CYWDmaAUvgqYEqaG5xvJvU8TifiKODqDJ4z05Jy6a\nEIT9tDFIMECO24i48XrZkaRimvjscGLOhadlxw/nO1o/lHJqVOsF0BMPyw47CybKIa98sjvxuOx4\nXG84bSOvyw4XxywxDwUZoJlyLDNP5xkjczNd2NbMUgeefeayDbSWKdUxlJwNyUZdMr5lrCpkQ9Uw\nM2j606msXRzh/YDsX5/Ugx2iG4sbxZRRUmBCnUhQHCSB1xVx2PnMpkLK0HyjuHGuhqiQJXwczrZx\nqRfWnNmWgLumPPBv3r5lP+54Ws784fkpUq0PBz7bH6huVDP+fx1J+d+5/eKK7i5nvn19pVhjqZXT\nuvLV7S1JhU92e56XlYfTmW/WD9xwj2Xn03nPD9sr5woXfhOJs2q435K4kBCKnDhZYydhnvL1+iWG\nMODc5keOXkOEIEox2FxZDW7UcIHqjWcT7lWpgOJc3Bk6texkod/fPFypfiLVeIcRokPdPPV7R0Gv\nntia4gKZRpXEpQ6x8JJYWlXiwmy9S3WEsgkiUfeCsAndGisu3ko/EfqfyYlqZ4heAdzoklU3GMCt\nX+gmYIq2imlggCbgpijSObaxTCRBSZFYEdzcOAjQ6JpdBUn9Z64eOLI0JDuSDNfgS3sBq4qYoDtj\nzMHrfT2PKIYWSGPDJQqxubCdM+dlRFyZxsIurzQyr2VgqcrzZSZ5o5K53124nQqPl8y3rzeoHKiW\nOOSVYsGdPdfMcZs4lQnxiZTh0/2R8zZw3HYUyxyXsRfDkD/XnKjAZc0Ima1mhrEwjY3zMrGsA631\naUSumG3wa/vHIw5FvP+++gLSQmxiAuiGEoeyiSIKozaSwsUunDvxYWLHkIW1wCoLq8FgIKnhJOaU\nwWEYYBgbw0V5WjcOptztBv7i7Tu+O77yu5cn3u0iRXvSHDOde3wm/ectAYZfYNEdUsJxmhtJhTFl\n7saR354eeTyf+frlGRXn3XRgn4fwOl0XHpcLnp6Y9D2O8dl8w8lf+GG95bPd7xg8RACXek+xM1kq\nyMrmcG4jNSf+qcw0Aju+ySeOxCg2uIYe3p3R4JCukENQxpIo1aPOhBFLSEWLa2CQlnoNasAQ11n/\nnruxMGEecAIdE976ouzaL9cWkIbQtfUorUlguQ6IRWeZeqE1uOIemmtcvBVoMdaLCeoey5xMB2m9\ng7SOyIpNgDtWEgw54IXqtOSQHPdYM6YWvzO0Qyrp+jOjS7WSAorQOBB8AFRCObeCyUBKDU2GDVGo\n15KgCW4OmmEy5rGyVeHlNCHe8JqYcsU0wJ+mmctloFgCF7IaN4eNUiqnmtlOB87L1IsckTQ8NOqi\nPBz3PKeJ2gamYQsutCuvZeJSMqdl4CKZZsrtfiHTeLjseTrN1KKYKzk3xLvCzh2vGiywBiQJylz1\nYJLUDgGJQWp4y0gTIFgmKpXkipEoprgLmcoomQ0wLcGXFSXlESudyidC1cYqQu2fRfGM+0bVE6Vm\nLhuQQhX35c09WRPmxku9kFQ7a8fZTyO/ur/j+9MpYD5z9uPI5zfpn+vy/1/i9osrultr3M0z+2EI\nAxwe+d3TAy/1wg+nE1/d3EY8eRb+y8PvuHjlYT2Bw+e7e27zzGtZ+eGysElBpwt/vHwZDks6U+WV\n3y/3/Lubr6Ops5Gnes+5LQxayGIcLfPaZu70wrctYwjVlUNaOVGoLXKkNouud3PhkGrHcBMnT5wt\nB7dSrp2qUDwKbjVltQSdQwtx/TVPlKqx5fZE80Tt+G3GKE1p1gUaATHDn/7dgzUQo2mM+7hhHgsy\nrjxdBM01Oi+P4oYSGKxXPBsu2uGBcAkLIm8Nfu0goANSQU2QP+nsPRpetPVio4JqsDBIgASNTU3Q\nAozx3FtKiHocKKfgQ4OSxxXp3fCxjLAYSRWpiu4MnRwpzssyoVuibpldDmqXDE4hc6mZ8yWTNF7u\n/e6CNDjViQ/nHZdtxGvGNIrkYRemPa+XHad1oLTMkCuTNi5t5FJG8Ma2jkCj1cQwV1JqbNvAtnQn\nMScWauJgY1D3aj/cpIEprhrF0mPKyBKcXNPrAtcYPFgjkgzRBTR1+Dc+P7vUqM3wFNCDMkXMe/9o\nII2qlYsJ1Qs7JpKFTLjoxmkVcCNL5jBk/vrdZyRNvJaN3z8/czdOzLuJ94cD7k5z5+dcdn9xRTep\nUpvxYitbrbyuG7th4KvpLZ/OQVn5bjnym8cfORdjtcZXu7es1niThe/Lbzm1ysUXsmeS3zHpjuKF\nH9eVnOGTsfC3p1+TgCwDKR357XbPf9x/02kIA4/thntx5hQpDRcfeWwT79LGSxuCHGADh9QNAlt0\nghcbwm/VE6NYh+USmyVObaJ6jIfXdIJCRs2oZFoLfNE7OT5oUaHYqk1olqgWRVkMkjQavZK4cIUC\nxY0kRhOBrXdVbiCQxbAh8EZqAnOkRucqQ8NdCePX+FUkHE8lxuFMLN08Ic0QC94DIjCMH7m5qQmt\nG6J46nClCNRQ7ilRaDwTVLioXZSW0AJ5cFDFB8dSwpshW3RjWGLcV2SOw+R1nfBLH3s1IRlkcqQ5\n53Xisjq1KoM2Mo3msPlAbXBeRpAhXLsOG5M0XteBx+OMN6HWRLou6nM4g61VqVuM6SowDLHsKiXH\n8rPE0u3/dQi1wLhRDSgoxXuB68epRDTeVLfYF2DBwfbWGHRl8THMdFrwr+fU2BogjaUsNEtogOCo\nFVKbOuvMqL6ipqQ6kjSzH4XdBMtifDieGYaBt+PE+3nmtWz88fLKgDJqLFFLMz49zOxypnr7WXvp\nwi+w6M4pcSmFp3VBxXlYLvz5/R2vdaO58X//8B2LxRj3xfyWd7sdIPzXpz/yt8+PqCYaxifjp5go\niY3VfxNFIBlbGyn1DvGJhZXNjYNnPtk1/sv5qzBm8YE5H/nHcuDfpbVfRMJT3ZM8sddw5L/IzEPb\nMUvlYhkX4dIy+1SoxBLMEY5t6iU08DD3PynEdWRpYRJ93adY7/Mglm2xGLvePzbdzUOu8fGC9hgv\nJcfysJWfcFuzlcrZAAAgAElEQVSlkVLwaWtKSEvQOr3LjTRVWlLc8seOWa13ympYSj9hve4ELhD/\nDC1wQBI00GY0dVw71CHdsrPJT5BmctKgcRBYUNK0c1I9QREhjw0rQlszqXks9JPhk7CS0KbYFl12\na4ndrpJTWCWe1gkvjm9KSi0MuHeRoHG5DJyr0oqQgP1YKSqsdaCKsCxTSHEtk8bGNFWWbeSyjljL\nlDKGqMUTLQUro9SENUJZ0gSfjSxGXUdsvXKlPYqtg7f+/rohLiSCakaSsFN3GNhIbmwiVAm2wyCN\nrIlqTmJDBZoIRRIumaFDrTUXVllZW2LUfXyakjGNGy+lBnZf4s17M+xIKZFTZhwUivJwvvBu3vN2\nN/HFzS0P5wvfn05cauHNNIUb4M/49osrusWMwzjwZjdj7kwp893pyHfLkW9Pr0hSDln5y909X58/\n8OPlyB9Pr5xrYT8M3AyfkyVTW+ODPfC4ntgNN4xa2eme5vBh2fHF4e9jgDXl1A48rRUlU6WwOezb\nxM248H9ePsN8oJpyGM98W5W/GDcGwiz7ue44S+UmhVO/i/LUdqw+sLQMIlwsM6eGoWyuNE+8lLl3\ntIZJLNkqidoSS8ksNe5rEmOnSDAc1FPvYBJU7YURUu+tWpOPkjJxJ2uLLs3AGEKS2xyo6Bjcgmqp\nQw+xpstq+OhUybgHP0yMaCeT4ao/dXDdUtFr4LVG6iKN+Hm4hFhDDM/96ykeN7WOPyMweC/kDk3Y\njomBDoiqRFcsCTajRWUGYJoKaTQ2lK2O2CWhaoGRTo08hSrruIx4cVpJJG2oQJody+FjUUumtaB/\n7SZDpNFcuJSByyWEDSIJUSPPYKVRa2ZdJGh1Yoi2gJhasFDiRInxw7sJ+kcYqF0XXLnv0TLUxqhO\nEeEaUyEWS9eEMkio3xqJlUxDyNKYdWQxQ2TDqYhkNE+kpoxiZFFWjJNvVBFy3dFUEArjzjkvG+fz\ninjibpj4q7fvyWngXDee15V5HBhTYj8MnQHz87794oouABI4orihkjhulds8MhzuyTlj7nx3fObb\n44nndmLSzC4n/vLuSz6UDzxtZ769vJLUOeTMji/IjFy2wpEf2XylnN8xp0Ziz2rC7+vEv7/9ByZC\n9fTabnjYFHfHuzFLqXs+ya/8zfopxUJpdDssHNtI1sYsleKJ57rjYplDLsHi6t3uS9txqQlEWTyR\nhcBOPZZSl22I3C6is24mtA4nXFVZcTFHx5m04pXAIi2612vhHLXhavFzikLToIRaQ6coZK0oWHSx\niUbWgo3xuGYZ6aIOrGHaICtuKR6/0VkMNQx38KAwJcUrqAe0cmU+Se7Qc4miKtJxzCEewySWhrIC\nGodhTU4eJbLNOmc41YReF3kDlBSdvl8U94Q3mCeDsdE0sVTBzuGD4YSvBQNgsK6Zy5awyxDUNUDH\nhmdhW5R6zr1TV1KuJJziyuWS4sAJz8ug542OZMGrwNZbTgFS+zgl+JYQM0S7CKTmj0rBUDaGf0dp\nIaaoJLI0DqycfaZIIvckjCSxpAwq7yVwXo3PSxInVQ3IIm0UZsQV8ZERuJ2FGeWpbXx3POLA5/Mb\nbsaJtVae24KvC1mUpYSa8Nd3dxzGiaXLjfVnzGD4xRXdsY+yv3t6IInyzeszv76/o3jDmvP3jx84\n1o3vzq/sdOJXh7cccuKH9cQ/PD5zkRMnO3M/jKiMfD6/58X+yHFdeGkrSTOHvMfbGx43oXhF8zM3\nGf7u/CWzVMz3bMCr7/iPt18D4DbxUvdYG0g0UGNxpdYd+1T47fJpJL6acjcukRJrziAWnq91x0ud\ngjuMosBmwss2cy7DR5w32FqKSCwrtgatBbfUugdAqLw88F8nOLgA0hhyENLMhFqH6FJbI1FjZFdo\nlrGiqDliFgm8KRgX1YMalhrQCzcDkBQsIwhiHly1TETuuARe2QRpQUszBdGhiy1Aq2IeDAvwEGIk\nQByrQSmmS4PTHI2feXgTaIvXnbqFhA5GFUOKYuuArAE9kIDJ2VKCTXBPlOrQnDQ0NENLAiXhW3Sk\nUYXBh8BSrWROxfESl57GSYVJ8KMpqcsG+7YwWcdmM1a670UOGMEsMFz1+D1z9cWQHPS9FoDSaEbp\nfhnFEqowEhlxIo6po9ahiTSSNuPgZ17kgEhA896ESYOVp7LSdKGWgcJI1qD5TQoXNy5+RnMiW2JK\nO3aS2Y2JuyHzfd348fXMlBP/5v497/Yzr2vlu9ORw1a4nyfSz7jgwi+w6F5xz88Pt5hbl882HpYz\n3xxfeb4srDT+6u1nPGxH9nnkDy/PPGwLay3sp3vu01tuhomXdub3Ly/U5KS8cJcGqs/s7C1F/shi\nzkZB6sjkcLa3PBqstbKbXrkfhb89/4pRna3NFJzvrfGfbv8JEacy81wPPK/K0M1rqghPdUYQXpYD\niw9cauZuXHCB1RJJnHMdWCw26yba4U3pHW9iLTFW4l0IAYB1cxgwy3GxW1y4GSNJF55Z4LnaPKhH\nybAE1TO+9gJGZUoOs0XsjEkvEs4gFclOSyGoMO/dm1SSC6hRh94lq6AWuC0tPCxcFFLAI0GNMlox\nGAXpLmVXGbF0HrGaIwOQQ7JKcaRFh5qqodmxKV7ytinSjWOuFGCRhqWgkXlJsCjqhmSP7n3UiK9Z\nM26Gl4Tk1qdlhzawLRp4p3pX9EWxpSm+9NfoiqZwTzOioFpLwdDoCkQRoqOFWFJeedIueFfgiTkq\nLeS85ojGIajiuBtjLpSm3XBnwgR2rMEFzsqaMmyNNDhZB6wYYhFFT9YQu0litL5gTWeKWJgxbTuS\nw5zgzU44roXvl4oysB8Hvrq5Z5dHNmshXNSYrqLD/Ree7s/uFjlpYWBu7mRR/o9v/onFG7NkhsMN\n7+YdW6v89mXjb04/4hjn2viPb75kS4WtVP7+6QPFGxuVHW+5z3tmMt9sD/x2PXOYM7vhzF5G1nZD\naUKe/0gx8Gyc2kTeMqtnzuYUcw7TmUM2/vb8BVngvM2YGlsT/u3tjyhGa1N0tevMkAqiYajw2mbW\nlljrwOaZpSbG3K5MrSjIW6JYplaleQ65b5Amg01ggRX6dTUuFpitSIjDOoyg7iRrMBjJjSopcMdm\niFQmNRgiWLFKRlDUG9lqCBaS0CQFD7eASnCmnViqGR3n7Quua3Q5KQq6SnR4bgGuXFVynYXci7iT\nqkFWcl+6mQIlaFXaIkjRFeqQcGlIA9mia9YiyAQ2tPCcRZBjGOaIhB1ny+ADeEvIOUVkroNOEpBI\nju28bxIFuV2fq8dhZqkr9rTLdSU68+RdtpuwGu9PGBLZTxxcI4yQxOMxWheFeBgaD2oBHXV6h7mz\n0zUOJEYKwZBJYsxaubSE5vhZ5sRiVhRtW3T5mrAMVluwdhgxcdJwobaG+UTxzKjwdooD8WwLP65Q\nmvBZfocgVDNSEl63CL18HlZupomv7u6YUmap5Wcd1QO/wKKbJFRhf3h+xsX58Xzmft7xyZAYJPHN\n8ZnLuvGbl0e2YvzZ/g27IXDex8vC9+WJx3KmSej7/8PhK452olX4zfmJRmPIgrRPKHVGyLzaAw9+\n4Y3P3E4ryXac64Fnc94cvkFVSAqnNtHajJpycaO4sx+3gBfOn4AL521AsrNU5fOhkLRibeBYRn64\n7GOUlyjEqyUultjKQEVZS4zpEc4Zh49XoRal1PBr8D6uajIUCwJ+H2PpSzCVFe9wwOaKLoK4MehG\nmkK+3CzhTRgMkIaKIbvALKPjTWCNQUM6WrUvsDz8F5IF/cs1IuqvXGExhdZwbVGEyZ2PLFHALBgZ\nIiCqiITIhKZIcXL1sKVMwZltOeABaUquQjKFwVmnjhmbkmtMBLKFmUvbd36yJjgrWuI+rREdvHVY\noRBWki0hc39S2aOgmYJZZ8KG+u8jH3rrvGe1riCLb3oVkncOigdOLXpNFTZSjfetSFC/nIR6ZSeV\n5imMkVQ/KgIdYa+FUTeWNFN9wC2R1dnpRulSYPcW3b4l8IS2gmoLSERSuI7hzCYs4rR0CWqfDKjt\nuVPl/c3MpAMPlwu/f31mVOGv33zO3TSztsKH04nDNHVB0s+34MIvsOiKSKhi+mk6aoyYZs63ywvf\nPL/y0lbe7/bscuKLwy1P68LzcuGfjs+MKfNWbni7m0maebksfLsdWVlRFYSBf3X4jNVeeF4r327P\nJBGGnLH6K17LQC3GWV9ANha54X5e8TZzLHsei/Hl3feRWwVc2sRlvcGsUUVp5kxemXLj2+WeagEX\npOxsLTFkI6tTTMPx6jzFCE94nUb0uNCq07rTWCVMXLIbnhyK0VqitYzUKGZJDe0YY+saf7HoGMeh\noikWMEXCFEYtKEiqhmXDJGFt6PxQSHVDR6OK0Eid5iThxZBax1y7wEMkXMTco7tNdJwoddK/4zW6\nM1Lwi/1Kg6sWHaxEAq+lHGo2GuYJPVs0mYTNZsHxFGyAvAUQnkzw0dnUQgK9CeIplHIlFHTigmin\n2ZnAObpgUDzHc1Po4gU+8ppd++tC+kjCxyKLBDbtXRmoFt26axd3uCC1of1nIC3MzKO2A4WMkVOj\nWEY6ByVZY59OrLZjaYkqEwKR76Zh2dk6FTAOvUwyGNnYdILkFG3QMsl2VCqiC1UrSxWGsiOrs0/w\nZhw5lcL3lyPvRuV2mPl8OjCk+MyJgtfwmD6uC2+m+X9kOfifcpP/DhH5Z8lS/v3zcwRLinAuhf/r\nu285l42XbeFYVz7b3XBIE384PfP10wOO8P3pyFc3Nwxj5pBHfvvwSPHCP16iqB6GgU/mmcdt4bgV\nLvpAk8KgiujEe92z+DOn4rzaiqqzz4VDGlktUaph6YSmws104X4s1DrxXEaW1fnq7fNHtVhlCIqS\nh2+suZBSY8weOJ0pyxY44No08DyNcMNWhcuSI58N+Zjg4M2RApXw0g0CffBic3K8hXCg1SgCqTQk\nGxOFkjPuwVaQFphtUseTY54wkzDWbpBzQ2lUBUthJyklrCyFoORGvpx22aoEfinQAmvA4x/B/w/w\nE9VQaOXcyW19iYcJg4AMXbncAvNN3UNWLSS+dRBMPZaGDtmi+6wh7MIHMDWSCawEjzdV2nCFPvyj\nWbv38Vym6NY/mgRdtdy547DiQY2LJxwibOHqv9krUkGJIEkVp3mII8QreCJ5oRGMFaV2RogzaQko\nCSKS3hK7tKKurC5Mw8JWJzQ7Qw7ryPvpxGIhsBlzY6uZKa2gmebKflo51xkVYcgD1pzbsXL2GU0w\npAkx4bPxnrOH/Ph+mFib8Pn4jt0wcz/u+Ks37zmVQnLhy5s33M4Tf35334Njlff7n0VyxH+zXf9F\nFt1vXl84lcJaK5et8OFyRkSYh8zTEsm7Xz898XA5cd4K4zgwiDCngYdy4rvXV16Xjc2NX93cYmrc\nppE/vhx5qQvPdmRMyqyJ97sDr2XlcV1p+YGUKyqK+sSN7lh44rhVNoWkzi5F7MlpSyy1MY4Lkhrz\nULnJhbUOnMrAZRM+uT13ylTE97xehqCDeRhyqwQm2qqEu1iNrtaMwE2JiHAzwQqUmsESQgtFGhXt\nUIGZog2kxYImawgfime8OJAYWom0CW1UjYvSLKFuDATEUCVRckAPYjGWhqeC4RJQgluXCEsLHLR3\nnkgUbqc3ikrQ2IYpRlIHq8FbTbROHPAYqbX7NxikEg4Vnhs2KqVcu3lBNz4uzwyhzCAtpo5chGbh\n+2Dm2KC4FqRlzByv8T3pnjU6EKIG6Cw8Cb5w0G7DJ+Ljp9L6bEPgu13wEES0+B2ohxUSPaXhSkGO\nxJ7EzPbx/wuGMeDU7iIm7NIF0ZFTTYxDw0wZ08IszqvNDEPFLHx7x1y4bBPzsOIah+NugEtJsVQd\nlIpzOwtLGVEb2Q8zluDtBKXONBduhwP7nPgPb3/FsVSO28on8w1TVv76zZfs84iI8vnNgTkPwaGf\nd//sNeB/wO2/WXR/cfACwH4Y+e3TI1mU121DFd5OM6s1XteV745nRoUsmf/02SeowLlU/vM3f6A1\np1bj/bzjs8MtgyT+7vl7/rA+0VokQPxvt18i2dgK/NPjE9UbJ1ZuecuNDMwy8M3yyqk2fHKmybiR\niNMpJ6WOr1xqbJ5XGxhs4LwJD0VZmzPNhZQaj8ueOTe2oqymrGvisI/CF/xP4XKOLqVZ7xDjEiZ7\njWLbQrxgfaRNUrla35iH7NRbirwzr+hoDN5YW4Zukp0toAeZgzVQ6hQbcxMmX1EN8/VNI3GXCkOr\nURSzU4m0C6sZ3FBiS25oqM6ADs6Ce1fNxXhKHmI8d+BqZK6C5tSRB8dKhzTcSW4gQh2i+7atO2VV\nIZeKj5mmgT0jUWhzddQ1IotShI2SuvhiHTAcMmH2PkX3rVdDoKt2Wr1Xx97xQn8/+i1dmSICtS/V\nCNGHXPFqrl/vHT4Z9XK1OCIch3NnozjKxkBjksaJmU3CMW0Q41YuvNoeVNm0t9tXBzLv6SRq1Cwd\n5xeSNRRFhw3VjLeM2R7zypRXEGcrcMk7zBt3ace7ccfqhW8uTxz0hrthx5txx5QSN9MArmyt8rpt\nnLfCm/ndP9NV/7/O7RdZdNfW+PO7N4gIW6387umBb09HfjyfeVk33k8zd/PEMgd/8LIWPiwn7oaB\n/WHi/fwlD8uZWgt/8/CBDedOZ+73MyqKVOeP51eetwtuoFn5D/svKL6yFvin0zEEWrmytzcMS7hk\nPS4nCsbkym7XUIStTrxenGna2JqjKXwYrA5sYryURDMnZSNl57KNJG1YDahg3YSUYrGWeuxP28Ln\n1y0WVbiQpDGKhSS0CqXFaj8TnWtOJeJ+mrC2AS/KkJzkFToMV1sKvFMd9UrOvQgjNJvAnGyNIYV6\nLCSmA6mP5alV0tAd0jqWSw0Wg2oLQ/SPERN9lGeLeiMD4fzVvScUvEUxTWakBENOscgThw5pzI3g\nqALbGBCFizBZPLZ2KXEdwxuBJOQ2QAvf3k2M1l3NMr0s1jggwu2s08M0qLNhLBPFN1gZ/cAwepcK\nHxUNEW7WO+DGT/ZFimNkKpkCPRsv+NmC0JgoH3/W4gOCkr0QcUvRbatFanVAEsa9X3iVAwYsNgKJ\ngTXweFdMKi4JkQmXTBLIbCTJNK2oGKoDe99z9EqRlSMvnDeYZUfOG+9u7vns5oaX88I3xxOfzYeI\nZO/G5e1nngQMv9CiK3ROJGHx2CxoRr+6ueUwLrwZZ56WjeO28t3pyKyZ22nizbQnJ+XxfOK7lxPn\nbeP9PCE68aubO05r4eF05OvXV4as3KSZz28OiGbOl5Vv1nOn1zjzkPns8IZL23g6XTiWApLIgzDZ\nG9pROK1bbIO9YTKw3xXwSDg4XxrTzmnmaIq+pqyBfS5toFpfGuYIckzEn7WFusw8vq/dvEZao5Yx\nmqx+eafUGKVhomw1fTTEFhrT0JDkuAmtaSQaV1BtDBoiiaKKWcY3otBaIWlgs0bCmkbBdWNIRslC\n4SezdmmEQkoBT1FJBwE3pPSS0gS0q9g6vu0GsoRQIrshKRanC4aIoquRmpNz3K3mFLQwC+OccQ0j\nniY9kjxLp9VqwB0dMmgJJGs8fAsTHIwu9uiQbMcPpESxF+En56/OQggUxIEQrVw74LDdDNYIQNjc\npC7QvYagBlRkwETtpTeyeCsjwWFwMhs7WSmeWXWCptSm7IYl0qEZucgYxuepy45pqIVpviejacYr\nYCFHVgm/YmuQ2x5JI0OCnBfGOrI1oxThdsz8+e1bHreF7y4v/VckfHFzF+nFZrg5mvRnz1yAX2jR\nPYwDv3t6YDNjLYXFKn92/ybCEF+cf3h6Qk348Xjh37/7hMMwoiL812+/p9TGh+WMqPLr+1ve7w58\nv5z4+vGRy1o4bhtf3d5wM47MeeDbp1cuZeEPpxO30548OffjzPNl5fy08n074eqMWRjzzOCJdVt4\nXQvFBRUnD4nUMueXibUYJjEir6uw2zXcjNKEumkYcJtElLkENqhi1CVRW2CjaXASFjQud8qmFBs+\nKkYTrV93RrGh45ICNMbBUSpNM9auG/Sw25nHq8tW/15TUjVSbiiVMgzBnBBFi5Ok57VLiDosSVgh\ntKCghddC4LF2TbKtvbBdR3GR8MP1K/c1/ATEhGEkFo3XqJriDG6kFtE9xa/WksFdmDe62TqUofs0\nqDDSKVzNg3ygUXDxCMmI3Vs4uH2MxyH+kKtxewpam2j4Gbt1qCFQgl5srvBBX55xtaCMgpx6Mc1Y\nj6dLvagWUn/fDKXRs9FwJjayGIWRxQYaA6JRhM3HcIXjSq1LmGQGK93wfsKTRhqFC9pANeThBcE8\nUX2H0UiyBnfZnFebEIzbfMNtPuBeOHFhymE/OUriZpz44uaGl7Xwum3ky4kxZd79zOPX4Re6SHtd\nV749vpJEKa3x/fFINePhfObxckEEDuPc43aMZWs8vJ55LSsKfHZ7w2aVicTvPzzzum4c28phGpk1\ncTeOPKwLT+cLr8uKoNwcRuZhILvwzcsrx+3Ma2vcjJmUlf088OF44rytXKTg2UiaGFJYUS5ubCXK\noiZDUkMHw6oFi+GqVsrOOHpEyxiUVTAS0lVLmoJdYC5YCaVYcSErZGuk5NCMYhm3FD4HEIq4Tuyv\nNQUWmgS31sfKRmX4KABIzVGrSE9haJqC0mWg1UnJaH4149HgvFYLHqpfi5hgXf6rV9FD60yAj7SL\nTtFyushAYgnX/V5xJ7Uw4cndA9iHkD1jxtBASzAcSBFLYzkgITdhqNHNm3aztW615d2HApw2hFDC\nhGhj+jKPZCFfhpDadtw0am3v1LuwznuRjC/aR8puGNMrEcYE/UV9xN1HNsJKyUKtR7TvUyr9+RVE\nBzbLqFZEEoNs7PLKa7tB+sGWxNmnI+d2QDXgmNWE3dRY20AWJ6mzeWI3QfMMphyGgVN15qkiPuKW\n+WR+x6Vt3M2JfTqwlMZn+7fsdOYv7t7yyXTguG283x14P+8Zh8SnuwOO8Mlhz5yHf+YK8D/k9i+L\ntD+9Gc6cM1kTY0qICg+XM/tx4FQbX9zsEBOO68rf/PAjmVAWHcaBX9/d09z5hx8vvJ4fSSkzZOV/\n/+wrkgovlwt/9/0HkgjrVvj8sOd2vyN54jePD2xb4VJWhmHm395NDAjfr0c+PJ65bAsV4X4+MKTE\nxsLjpbHVKDLD4N1/VSjNYjSrTkTUQB4aiYavwtaU0hIphZGjjq1HkwttEWrLwVpKsNPKqI1m2o1v\nYhmTxJlzjYizZpSae0qDM0rYMl65ptVnrBnqQmbr5jSRXCtdKqrVEI1Mr2IJSxnQMDJrYdIOQktd\nAis9K42gY3GV0Grv/q77Izyq3pDi+chPyyptxnA15unFVmrguYNBqg5Zgp8rElhxVdIS1gdhSRsG\nQYzxUNqj4a54uAwBkSQTWvOfcAXRXkc/Gm+G0s4NcsKvirL+Ej5eph4iDO3TxRXtNXdcEtfc5oR3\n4DpRfWCUSqLgGOqNjblz0AP1vckLS51pnjlbAldG2bDuaOdpjGlBNaLr3UkdElINLw4q4WWRwKXQ\npAEjbgfGFB4fYwpXtmMpZDUOeeSr+Q0f1oXvl1eMgBM+l1uW1jjMEzll7KNL8M/79ossunPK/OO6\ncimFZsbD5cxf3L1hygM304U/PD/Tahgwf3rYczftOAwD3xxf+XA+8cfHI8Uad9OBz253VHOeLmfO\nW+HDy5H7cWYeM//+k094LivrtvHbHx4YxoSmxKd39zHyGnz9+kSrhmPcz7eMQ2LTxofTkbJFUOOs\nQhoC/zw3o5USnVYWUk5oBtGCFYuAQROkezAMQBq2yAhDwiBbHXUjD92a0ZSyhhjCxVExxrTFiOmK\nl0S1jFdhHJzkBXLEclvNWAmK1uitCxcSJrDZSKJ9tIAMOwenMAHWPQJqdKv60/vjnoLFIB5hmdZZ\nDInY5Ecsbe+8iSKVe2fZiI7XDKqjufvBdh5Ysuh6hy6yWEflKgdLBqMFjGBAG3pd13BQa2vXLtBr\naSKEFiJgGq8H+ZhuQbpCIdc7xBTggwctD4/O/TpP9tw3kQJd1EwXiOAacBBGJkzpqxsTOZgKUkls\nvbwLm00YiUlLeAXVMLtxCXgn0PNGYgMmCpnFooMdZQUSK4kimVYFyQnNcQxIqlRraFJMJtDw5mg1\nUkFOHijym/SWW9kh2ampMo8Da2nYINx2eOFxWfhwOmPmzHngs/3NP9t1/7/K7RdZdEWEKaUeDWLc\nzzPNnB9PJx7OZ07LxpQGvrq/CyVTGng8LbyeFrbV+PLNAXPh/W7m5bzxulz44fnEbhh5e3Pgi5s9\n7sLD+cwPr2eaOW9uduyGkbvdyMP5HFDGsnKTB9og3N2MHEvhuJz5cbmQJXZGt0Om5sriwnG18D/1\nzJAylgpQQDbqpZuXe2IYHc2VPBjURl0iLDOaxMqUWqiUasKKUGsUgywN1Ra2fk1Z2hDYLMKgjTSB\nSMT6+KaYBPd1Sg7S+BgN3/12x9TCaFe8Y42BDyb1j1izIUFbQlHtLAbCCMWIx5D+KXW7FrHe0XZ1\nFlgkTWx9oeYR9imTUIQw8q4WMAOhXtiS4KkHdhqkCoMSfaISSjCEbBIwbQuk4IpoXFVx0qIT9qs/\neyaMdvTKLe7/t2PXaPu4Kozn2v+Wo3u/shM+sjSuND+5JtoZbkMssci4bAi1F2Klhu6OyBJemaTQ\nfGTzMZaMTZmHwqiVl5ZYdMJqMF6GDlVEvjKoBjzi14WgeQ/uTGAj1hLoDLbAVEJ4YiPS9rivpKly\nkZW2OD+wcBhH/vLuHW/GmYs1vr9cuBtGUlLupwkRxX7mUT3wCy26zYwxZd7sOgnb4e8+fGCXMy/L\nyic3B96NO3Dnv3zzHeBsFhHff/3VJ0ya+eblyNcPT9CM10vhr774lH3KIMY/fP+ImfDd6wv388j9\n4Zb7OfObx0e+eXzmadlI5vzq/o5DVp6s8nI583I5stbGJzogU8ZT5aUUlhW2Cpm4AOYhs+YtttSb\n4Wft0csy6SwAACAASURBVNeODi0WVN4olwgjdDTSbmSLi7s2liX4u8mFnAJeUBVKi8LYTDuDwVAV\nVCrVB1qJtArxQA9TqoG9esJaFH4xQ3MXYalgLUUxMcjJoUbqbpGARZoFfck7dampoN0Lxq3HC3UO\nrquEyiv3Lrd0BsCqnTngWO4rwebIGr2fukSXnbVzgIWMMzY6nxZWuTat2pd5Mf0Xgxy6kUgfJvi/\nQUCIyCDcf2IrSMd8oRdcPuLLAYfk+Lp6YBhXiKH/N0pQ/EidRUHALCYebAIJx7eBCu5UGXB3stRY\nglLj99cSlzb3ZaQz6UZtYyxUtRuGG8EFFgtqYQ1ISC38RrJtMXmI9R1mpvVw0TCXiGUlNjJKJomy\nGxLHLXOqG5PCTT7wdjpwKiuvdWGzRpaMj8ZFKn+2v+UwThRrP88l0v/n9otcpG2t8YfnJy6tstXG\nd6dXbvLIfhhZW+PD+UwtlQ+vR1arHPLE292e87bRvPHh6cJxW5gk8Wa3IyWhqXNeCg/HE20z0pj5\n9HZma45Z4x8fXljrxmWrvD3sIClDEv7x5ZXTcmGthZs0IiMkyfzYjkEZW6JDSp2GtCp43Wheww2s\nQ4c6VHQouFsUvxKFYR4gj9GDmkGpYf3XUIYURZOelVVaphkklKQbqTtztZYwS7QiaHKSNyyFHHVt\nAWX4VXjgIOo9mT39JNl1C36o8tE43bpBeAawUE15p1QhEcETDloBLfgVI72aSCDR6Athfi6xLIxU\ni4RUCz+DIWAN7di8FOnR4RGxU1Is3hKE50X9iep1jXRzJyCOjiN7ikMChTbROcNO1XifaHwsuPKx\noe3gxBV6kPYnOC4BPYgF7NA7+Kjn4W+MxLJMtCCkkFynQrMRV2PsBTf4D5nmiSmFr27CSUNh85ni\nMOQ4IG+GCxefaa7k7LQmIXTRMMgZU2VjQJOjMmDeGMQpmsgqDGnAamY3jEFoU7ibJ8wG7tOe/bBj\nl5Wvbj/leFlIAm/nA5/uD/ybu/ccS2FOmZtp4nYc+VWXA/8Mbv+ySPvT26BheH3aNrII5sTGtMGy\nrvxwPDJJZr+buRXls/2O81J4XTeejws3o7KTga/u71GMl7Xw9Y+PHPIEDb58e8PNbs9xW/nm9BBk\nz+bcpZl//dWBsxnfvzzz7dMpxjVX/vXbd6zeeG4nvjk/hZGKOe/nhA9wxliak0ujFhhkD7nhWvBh\nQczwJYxqJIUQYpycNFQc2LYEW4yMYwqfVBc6VADNEwKMyUlaaKK0arQy0CLUgZwLpJ7+68ba8tXG\nldz5vEKkE/uVZuTRRYHTpAdjJoCwX3FVnIbl7uQlIYYwi2QCa1Fs/TrOu/QWmPhTe6cpLQqsJbSC\nmNNUYHJkCPkwoQ3oOWuxsEs9RUTtTzLWOiHCekOKhTyh6ySCgXBlTyD9OfYP19VLIfWlHz8V8MCg\nU+iQQ07Hx5OTYFZIPwwciY43gWOYdG9eh9RSiFg85L5ZG+6NKkLxqWO8gbOqbLgOrJbJGuby47Ax\nOJxq4pJGapx0uNWP00b1oJ1ZHqlbN9BJHt32kEiW0aaIzJhUUtqwmrtdZyRWHPYDpVSezdmdXtkP\nI39+85bbYWZplVOpjJpJSZlyQlX/25XqZ3T7RRbdasaYBv7ize6jUOHvPzygJjyeL8w588m8526e\n+PsPP/K7DwuXrbHWlT97d8ubaWZtxh+eH6HCD6czX9zccjdPTPPA7x8eOL08893DkcOcmXcjb95N\nPCwbL8uF719PWK3c7fa8mUcWrzyXlR+XV47nC/uUGKaEjsJrWTmbsS0xr4op05yoCUgrrRhpkej0\nPDNMFbShOawJy6psa2Icw3JSc2RtGUrdUiQuqLPLG5KMRpDmrWRSTxKWKShLTurS4ISbkIbAeEMl\nd4307rO0gGrBLAWVSa6+roT/AVHBzAwfUl/eR2rFdb6yHKO+KH151uGLvs1nkJ+w1Y0AVN2DZjby\n05LNiMgDFywJbaabnctPgZbFo9COP+281Do1TWIvRw4Wg1zTINoVd/XOmrg+Jj9t3K7QwnVRKD89\nz8iYH4jTgI+/N/oiUfr93MIZTwj83aikQYL+BhQS2UP6W13J0kUoLRKlnSEOihYcYfFQK8owULVh\nZWBIwTQolqlppFm8oGQGEvH1oh6/IMmYBc/bvdtTpsxgwqQTs06cfeNYV3DlrR6Y0sRWK6UvrrMq\nSy3oIPz6cMd+uJqa+78kR/wcbyqC4VzWLfwWlo0pZe7mibt55lI2siT+8emZ81IQc97vd5jPTFPm\nw+nCy7JwWRt388RfvnuHjqAufPvywmWtiMGffXpHUmWaBr5/PfO6nHm9XHg/z5CVw27i2+MLz8uF\n58uZQZRPDzcchpEPduZpW1lLRasxarc4PBDYl2/YpafgyogMG54bOZcwcFlCOJpxDuNGSgVy+n/Y\ne7MeSZIkW+8TUVUz8yWW3KqmN85wQID//9cMhrh3MN09VV2VVZkZEb7YoovwQdQj6xIE+ULwAp1t\nQCErY0sPj3AxUZFzvkPOAa2RvAphFCQ6vMa0eQR7dk1vCtVP64PHtFsLlGKIBkdADgG1TJFAK740\nMtPOGW/o4BBth5V5F011lUMVZ9mW2/LMetqDl1xXhsVuNsAdZnSnHLEXN8ELcQd6Y33WG6xL2Zxe\npptrZIldwZVwJQCGrV7gm960uV6Mf8tNaJ0xY6EruXxP1wuznxYkuXvutyIFXkcEt7/3y5p/IZU+\nF/K/i6gXsEZHRMLNVxxMkOKSKsU6j8K1umba2RU9h66zM0rz4741v1EMIZPUmGtgtREsEKSSamFt\ngU0catRwcpmq9iWqAYoEddpYMWqptAY29ucjB1IeWSgMobnGvAZGG9nHkcMQeTNOfJqvLs2MkX8+\nvuVx2jHnjV+vV3ap8DCOfy+jhf/H65ssukGVKMpfLk+kEPg4X3k/TTwOO3It/Ph8gjaDNKYh8S+P\nj6gpf3t54c+/fOEuDpzzxj+/e8N+TDQq//63T4wEXtaFh92OD/uJFpS/PH3h118urKUhAf73331A\nRfnpeuY/Pn+iUbkuK398eCCI8Fw3fshXSq60eeVuN5Im5Vw2ZoxWjJoraokYe8e3c8kNZMoaiNUj\nXsZ9xrSi0di2hGyeXDsKpLF6YUMoNdAy1ByJwZBQaNE5VqUFShEC7nATaYSd0ZpRWvSNdXBimPQw\nw4bQqi9vzLNuEHMnXdWbw0y63FacAWHi2k8VqK6l5uYku7WLI1+DGrP4n2avMqzXGWl1zq5Uc2h4\nFwPcxASa/XGqCIXqX7cHckrrj9nwjLbeUWvzm8OrguL2yglfl2dfZaa94Aa+3ixubxfldT0v3vF6\nJA9IdPmGiCCmaHOZYLFKw+Vv4J15LYLUCEMmcKORQWOgWSN4SWYIG2bKVgeqCcUCSXzO29RviP78\nCNWSd7A4c7mJYIxY6/9e8G9L1akMUiIaPcE4y+q8DwlkgYRyGEbWbFAKL9vMLgV+t79jN4ystaEG\nUcPrktFDiv/+i+43uUgzM/764hzchnFaVn54eUEqPC0L65Z5O+15POz5NF+Yt411znyeFz4cd9xP\nEwh8vl6xanw8nRk1cL+buBsGPi9XGsLPTydK3djFyJv9gZXG0jZ+Pl05zzOBxv24R6PxtGV+Wa88\nXa5O9MLYTZGlbpylcFkyYkrIhRi1w2IyLW1IabTunArBlQuyq0hxQX0rglglNGWaMhb8F3xZI4lG\nyeqdbVBM1Qlb2ZWk1Zw7rKGQNbkFtXTnv6pbkk06rxeHozS8k6vOCqAXUFOfc3qcevBiQ/8lC7/p\nWI2vQO/b+5z54h9T+/tvneToCzV/Xx/IIl+LG24h1tLLd4WmreMS/Q2huN0AbVgFS4IGN24ASPFO\n+lYYWu9CTYQiv1Eu3DS6twVaD8q8FWG5jYL7hwXtOzfp4w1rJIPaCtEMS5HWs8lEgS10dKa5CVgL\nghCs9jmvukoFX+xJaGhNFIzkUwsGzVgMbFXR4K7EMfmTk5syRB8zqDqfuJTKmMx5xih308C8QhQj\nJZ/93o+JZgN7BqaQ2KTydtpRKzyEPY/7HYME/uXhDVtx48fvjvfsU+JfHt4wpsRSMt8f7/4/fKX/\nT73+sUj7v7tUXOAPwtYqu5B4e5iYx8h9Glm2lcs883zduBsSbw4jD3c7yHBaZj6/XDmOA4/TxP1+\nYpcGnucrT/NM2SrHMRLjyOM08WVd+Xyd+XK6MA2Rx2Hg7eM9p3Xl83Ll4/VKUOPNbuAuTqxD5fP1\nwqlkyIWpNYhCPAZKNtaYaaUhWyRmSCMUKZitEI1Qff4npaHaSNFDFDdTKI2cA6kvfYax21MxtipQ\nXAERoxFjxYKQa0BLZe0kK1R6oXDffm6C9IoTkhdgYp/zqg86zeQ12JBmFHpnE+nD3ltL2vxjtH1d\n/ctvFmgK7PGimvvnVvr5v48Ykv9cqV7VxJx1a2LYICDm88qu2Ko9o82CEZKrJLT2FrYa1Qwm6cs8\ne31c1gwdumMt8HWRdhtNi0u/bl1vB3z5uFe+eiMET59wi7R67py4hMoskcwotSsZpHedRTAi1SIi\nBaH2hZ5rGBSfO/to2ccOq0WaJbRErPYlZ3UDSYy357DPxvvNyRB/WzfP5E2ozRhGfz6lCtp2YAWi\nd+ShKrEmjsPAm2niLk4eUHm9MobAP9+9ZRciudY+XojcfwOpEfCNFl0RYZ8S/+3TJ0SUL/OFXRz4\n0+EIqvz7r5/4r/MLUoyXXPjfvnvHPiZO28J/+/iJQ4z8cr7w/fGOD8d70gD/x8dfUbnyMq+kGPnD\nwwOHaeKvT1/46/Mza2lc5pn/5c0juzHyXDZ+eHpmKRtf1oU/HI7ekCn81/nEtsDLvHBIIzqOYJVT\nW1lbY2tG3Lp9VgSbKsUKJVakDITc3KqZHO8XpopZoLRKK4EhC0kzYxBa8iNnq5G6gQYjItRRaDc2\nQvZi2UQYxPWuTaqjHIGCqw6couUsWlWfNxq8bvqlBzKaeSeJ8GoR5nasTNZnnv7xVvQr4EZBokH8\nTTcMXrUCvmKPfJWRrTeTQcGiFyrAjQ7NcB6QwEAH7Ljg4DZDdtdVoKWOehCX1frI4Cb9cvXEbSb9\n2rrfxgv9jxux8DVIoo91K/6Qg91uUAHE5W1KQGqldjtyiH14E5RQ9XV0IZZpKgQNPn9HPbyTBlJR\n61Q3RhSnet2kENofjBsiPA8tb/3OIIJKpJohbfDlojQIRmhKEl/eNhq1NdbSiLFRS2MXEndp4lor\n17IxycAhJb6bDqiGbi3WLrVz3GoMf++2CL++yfECwMfLhZd1oZmxlMKXeSaYcFo3nq4X9iHxsJvI\nzVhbJq+N58sVAe7GHfsxcM4brcCnywWzyiiBN8c957ISovDj05lcPIXhfu/pBlWMn85nrstGKZUU\nhP04sLXCL3nm08sZi1C2wnFQssGihZd1dq5LqUzis8am5nNeCuRGwvW5IhWGuWteDcviR/nWiMFj\nzEOQDi2plKzspCEx9K8buotWPOX3NiEIfdmjwtZ8m27SF0KhUYFauoMp3CA1Pl2sqBfdfgY395EA\n+BgEz3B7jW0w79KkCBY8Uoe+QKJIp37xqry6KQbMgCz+Yu6SMpJ/TUGwDa+ezccLN4WA9GLqsjF9\nZaS7FDd4sRVei4R0DW6VLlnr7+f1cfKqZAi9Yb7V1B7r9j9gJOTrt+3uYTOS+Y2sqfVcNEdainr3\n2rELoM4by81bZ0kuIRO6xMuJHAyibNZIoSdPCIgOtGZ91NFoNGIKniyEEmMk10yKPhsZIgxRyBvs\n4uQ3KjGOQ0II3MU9oRtIHncT69p4tzuyj4n7ceJPx3vmUkga+LA/cBwm/nB/79I9lb8nytg/xgv/\n16u1xmEYwIwxBP528pluUiHFwHd3R2jC+Xrip+uZxzBRpPK7h3v2ceCyrvxyupD6SXgaRj4c7lnL\nxum6Mb8UxhAxa/zh/SNbq85t+PzM3TQyCbx7+4YMfFlf+OF8QlVRgff7e2xfOeeVl/mKLRtBXdCf\nxpFaN66KF3QRdAuk5EsQ2W8UqWgLrmwouCmBjA7u3CfAuriFV2tiP628nqTxY2deAykpNWbXQETB\nak+AbYqq9rluX+E0oXRspO++tB+5m4O7pWIoEnwBJjesoQiM3o3ZTeLF7f0V9oKqE7Ws57P5IqsS\ne65ZU/HFWg9wRM31ub3ZM0BWH0Fo6JOMg7xCcSTj1mV6EKT4KMeK9Lls9XieqOgI5t5jnwcbXwuu\n0G9Q/mR2AidVjBR7V41/jvYx783n0SW9JLzZX3CjiERHQo7AWrw4Sh+Vh27cqMUfT1So6ieFINId\nZG6GUamv8q7WEtEDi7q0zPw5Fs+g8Dip2m98ECQxDbBlo2QhEmjiIaXg7x+Z+mxaiOpmGWmRhynw\n/f6AGczbxsfrhSEEvj/cE0VZauZpXZhi4HH8u4jp+X+9vtmiu0vO1K3NuOYNa8bv7+6ZYiS8KD+e\nX7ACz/OF3x8fuBsmPugdf/nymZc28+k6czcOvB/27A8jP7688PPpmed5ZWuF745H7vYTX84zf3t+\nYd0y51z4w/0d+3FkpfC3lwtbK/zw9Mz7445REx+OB35dZ6618PE6s5OI7pQpJk7lSrbKGZ/9jRI8\nqTa55GuNnmkmNaLVwbBNKzJs1CpEgbwGdMv+wgyFGCpLixCUuhgmjaEODEPxo6wGh45n87ZKpBOn\nvABYM7cBI2hUWrht3+1Vk6tOAfQZ8G3JRnvFGkrwrlSq0qyCKJpyZx24FK1Wh2ZrcDSjpu5UKO7M\nsqYQG3FoWO+k24prf0uXWY2GhYbUm7MN75ijIdJoyQukFIeMm/lyUgQY9dXLYObfI60X6G72uAFr\nbuSzoD5aUZEb9xzx8bnL8UJv1qWbD6wz2iO+9NNey1eQRB8fVB9NmDIglK69vc2GXVbsZo9mnUfe\n57y+Bu0zbHFrdOppIXTanpR+ktBI1PaKKtYSutHFfKEryjEmlipukEGZt5Xj6IkkYwgcw8hqhTUX\nhuSutftxIobAFNwMUVqjlMqlGe93yrdwfbNFV3G5Suivv9KMuVae55mn60wpDhvfj5EkkbmsvFxX\n8lYYhpE/Pt57AoAoPz4981JWKJUP+z2mwjCoB1jOK8uWeTiO3NuONESel4VPlwtzzgwx8q9vHkkx\n8Fw2frl4PtpWMr/bH737DI0v1xOrVUpuHIMf+UOCc600gVyMUCJaGtGUpuV1oVRa6goHRbVgooTB\ngwu31qhVidkXJ+Oo1Ojx5M08Q4zgXVBTH3aahL5fklebrAqYttflkoi47lX655t81f9LIw24xMzw\ntGCh4ycNw7m+pQpalGqNoIqmwhCbS9yaUM3Zws5iqEjohWzrM2RRTBth5woLkgdwgs9gmwhM5q1l\n9gLtDZ63qBK9uzUFKU5Fa01Q84SHFryYi9CVFfb1l0uh9Dl1NI+KvE1ibmaLYH0SsUHoPN7uWiYo\n7BosBpa63E1gtEAtMNxGul1xFf0H8HVB56iQ1x3lLY3ituvDxEE9BKT1vWN/+CFArhCiMkjgmiv1\n9u/HwKDO32g5+LcahajC3TCwayMNI0jCpJJzZQ4FE+HtNHJIY49kv/J+t+d+HPhwPIDhy8pv4Ppm\ni24x481uzxACpVb+ffmFn1+eSRqZ68Y/PzwyDSOn65X/+PSZY5w4LyuPhz0f9ndYMP7j109YM9a8\nkULkj2/fIgI/Xl74+ctCEsUa/Mt371BVnpcrf/78maSRZS384c0jQwy8bDP/+fSFZsa6Nt7uDoxH\nYc6Zn+YLuW3kbCSN3E2R2hoXWbjkRsE7jX0csKX4PDhUmjVoQsgRyW4LtZhdZzq68N2KYTkisdAE\nwh7m0lO2svoBVI2gnvzqJ2l/kTb1bX37beJB3+qLgozAq360jxSiMQ3+sSJeCHwX5SGcw+AuLDFz\ntYQJOhRGaTSRXoh9/ipmxACqlWFs1M1NArXgS7jWCGMPkAxQNv9ZeNU3ZGxoR0K21r9m9cdvyTtL\n1+uaZ551wI1qc0lvn+PKbUSi0M/xr79jnvPYRxi/maYYr4k+rq0NeJfZjMGfHSiGJOf6VOt8H/Ou\nN0YvxrcyOv2G9VCLN6op8Tqj7uCz1wnIbQQuuP741bEsDqZPqhRrrs+tQgywi4mtVkeGamKtG6JK\nEiNaYDClhUAMHmlUq5Pi3kw7/ml3x1wrl3XlkyhJlYf9kWbGNWe342vwcd83cH2zRXcIgY9n5+Ju\ntbLUzB8fHpjSwN088Ok6o/PCr+cLb3Yj+yHx/cOBXy5Xfr2e+PXlCgoPaeBPHx54Xhc+bxdeLguX\nuvF22jOGyNvDni/blfN143lduR8G9sPEh/sdT9eVX5czP59OHMNAiIE/HO94qhvXsvLD6YUpBKRG\n3hx2zDVTpPG8Xd0eW+AQ1ItKqVxj9XwwBLboHZZBSJWqhliAGpHFu1vVAskwSY4zLH7EL0UIMVMJ\n6BDIzZBmZItEMWrEt/amvRRr72iFEHmda94spxIbQ4IUW89m8xd0a57hNk2FWrV3X95STSljnUuO\ndjNGTdSipNRIu5uQN1Baw7gdf52De5gqLSu5CbkIBENohCS0qSsZtm5KaAYWkKmhDZo4YYtuiPA7\ng58a2k250Hqr2YutQe86jRuIwaOAvOg2vKC/5lS2rsKiL8PMRwyJQEC4htLNHMIBIRqsNEr0uXSs\nPvEwvNASvzJ2hl7Yq35l76T+X4PXWXfqD/8G9Anmqq8QffkZTAiiiCZGTbTqi1TX+nqa7zU3UguY\nKXkrhNFNGTsdmGRws4kIYwyIKcdhIGnkYXS4VK7GJXvR/f7498/ShW+46EZVtlYprXVuiqCibLlw\nXTLn68L9MPG4nzhOI6METsvCaZkZg/DmbsKs8TDueJ4XnuaZpVaOMbIbR+6nyDk3nuYrX05XxiHy\ndjdxmAa2Wvl4uvAyb0Dln3YOL68KP11PXLaNJTc+TDvUAuGofJ5PLGLMy8oQR4I14pi4tkyhcakg\nFgkYQzZqLLSqbGOFIlgNxC31JFpDppVWAhorkgVbXXNbHLVF08HTbzcvrmIBDUaN5lBsE8wqtICM\neCJE6lKuzsoVEcJQSdEXNa9nX4NhcPtxip5aENT8iFm8cO/iRm4eoll7xMUQV3ZRumxJsWZsBWoL\nxLGSoq+qalFMPNqdakxayRYYD5mWPSuuVMFEPfooCjYUv5FsQGtdAqaQ7NX661dXTUiD5IYJ+4oR\nA+v8CAH6LUn6XEV4RTlQghvhBI+AHzWwUb5yfFU5kFjJtNqQMNDYGIJT4JK65M2Cy88E72ZD7FLh\nvsST296x/1sDzg/e6Ko2VWJpHIZI6ZXXUEIL7IaIaGDLxccoGFNITCmSq6A1MHYZYVIBiUSUbEaK\nCUWZa+ZlXdlF5d104JBGrnnj03xlP4y82e34cDhQW+thsX//1zdbdKsZj9OOocPMwfjLyxdGEj9f\nzvz+/p77weU0//3TL4wh8fk6cxgC/3T/wC4O/PX5E39+fvJgSDP+dH/fxwWF/zw9EYtyWlf+9P6R\nSQee68xfPj0RRHhZNh4OibvwQNbGX56eyFZ5vqwcx8C74w5T5dN2Zb4uXLcCKrydjkRTTnJl3jJb\nc2nVIUaan63ZYnUVQgQpwIYf//ANdhkMmjuJZFM3MlhFhoYULw4te4KuIRDVgxiDH1BbdUaCqm/z\n6y3ltlhv8oQ0+tIr9KBMq0KMlSg+Fphi6fQxfIGGMg6VtMsutW0BUyG20imOwi6uWBOWMrr7DWVM\nheOwOqBFlNKMIpFSAzFU4lAJaoTaxyN95DOqka2hk897a1Zq1VcouqqDySUIrXg378W2n/Whyw7k\nNjT1ycJtutAB74jQmhFukBxe8b0oQkRcs4rRxA29gv+8gvnyLItRaYwkYlfLzK1gfRE3qCMpG4EF\ntwwHFcc5SmQxxz2KRIJVUhLPvgu+EDPJKN50DEEZRKnSiBIwE2IIJA3UUhg1unoBX9hJNeKkjGEg\n1cpeE4tVtpy5CwOPKfHdfs+cK5ctozqj0hkdrZJrIbdXXcc3cX2zRVcFSq2sNbPkyjlvPO723KWB\n3RiYc+VpmflynhETBlX+5e1b5rJy3jI/Pr+w5o3JAo8Pd9TauLTKz6crl2VlnwYOx8Sb445Lybxs\nZz4+n/yXWhPvP+w5r5mnuvDDpxMTMETlf333xsMwa+Zv5xeSKKVk3g8HShez/7Q+gwmtNUZRQtdp\nnjAsGrVFqJVYA7E0ijVKLxTVuvZ1u8V9e9S2Dfqa7kvzwiQKJI8ip7qSwJNjG7JzOpkfrR1Ko8FI\nydBQvMgYXYbn/87dmNFQe3ZjoFZhjJXdkGmmXojV57VaG9KMcajswkozYauRmcgubbTmttRdzKjB\nyzp1w1ogxMb9OLOWiEnob6+UKsRQkYMxxIqsnZpWPe0iSqGpOsdBwTZcYte+zoGt3Sax4u3iDbzQ\nbbxyk4vh0isB9DcF9zZjlQahIxyd3xaIUtkZmAQWLV1ipuw0sQdmKiqJjeyFkK7X5avdetepC6W/\nL2hi6GjOpomlgGgkqefZaVOyNiQEQjOiKaNGqlZnJEef145xoNTqKhEVogiHceTKSrCIlQohUDsl\nbJcSGtRHKk0YU8AMUojsQuL9/sh527jmzNM8M6VI0n+oF/6ur6SBpRROeUUQ5m3jD3d3DJq45syv\n5xcOacCCcZxG3k17tlr46eTKhkOKVAl893hHE+cwfLnOHOPILkXe7Xag8Ol65ZeXM6qBfQrcT3sk\nwOdl49PLhdYy9ykwhMQ4KJ/nhWvLXLbMnkhQ4fH4wDmvZKs8Xa+EEJFSOEwHcs5UaZzz6oupTdjd\nBoy1MQ/mQHPF1Q0VYhNqMkrrce4WkM27H8cUePE2M3doVaG1QHMmIxa9u7Pa+bcJCMaw64swfARR\nMgyjsR8zSQsxWGe3GpNmNDQOY/b8tCasNVGqsIuVu7hwbQNJu1dKKlGad63a2IdCa8ZsI3MO7IfN\nxaUa3AAAIABJREFUpVMNQhQGqeSWyBXXCQfjzbSQc3SLrTncvebgKgXN7qbqI41SHDqDNf9+zWe6\nlObD0tJ1Xx3r6BuxrsLob3hdouFluJu8ekH2MULFnONjoESsg8q1f7wBSbSrP/zmFd0GQ3K7BJVG\npWAIfqj3zzYatXrvHCVSNTJoY7LALN54BFEGjYxENu3CNoPQAsdpZGvVXczVGAhMKRBCxIoipREJ\nDCEyBg+IV3W53U4T2uOwnsrKTiPvDwemNLDllV/nK3fjyF0aebN3xGo141sou99s0S2tcRwHHnY7\nWmukoPzw/MJE4KfLmcM48G63Ywx3/PX5Cx8vZ54vK8WM393dc78b+OX8wt8uzyxrYbbM97s77qaJ\nrMYPlxfqWvh0uvD9/ZHDMCFa+fPTCTB+Pc/sk/IwPaBR+eH0wrxUfp0vjBp5P+5IUfmyZq658Gm+\noqY8DBOhRWxXuW4bV8tQYEwTNWfvfqywRWFDXMdrjbD5prpWY5vwjXxQyBMq5tE2qdGyIUOD7C/0\nmsGS0YKPGQyPjAGQHmomasTBoGtDxZRhV5j2lbtpc0WDGVEarSpDrNyPKylUkhTWOiDSeEgzQSuj\n+khgXzOXOrJVYYiN+3HlXJPPmAWGaGibCeaR3VP0KPpTHVlzYIyFMYKZx7en6NS01iHp1oTjbsGq\nslmg3PS7mBsNxooGw4qT2KyZF9zXrVPfVBW4uUvcYuGuvVu6RfMRdNdxCNXcxNDH5yQSSZRz2xBJ\nKMaEsidxobDiphQ1YS/OOb5apmIUGhFhRyQTWDtIswI7lBaVubgkMoTGjkASZTWfs2tUhuZStNYr\nnqgQY2SyiGJk8XFS0MhekufXiTM2FWFSJWlgJLLDxxnugIuMUXlMA3PxZbUnFXvS85KzNzavSoxv\n4/pmi66IdI1oY2uFrRZSCOyHgd/pHcWMXI3P52euy0og8v5wZLOCBuOHL888XReCerDlm3gA4Mt6\n5dM8Q4UxRv71wztKbZzLxk+nE2o+W/zT23tybWSr/Pz5RLVGscLvp3tCVAjCf728AMZSVt7FPVUa\nY4p8WjdqyWSrTJIgwsTAVQGpzkYtxiAeyHjTbNYApSvypQlhwe3ANEooLpGKgi4JFcjVsCG78UGg\nFT/uGg29a6DNi1KXMlmEMLjVeIiVQbu21RrWhHEoPIyFKXhXBsagzc/xItzFhRQaExvnNmCa+BDP\nRPrYAWMMmVOZPLVCjeOwksPGpfhaStQ4sLATZWmBMTTEGqc6seZACMZOO/ylBVKEdTFfvvXRwJgK\nIvjCzeIrnTGI0ULv9s2XjJi8OtvoHTsdOO6fZs6m7VPQm5Q3VIhB2HrhxBpRpNMR3FK9Ubq8TBhM\nX9Pmt45UG0n4+UOo+DZtNB9nZHxRPIowEpiiMrdG5uYag0kiIoGZ6vjKIiQNTDFwaYWCL1cPMTKF\nwIwnN5vBFAOHmLiWSsRHSX5fVgYCQXwMMUpkCgkzWLfCtIvcjSNvxx2nsnHZ1tf9SvzHeOHv+0qq\nVIwfnl9IIfDT6cwf7x64HyeucePfPv7EoIllywzjyO93R0SMv3658PmSX/3xv79/IAblaV746fnE\nFCO1Gh/uDhzSwNM689eXk+svzdinkce7ked54/N8Zc4ZaY0hBr6bHrjUhZd15XTacMx34HeHe5bS\nOJeFL+tC2TKY8Ljfs66+YFssYxJYt8rQz/YpjiyyQTLW5kfE2IxQhNxZMFXMPfw1oMRXMMtGdvA3\nznUle3pAiw0ZnUxlpk6qSi6POhxWVI0YKiou1UoK09AYtbCPGyrGiMcBNRN2IfN2WBCUtQXUIAXj\njo1Z4KALURp73XguI0V2fBjPjFJ6plpgFaVJYM6JILBPK2rGl3VHs0AlMUjlMGy8lOabfyrXMrGU\n6CCflAlSWasDHZpB2dIrVCbESohGXY3awv/Izh3k9X/psJybnswXZ25j1t7p3mBe1RpVIJmPeJII\ngULBj+oFaDT2Tb0TprHRMA2kJl2RMLBSfEmq4nN8FJP6ysoNsY8r1GFFSZVcfcRWBaJ4sS3iIwvt\ngPMxCGahg9VdJjbGSGmNqEIzYxeEQ0hkM1SUYo6cfEgTaylcbSVuwqSRd4c9En2s97Jt7FPkOO24\nHxJdbPdNdLzfbNF1hafwx4d7SrctfryceLrMPK0LgcAxJv5wd8/Hy5lfljOny8a1Ft4OIw+HPWve\n+HKdybnyaT7zZtpxnPa8Pwgf55nP55lP1yvHFBnHxB92d/x4uvDTy4mfTxfGlHgII28f93w8n3nZ\nZn65XgG4i4Fd3HFqK+dt42nbaCUzBX+7AbI1ssCaMwNAaxwlOoowCFvtHVsRdg20eFdaA774ieog\nF1NC7QszjEyXd1lAWoLo2twWsyMdq3pIZBWaFNLYSEPpBCsXx++GTIyV9/srKTjjKoXKnANE4xhX\nRi0cwkoVZbCVKbqSYh8yh7Qwt8S57rGeMPEmzUgVRnH265u4cS4jWXa8j2eGoVBMWRnIRZhSZS6O\nexlTZaeZ3JSN6C49MY5pJknyBGAztAhL9u8hhkZIldKU0olptflG390kOMqw8pqE3DFA3vb1EWn4\njWfCcEyE4egCaa4ZF4ONClWxUBhRdggzSlZzswuVPRGILGRf+tFITRgkkntRlqikpn6aa8YsjajJ\nF2UkBlWWEGgd0juQuAuRrXYLjDWGEDjEkdY2t3KLItbYp8BS5HVOvFEJGvz7VmMnkU29+x5iIKaB\nMSTXUhfYrBLVTxVLrTwAqjcF0bdxfbtFt/+QhxBRqUQVcmscUuLIQGuNfRz4cp15vlzZipOUhqjc\n7yfO68p13XiaZ44p8X5/YDeMVBo/n84e2SPC94c9IMQh8J/PLyzLRinG7/Z3IEYcIn89ncglc1k3\nHscdYjAE5dOysNTKWgqH6NbLaUi8LBsmjdWMkI2dBiaJ1FAoubBGpeSKSCAslYMCFrAk1NoBMQZp\n8y6nNWPrCxACDDU5D9yg0CVbwRDzSHcBl1qlgoTqG+oKZYsuuQqVFAuHITNIobRIMZg0cz8U3o1X\nklaiVaZQOOcBCY27sDHEwl1YOLeRu7Cx10IDklT2urHXlU/1rocywn2au3nCAyYPujDXyovueEgz\n348vNIRLm6gmrgsu7dXUsI8bsgjXOrC1QLPALm1UCeTo3b0Ww7KbRkQbocvGSnbzRMt9lFB9CXmz\nRUtfpBk92008Pt1QUsUjzxHmlkmiVDN2PVVDRKhmZCl9LRaIBIRIbhXF6WaluvJBgyJUYnClzdoK\ncXSzxVoKowpzjwRqpgSMMSaqGZvAiuew7WIkhkDZsvfqJkwqTDFheaVVV9BMmjgMAVsbTYxaGoeQ\nGKJQWiPXzBQjkwWOceBaM6eyebL2OLIbEte88bKtzKXwYb/7JqJ64BsuukGVGJQ/P39h0MgvlzPH\nNPDd/oiZ8W+//Mp1OzNvmWrwx3dv2AXll/OZ//j1V4YQeZoX/nh/x37wdNO/Pj2TVDnPK28OO+7H\nic0qf3l+wq7GvBX208jD3cBmxsfzC+v1wlZ86PrHx3tyrpxL5ufrTLOKFePdbsIqFA08r1esCnMu\nPO4HSoMYArM5v3YDYmv95RmxBK1U5uRuI2vGWIRknoZbcTzhbcN+0+neqGP+QhBYcWB1NEr0EYBL\nHfCxQ4nsjjNp2IihEgRyVjZxtX5SY5c2Ji0cdWYjkltgrxvvhhODNqI0IpVJsx/vCezCSsS41yvP\nbccUKr8LT1RTFJgkc5SVH/IbHGYo7ONKlMqpTM5Pl4bKwqUkH2eMM1jjpewpPa9Mq5GkQSiMqZJD\noyxKqcEtyVrZDda/HwfEI9Fv3sGXiAzA6pplytf5ZBV7dYuF6p2wItAaWSH1RVVqPtO9auZawaLz\nio/46GeRymIFi4FYYCqGamSRzNrhvbsQaFYpYsSmbAEmDQSNSFucFxKFRGTQwIqhVF+mmhPCBGGI\nkaiCNCWFRBBX2AyaaK2g6uqXKSZ2GlmiISi5GENI3A8Day1c60YsgRQDj+MIopxyRsStxfs0cOzj\nhW/l+maLLviP+e1uR2vGd/sDp23l4+nEy7qx5Y2RyB/vH5hr4bot/LKsnNaFwzBxPw683x94zhvX\n65mPp5lRlTEF/vX9G76sC1+2lR+fnhwSQuTD2x3nXHnJGz+/vODu0sDvDw9c6uqaxWVhzYVI5HGa\nWLK75q+WuSwLU0i0KByiY/g29SIdTcktcwjJfe9ZHECTApsZMTesCsECNfiMzoq7mcBBL3Sh/kKX\nn5pLihKOSVzE3FoaxcMIwZdsh4qOmRSzP6v1K8T8blw5pIWkxU/czdGQQYwpLNzFmaiNB71yaSOt\nKUEq7+LKagMNJVIYpfKoV57ankgjaONeF17ajoXEH4bPVALZlEEcgnNtI1uNBG0kKbwbshs1TKkE\ndjGzlcCikf1uAzMudWJrkXbTo9bqjAm6nCm7a86Xio2QvHO0HjekeCS7J+X+dqHmT2rDmRPNvMhZ\ng6E//xWYpZLFm4KBwGawmBGk0MyIBCiNKsJiSq2NIC4QqwbXnJnGwCTa567exfqqzbvlRGBtq8uM\nVdhJ5Dgk1uoIHFUlNjiMI7m5dLBJJYhymBJtaYSmSHDazhCSGyxESNHVGcG8Gw7R9RylVGSErTpx\nrAlsZnyXEmNwLfW3cn3TRdcMHsaJZsYQAn+7nIh9Jpdi4O1+hzR4vi48Xa/cT66CfHc8oAJf5oWP\npxODRPYpcbcbiFH4dV75fLlAFR7SSEzKcdrxy+XMy7KyrJUpDkRR9lPi6Xrl2jLXrbCXiIrwuEuc\nSnG4zebQ5wFlDIFWXTi/ZX/BUysa4KCji++DsqRMbQZbY2xKLUbooBTw5dmt84qizJ0QVjCSKRrd\n39/MWMRYOh0rAUMRiJUaXYaEGAmhbi7nylKZ7q8c0spdWjBRSo0MMWMGx7hyjCtJKgOFtUVMhINu\nxNA46EIl8MgzpzaR+2LnbdjQAlcbUIxBK2/kSjZPHVbJTGzMbeQTRz6kF0jCaorgM8nPumduN+QL\nPA5XSoNirt/1HIRM1sRh8MXgskbmNlKKRw6pGBo8KUgiyCKu/Lixe9UbX72pGorT2By763P1KJ7G\nkYBiDcG/nvZVWzLPETPpcHgakcFTPWg0a5TQnYHVkycKhsRGVaNUV69UPMsumXDFcZMq7kZLClk8\nemctjRjcvNAEQgjk6o9hPyak/1yX4mkaowamlLjkjc0cir6PXkDnVphrZUyBfQwMceBaCqdt5ThM\nPIwjQYUlF57mGVXhd3f3/7++9v9nXt900Z1i5L9enhGB52UhaeCfDgeiRv7zy2d+nS/Ma+W6bHx/\nPHAcBg5j5ofnF2o1nucLHw5HjkMiJuXPTy/UWvlyvnCcBt4cj4xT5M+fP/G8LpzmhWbCh2lHHCNf\n1tlHDKWyrRvvD0dCEDZr/DxfwZTTuvF2GGgNdrs7rttClcB1XdmniOXKPo5U8xf0pbpEypoxScSl\nCcYyKqV5PIIUOEhwG6n1UULPLTviNlABztao6hKkQxUk+PzOFGx06ZSaocHQrAxTQ9NGCM1l+9XY\nSqKaYla5H1Ye4sJjuniRa8qoxiiZnWb2IRNoTFKYW0IE7sOCSmOPsSK8CxfGVli6GXWvGQvCL+2u\nL4AaUWeyBa7maoajVTZLfCkH7uLKo8zkpqw2uOkiNObs880mwt2QUTGWnMg1UCygUh1/OBSsE8Hq\nMjhY3dwkocF6h8xvUozpWC8Xo0oRT6rIvqv08UL3XgiMCKVUVry7NC2MNpAkslCZuzQsmBKqpytn\njCbFpVwtIkARY20et5TEGIIym8+NNwyVwKQDuWzuHAOEwCH4ciy71Q3MXCUhigRjUmWujRACKj7O\nOIbEpRSaVaoJA5H9kFi3ylkKBwmMquyHATBOa+ZhNzENA1OKDCH+Y5H2rVyiDhZpGGNILqxfN87z\nidO6gRgPu4mHacTM+LSuPF2vlJw57Cbux0cKjZe88euXC7lWdjHwz+/fsORC1cZ/+/VX1ryhTfnd\n4Z65NopWfjw9k3PGWuQuJjiMZBpP6+ozXoSpCW8H1zhuNE6Xc/9lN45pgAp308Q1L2gYWLbqL+AK\no0aKuij/XH1uJ6pMRUljwFrl3FzIXw0vttEoxcjRKM15va6PVcIQyDSuZMdhFSFlRcXnnRYb4847\nQ8VICus2EqcTxyEzaCX1qSu1pwljJKmMofA2XJlJ1OYv7r1uvgEXI2ActbE03+Y/BC+cI7AQ2OnK\nO4zFRpTGXgsxnvjP/N7ZBNLYSYFw8dGGGpNCaMalDgSFd+OZZnCurvdVjGypQ9dhoKJjo66KmbDm\n2IHeYKk59KYKZB8baHM9tslXbCLgrjWDEvrnAhTX7BqQW2E2HFtpGRi8OLsJnIr1MUXXTLcOzwzO\n3lhbYSS52WX0hVmtxmIVqUaKyqh+A16tIBHGXhQ3q6zVjR3TkNinRG6VKo53TBYZ1alzrUKRRiAw\nDgOlA4IaRhRhH6MXajGKZUQGdiFyzhnVxpI3ppg4DuM3o8+9Xd900S218W6/J6iylcK//fILzzkj\nVVhb5vfHO6JGPl0u/Pj8whQT81b5p/sjh3HgeV74+enyCjV5GLxAn+vGx+uVXHzTPenA4/2OOWde\n2sqX04VBlNiEh4PrGeeWebou7AZvP3eqiApFhGt2DWYTY8SIRreQ+gzPLFBzYbRKJBJ2gW0rLDWz\niqHW2JkwMrDsGmUrrCKMYogGRolI9O9hSQUrnhBwiAlTYRa46kaT5uD3piQCORhFK1UaQ6jkdXAc\nJI27/UwaN+7GlcF1aSStXPOAJuMxXBikkqQ6htaEgYaog8wV4yBwtluaLezFehalu7mOaoytsMnA\nMWzcsZGssjKwivEhnri0AQNGrYyy8mu9YzEvmFGMx3AlNw88q6ZUCnNNVIkc0oyIcN0SG5FWlFpD\nb1q9qIZkDodv0jXO7vO1iFuGe8IEPgKmqo9sVD3ORxto8NnrLUU+KqTi2Mu1Fk9cphEYiK1hql3D\nW2ghOFvBYMWIEbZWsaCE1ti0p0QEpdWCWnTgjjQ0DbStOHkOh+aEoJ1V0Wi1utMsptdMoYpreo9D\ndM6PwSVviCh3KaIENiu85I1DGNgNI00c4ficN3Yx8m7akWmU1vg8X4mq/P4f44Vv4xpj5G8vL2Rr\nLNuGAb+/u2eXIj+9RJ7WmWWpPM2zp0hMiff7kR9PZ355ufCpy8X2YeTNd3f88HTi53Xm5xd3nj3E\nkYfjxK/rysfrhed5IefKQxrZjc54+HS+UEPjtGy88UxrUlDmlmmt8bIs3A2RtiiH4KmGFeHSCqMq\nVhqjhO58iixm5FJZSvPjHwMaXHvZVLGq0Cr3g3MYUoiU0rhWlwjtdcAiDBLI0lik0iR3l1F0MwWZ\nOlbQQsDQpMQyEt3qRojOD9iHQqmRax7B4N105phW7sNCVM/gitJYbppZ2ZxCJo57jKLci8v7It4y\nvrfGSxNyd7RNCrvq0fABOITG0BZmiew0cwgb1iozI1dG7uPM1iKbBU9WkIJm41zHLpNTjmGlBWGV\nQMM742CNxSIxFlSEUmHZBraKs3WtIxUDIA2rPp2VWz4OPlZoXZYXNsOsozTBRw5dydYqrOJxR0Uq\n0YwoSrXCIqEXYJeVRcR11VYIpp6urEKpmTWM1FYZNBI0YqGwBiP1k8gokUxhtUoIrj/epcBSKs08\nuNKksdNAUcOqMzWsgsaIlApBmSSwWUbbQItC6i60eStIgEETg8KUEhhc+kkRM8YYSRo8/ugbub7p\nojuosjWfhd188xh8vjgf97xmdkPiQzgwpch1y3y8zpyWjRiE7w4HwCEs//3TZ87rhrTG99OOTfyF\n8l+XK+f1St4a99NIkcIwRr4sV87bRsmNg4w8DgMmDkm5lJ5PVhqHmGjmhWwuhRgTOW9df7mx08hW\nCtNu5DIvFBotN/bqs8b7KXFpRpFKzsY+RlIIDA2W4JE2S3CYSQyRECK1Nq6tslpDAtylidoqIsaq\nxiaKFYEQGVBMK4XCMM1MoRDECGos654ar4TYSJoxgSl4Xs6ljL7ITAv7sHDU4sSy5oue2nW4gzi1\nqierE0R57NZjcJ7tGy18scjaY8xHNY628dliX1rBgYVM/D/Ze7MmObIkS+9TvfeambvHAiQya5le\npmWGQuH//yMcinARkt0sGVZVduUCIBZfzO6iyge1QBYf5oEzPSNNQdoLBAHAI+ARpqZX9ZzvMEnn\nLm8xSrCF1TNFB3fcqF6oOEkg6eDWFtyFRgoqV6qsXhik6G73r9eSf+lktTuG7iMUQCWswezS4PA4\nMFQi7meE4iygPvHzV2WP2YnfMllk0nWcIsJgBOzcZY+Rj4dSdWNOQqZ8yeVRVxqd1WJmn0vosfvQ\nUCtI7DayR3BmH6DqTBJwG9uM5rEsnXOMHPqoIRtDWJJy0GDxVm2I5fi3krEStIl1dO5KYbbQJFcF\n2Tbul4WHOYrv18LSha+86FaLaJyi4Tb6x0+f+NPrC5jxXFd+e7rjNC9cbiv/9PSRQuLT5caHu5n7\naaGbB093DC63zqyJu/sTAny+nHndNrw71Qe/Ox3pw+hF+efLGe9G7c6Heca1YNJ52kJtubVG2efN\nLkIzxbNQB/hoJHvzzxeGD3pyzrctOkJXTtOM9YpNiefaKCnj3fkuZVQzKSmX3hALQ8idFlJaAOO6\nw1VQOJlCVjbimNpxhjhzUkQnmkelGd6Z5kFtBe8lUgqmimrjbrmypIojJDVetxM1vZLEKSlcbLNA\nJlFdMJzinVmjg48rNKB5txtkBFUheXR5QwbvZeDe96MzHKVxlcrNJ4wIYDzpyic7UT1ed5ZQU0w6\nglrmladxpI5M88Qh192BFtI4UWcbBdzpe4s6z866dtwDAfkWZOkpZtfIno/29kO3Txti28WXWe8g\nxsETsdicMliPvLIN3wttMHpdhIpRSLjbFwKDq8d76CPeMTeqCMOISEp1ZhXMBlUizHSYkURRTag5\nKsY2hGkKpcaswpwy6zC6xMllTpklF4YZ2zBUOkXgqEss4Ubn4o1jLhyKUs24jYaOzKEkvj2eONfK\n2hofr1eyCnfz/F/7dv9Xc33VRTeLsI3OuRp1NFpvfDgeOOTC/TZzXjf+/PzMx+sr2TLLXPj3HxY+\nbis/XK/89HxGBErO/MO3D7yslXOrfLxeMBvklHg3FTaHJ9t4rZVrrSyaOKSJxxmeakVl8PH1yjEX\nVJWH48xta1hWnrdbGBkGnKTQWycviUttzJJZ26AkZbTB3TLRhuHZqWmCHiqCgwjv5hlFuXjHaicV\nAOU+LzhKE+fcb4iE5fadzrThtOQ0u6HisTFPThKlpYZLYCUlQ+mZNOZ9625kbxynytoLm5Wdqbqi\napzSxiHVL86r177wQTtJLAA4u3lglkL3N0RizHFV5MvvRZSJzOoWXAMRpp0/8yrKXarceaW70EnM\nqbNYSMoMyOIcU0W6c7Wyl3Znkh6Lv+QRvw6sYwrpmgZ0YcqDNnZWrwVxi7f4njdFWuC49q82rh0x\nC3W3ChMhkCPtKoa9vW2B1GDbWbhI1Gjb1QQOtLDCMDAWUqRFpJj9d3e25Fg3iiaKxMjiYjF2Kjlx\nUKWKsPaKpjkeVstCv92woXQihHQqCTNHCTuvIMyeaB7+5rcTYwKyJKQkEnBrlSlH+m/JE3OO//y5\nrhzLHFjLlMjp1/HCV3NNKXFrjdoH5k61weOy4OasdfDz9cohT9yVhYdlxl34vF744fyKOxyWQkZZ\nSuIv1xsv1xu3rXGaZ7pk7ubC07ZybiuXtXHIE4sWjnOiNuOldi595WhhlQw6lXMeneGDUQd5b4tO\noqy2Ms0T594wc87tyl0+sI3K492RaobkxMtauc8hZ/r24YFb73hJnFsNgbzCyTM9O0Nk706gaGD7\niobUyLXReuewW5BzVl7pVGl7qkQi74uigSFT5Zga4KTcaduBlm9UlKQN9cS35UbzTOsBjZlyB21M\nEkhH83BvVcuYBjowiX2RFC0y0d12alc0iosoa+z1gXiYxqxzL8QY6s7ZCgfpnFKlO6wsdFeKBH2r\nekjlSupkHby0jHtIxrIMhsefiURUULPEaPuSTPaER9vFunvmW4yef5EvOPwCiN8/PPaG3ohQSUvx\nd7T8MuMVDT1t1vh/QYRYujoDpxLW7mDYgklEDtn+yo6iO5NXCM5tF8CVkhNJEl2N2hpZlSzOXJS1\nDbYtGB5FY69wdWON2GSOuTClwnBhE8dx7tLEXBKX3th62OEf5sQkErZ2AfPKt6cTD8sSXIZfxwtf\nx9XdeJhmyiFhHrDBP74848P56Xbl3XzgUBKPc+EPT88BttmunFLhfo+T/uPrE3++vHDdGr073x4W\n5qXw+bry5+sLoxlr23gsRzQF6ennbWW4UevgocyYhW12s42myvO1xizTwlK59oEtSt2EWivqAbg+\nppkg6Ck3b3Fzr8a304x1Y75PfGwbuSR6XXlMU6gAlplX251uozGhFBVKyqytMlA2Gq7KIRdWdcwH\nN4yBRaChz4whoMZWKqM0xkggg4xiXVn37vYht328MHjZjqz6SiWRteMG7/OKATeLRaFIFGJHmSXY\nXH2f8zrsC6TAJcZeSpiI+HHzsOGe6Gx7rAxEV1t4SzSWcJvZYPNMEuNdujIcnv3IsGgtVYQ65EtX\nPOXBZZ1wlK1ncGcqzmoRaw+ySxBkL6i2W4QVdptuuG1/6XL3pE7w3ZIdr0LaO/bEvpzDA6/5Vz+/\nFZCdw0C2AOOkiE1/i00yGVzNSShDdzm2RKz9NjrOIGvGbcQYRaF2mKeEaKA/51K49ZBQDlWyO8ek\nXPdRQkrxMFhU2UaMElyFYyosU6EN49ICUjSXwrvDgde68bKuqELRzPvD4b/KPf6v8fqqi64QpP91\nDHrv3EbnfppZUuI0F26t89waP7++sG0DzcLfnR65WqUx+F8//sDWKjaE3z08cKsVc/jT9cwUv0sE\nAAAgAElEQVT1ttGHcyqJw+HENhqXPrisASPPmngshToqqPKprXFbduE+K7Uby3Hi83VlngrPbSO5\nwjCWMnEbDY4TtzbIGkL0u5yZpkSehNcUaECSMplwSBOyz4g/bTcsh131lKZIiFDlag2ZlFE7i2aa\nGUim+xrzXHznHRRWuWHT25IlEmFlLKztiGMsxwuH6cJmBTFH3emW6CZMqfGoF4YI5spzO/FBBoMI\nnRQfHCQQgW1vgN46tu6DWVKc3vfvYsTaxLKty1sWGiwer/nGZLjTjY9+2s0aMXvNuwitoZgoyYOQ\n5Q6nVDmq8FxnKpm6U77NPKKOiIeOB7YtkjeMSH4UoviOfcOGx6Iw8Usk71BcAp2ItqiuWfARxQ0I\nqE2L2PW3Wp5+ecWdoe7oHheUEJoEn3e1hiBojg5+jM5aEtkDDlRSorp9+b4mAgDVZIsHmmWSxFy3\n75bf1js5aSRAjA1NEsGh7sxaGDRyihNb7bAUD86JBOskin3jNM04Rtb8q073a7qKKlvrvGwrIs55\nW/n2eIeKcO2dH68X1IWswv39kTlnrm3jL58utBGWoqKJx9OBW6289o2fr9fIelLhN8eFdTgXW3ka\nN3QooplFFSRx8cpVOjreyE9BF7uOTpkKL6NjSbj2lWOaaN04HAo3j67h03bjXme8d97PExtAgicb\nUTS785ATolFgnnsLiElOFAmClriwyaB7uNXaiHmfmYI4ncGSEutwDjg3lNUaXWJ2+yWE3WZML+S8\n4TrADe8zazvyeX3ExXi3nHksF64+kTyO4Y0UHF0xTtqIHE3hOmaaBLg7BPbGLEKRfZa4u7gMC9fU\nX40EBcIaHETZWFy5sJE5SBDENleGJ5Ydjt4oeyetJAl99WaRoWYI4sbwzLRDzHEYXmhtL1e+2z3e\nGI67ZveXQ3OIWr8oGkYA4GXHVAL7LCG68GEe7EdTcv5Fq7y/UlgjckiBHRhmuArDGlknmofGdgxn\ncqdrFN9EzHuzyp6EIfis2IgFnfcR7AUtJARzY2UgIgHOSYnaB891Q3VXPmhmHY1r2zB1Zs/kktl6\nZ2sNE/jmcESQ6ISbMazyu/s7Hub51/HC13QNdw6lcJrDnjilzJ9eXrm1ysfLlZITS8789nTkz+cX\nPl9Wfri8kkR5XI68P8z8eLnyw+2Vp3Vj7ZGAuswzY3R+aBfq1jn3ylEnckkcDZ5HBYliP+3fgkMS\nmjtXq1y9k0eH7kyqoSYQpyfnRt+jWyKORcXJS+YineHOdXOOKKTCXYpUAFGhGswpYz10wtUqPcOl\nbRxyZtRGcSULsYSh4+oIFnE5qrTRI+FXlMRMrzHf3HLlqo0JJ6VB9nBLrXVGjp+4ny5Iir50azOX\ndOQnewSMh3LjId+4WsjBIHLUVkIKddQoEl2c1YWDS/SmZmx0ZononkDXypcOeBbbORNvl1M9Y2/v\n2V4Ab54oOki2Usnc2G29ONUy3QIVo+rMafA6osh2eyu0AnkPl9P9fD7kS16aJMd3l9o+4AUcSW8d\n8H6pwOhR0HfOLRLBmBu6S8j0C0gHQu2g+ssIuQ9HMwxvqGbYIenNhOT9i1tsuNFsBPMhwWRExHqP\n93rYoGhE6igloPc+oEwkF/LuYFtb/Iy4OEUzU05sNrj64A5lzoW5JG698rxu3C8TD9NEycpqxtNt\now3jUCbSV9TtftVFVyQ6QN2fwK1HCN/dcuBQMnUYt9b5x6dPfFpXzAffzkeaGJMo//H5iee6Us14\nPCzc24yq8LmuvG432nASg7tcSCpUq3zsNeaOFnEnwyIG5tO4hSoAoyB0Nx7zzGVUpMCzVbIo6zZ4\nWA70NFhK5twHWcGbU0Q5kpmPmbV1WoLaHd2M2QtDjTknntqKLsJWjeOUqM15yDMXC4NEtbCLbr2T\nROm0gMKI4bJDWoiY75tEikUR8Hbith7CClxWtGxsKI+5YxbH96ftHpYfechXUKe7cu0zF5n52Asu\nxqyN+7Sx+iB7ALLNYXNjc+MQRElublw8VBURX+lsxPz1S4CD86VbXdg4y/wl1RiBujMgosC8WWyV\ngXLQShMlwn3CqdVH2LK7685bGEjbO9q+f2G6T4572uPZ42Eie5FVfim4jkSe2pu74q2P1ehtXcbu\nRnDcchS5kejEA9KIB7C+IR/ednc0umTcDIqEtpg4MSBBNtOkeO9RfIeQpsxUEt2cNgaWMpMIpWRa\nc7ZeSVrICU6lxN8zpw+jJJjSEpD7/SFoo3GYMkvK5FyQ/RR3n8MSjAhJE0n/SlL3FVxfddF9I9j/\n6fzMpImfblceDwuzZm498398/EgbjT6Mx3nhOGfcnD+8fOZlvTGGsQ7j/XIEH2zS+cvlDC6hhCgz\nJgUfjU/9xsDZrDN5AKNrIMIxNUb3WFK5Yhrd5EvasOZ0h5MnugqnMnPrG2VJ/GwrE8p263xYZtY2\nkFPmY18pWbHNOJWIXxF1KkaXQWVQekZHbL1FjRc3RjZ6M+YysfXBISVuPiia6d45amEdnZGFZhsj\nWRypkzPZjNg5OKs6wipsylrvebk94gwO08YhB6D8N+WZbhlH+bG+4+/LE3fphgtcrHAeC49Seelp\nP1pn7rSz+Ub2ibR3pI3Y1i/qu0rL2ByulnZpmVCJbtj4Rfc7TOgoE6Eh3jzTTMGV5m9JvVFA3oYE\nzTI5tVAvOKwtBdbR9td9S1hMsUNzZU9dCM4CefwVZze+tl9687+Ww/31ICEm1mjk+cUsx1EJspyN\nhKXoZodHsrDtRW9K4BJgGpWY8Q+B0Y08pS9z4XDSKTTDUpDIppRI7kEy89BmzymTRekeyptJE4dS\n4jXduVmPRW3Oke6midtoTBqY0mqDW2tso3Mohd/f3bGUQu2BrUxfiWzsqy667kHB/93pRLfYev94\nOfP99srH2xV3Q5Ly7z584KfrmZdt5YfrhXV0Dmni4f7AtW583i58qhvXbSVLYkmZU1Ze2ooofO5X\nlIQCjyUyrVyMl9G+dCuTRiKaJrgR+lobQkmyH/+MnBI1G62FxKn0hKTEPMGTN3wWXvqFg2fqZtyn\nwjqcaTbOrXMqM9feeZxnWjdkUi6jMqPcLBZU5rD1wcA4S0iuXIPjOtqg7IU4kVDLZJxNBiuG6yCV\nTnFBxBn1xJhuIIakQfXMbMJLu+en9T0uxjFvHPLGk838m7zRPDSbP7YTv02dRSJC5snTjoeEH7vh\nRPjiQQaDTvKQr7kLdVcCJAHHMY/Msyc77B1qlDJDiCxdAiXjiUhIA3FnswkzoVvAxCMNIhi4w0Ld\nEGyi0KvGH0igHH3verPvDAYNPRjypdyyY2veUoJ9L7aJwDmmDHTFdlSkKDHK8HDCDSI2ScQhKRa5\nkYzdaiwOZp0umUwOZUkK1cewjlnc/illfIQCwwiYTtFgSSjhFqvDuFsyPkLNPJcUS2CZIspdY/dh\nw6nuLClSVkrJbFvj0/XGcS58OJ7IKqQUiqE2BmkfZXwt11dddAEQYUoTOgZFG8NgypnvlgOrBfjl\n+8sz35/PXMfGUTOSMictfK43Pm83nvqNSZXTceYohZtXPteVi2zoCCmWuKOeOOcVs5ib6Q6ZLkOo\nGVQHFyLifDM4SSxVNDlbMmQMRhfuNHgJ86xs9DAwNGcaGiqCJGgxXml0MV43Y8oTa6scdOLcGpJh\nHT0gPm1wKJmtGZqElRbdfmscVNnMQW3P1NqtqBKPkdV7WE114P1Aa4aKMFJEClXgbr5F5LnAy+2R\nb8sVF0eTcbOZyRvnceB/Oj8iGCV1llz5eST+rrTd+QafbOLeBrMEQOfFhe6JBefTvgRbrUTEj7N3\nstEBVwuJl+039w6f5GXMVItuMBF8BSTgN4XOSvACwhJg1D7j5vQR27SchNHh/201e/vZctR3mS58\nEd3KF/bjGzZH915aKNp3kPtuWpP4kyRg0snquAf5K8YSFrB6QHLwcfOuyrEUCypXZ/herD1KvCbd\nI4Fsnz0rWZWUhDaMOgwVZdLCUia6rLQxQGDRspPKCoPdpi3KnWbWPJgkR8qIwymXLw+YkjPLlPlw\nOPK03qijs5TC/TT9ao74Wi4Roajyp5dnVISP1wtzUr6ZD7g7//jpZy6t8rptZFX+5vCIqvDD7ZX/\n6/KJdVTWHlrbnIRqzo/rC11CfrZIRqZYgpxtA1m5tlAMKBqjP3WadoYGZEQtbuZZlDENJO16zB4T\nwYMIjUFa4MlvUQqacNSCmZNn5WKVnIVbN46SSZK+ZMA5N1YMbXGjjB7JvddheBrRYalw641TXli9\nUlJYOedUuI6GJGjurFL3rs05ekHdME9UHygZLStrP7K+HGOOmRqeBjdLfHt4DQmZJ/5y/ZZv784Y\nkJLzYkdkODfJ/IfbNyR30MGsnZ/H4O/KHvuO8DISVZyTDpBB3fGQmyvNYha7WmLWKBjNlKwxt61W\nvsx3mye6R6d87Zlqe2yNDtYRKR1bz4iOmAdje2TQmw4X6Cm61gLaBljCbV+2CfziS9uLpBDpEtLR\nQJbjss9yVQjxmpJL3+e1eR84gCZDxDAxkkbH7m7oLgV0A/MRhhs1jIKpURkMN8QziO9y4ciYiwWe\noKqkorjFabCNxiIBLh9jYMmoxJJ33uOYhjsXHZzKxIfDkUttrH2LJXIufHf3yNiXdN2Mx2XhcQ63\n2td2fdVFF6ILebfMDHPkeOLpeuX7lxd+2q7U2mgY//DNe55vG+e28f3rEy+tIgIfjidsOJde+dQv\nPK9nkEQSeJinvcgNPvcLugPDj5routO1CP5s7456dFAZsLmTRGjegxE4NJZlqliOrbO64F0xd46q\ndG2Q4LNHYu/a4EFm+jAOS+bVN8qiXJpxlwqtQ5Gg/OcEtzEoZMwdt3Bc3diQBEOiC9rcWHJiJRgP\nZs6SE5t1OhISp9KQHSKb2gHPxrBEk0Zy5ZgvPLcHXto9ag4lQDdPduRvl0+RvKuZP90+8HDc4nXF\neB533PuVg3T+59sJlWDLLqmxDeeonSAUChfLXKxw1IaKMZg4W+J1HFitYCMKcdKAu/QRBa95CpeZ\nhyZ1NcVMGR7QmrF33El9X8BCa7t+zAXSQIcEmdwU07eBQRTS8CuEa4t9/Ps2bCjZ6H0fIUik/05J\nYtCg8TDLKXgJ2RIugXHc6gBNmPVditZJmvfPqPSx4zglpGCTCivKwBgW47UItTQaYTcXFY6aqITD\nre/g3yUX1qSohBJm9QbJSCmYCnMu1NZ4va0clom53DFlYU6FQwmL+61VTtPEkvNXpVj46+vXoguc\nyrwP8pUfLxcQeF9mzh7W10/XK3+8PPG53pgkIRl+W46so/PZLnweV8yNMhcOkulubL5yky06nRzL\niCwJyx31zkgjxPwOs84xoSxKlxWxsOarz9hwluw0WckaMdlCoXviUKAZtNSp5qQsjKswqZBEqbLR\nM1xHRRFaV47MbN2Zi/BiGzknzgMOOcXn0hwutZy4DmNxoY+OCsF1SOyYReEwCr07RRMrjWIFr0IR\npTO4ORQz8rQyuSLqrJcH+nGwWsZzQ6rzbn7lc3vgqd6FOSH2WfylPfDvDz/SLChff6nvmQmma2bw\n2Y/csbJI43/fHiNSBmFJneqJRQI9+bYoe+4zb/e5i7Ba4twn1l4CIjOi41WNRZu6sFmi79D1jHHr\n866LjaGpSEimohUNu4L7HuvjQBq7nCDGEya6kx4jHsd3J2TvMaOddOCjo0lQGQwpqCvIYEjYpkeP\nxOBhQkpgMphSEMJSSZEQnIPBkEvMqzuhHTeE0QdpKmgSah9hnCPUL6pgPShz2ZVpmlg00YkFbNpP\nh0LMgnUfS5xS5jTPvOxAHnf4cDrw4Xjk4/XK2hunaeK3d/dMX2F3+9fXV190l5z5w9NnkiivNSL8\nvj2dOKTMH14+8/P1zKftRnfn96cHksBrX/n++sxl3Lha5ZhCZ5ik8NReGNJYfaOY4DkQfJ0BulKt\nwpsjyqM4pukWWWOAWQrf/UgkqTADNNLIVKsoE2BMWalyIWVhc8jMWBMOs9DMyXPjMjxu2BZKDROh\nsVLVWR1EFDNjthyb+uxcrZHmMIccSjAiSlHWHvPf1SOw89bDfjo8kgYmzWGz7YUqFhv+XKntyBiF\nooEe7Ja4uXA6nBmWMIGn8zccjrBZgjywDh+mV8524D+8/ttYLIlgIvyxPvI/HH+KZAt3Pta7HVAT\n885nO8SCB+GP24KRqEOZsu3IyESRQXelWeK1RwabO5hrfLynoIlJLNR870kji83pXXc5miLD8bHP\nI0eoYTyB+m5nGPLFEmwISRyx3d+rho0giaUUul+V4GGEDiwhlkipfRFGBHjXSTlcZL0nssAYu2lD\ngsVmnXC3OWgJVq7sGMo8KZJ2oUXS0PJ257gkphw4zEUy3UcU7pxQUyZNVHOaO7M492Xm3XGhjs5L\na3hK3B8WvjkcGR6zfQEel4WHeWbO+atamP2nrq++6ApwLAV341QKZsbztvJ/Xs/ceuPcGt8dj7z3\nhW0M/nx94tN2oRJqgDuZ6N24ycpT/0RVR7xzTBNuHRVnswsljXBIeaGaMCEg4Wc3GmMEXCV7CN3n\ntOJp4CYxX0zOLELKG10c4QpW2IYjfcLFyNmodCjCuUchdoelaKAAp8FLh5Nmts04pMRwp6XBRiW7\n4h624exCH44UoWOUXb85WaYloxDIx6Wk2EBbYpNYng33EN2PiaKJ3hZeZYSwf7rxst3TRqGIsJlw\ns8xDz3xzfKVZopH44fIBXZSbZVIeNE98M71SmfgfX/92h9QoQxRl8N8ff/widX3pB8594ZArKs6V\nmToGdWSuNtFRtpHJKbi/vQtZnWoJ88S1TXSLxRx7MXYjlmcOWZ1hb6oI381m0cnuu0a+IMfU4zjj\nGd0TkT1Fdxzy24T4QL3vE19Hs+7TXCPnDmKoJvp4A/0kVPuuLU6IWKQRI4hXNBVUPRImWrxmjEMG\nOcNoIf21FIaUaQoIfsNpvcXHSsF3U0zRxGqNohnRTkmFpaS9sDpLmYNalhNTSjzMM+7Ga90AeLcs\nzPmrLzVfrq/+nTCcx3lGJCJ7Xm4rW2scUmbtje8OR9ax8f31iR+3V5IIpo3fl3u6DM7blbO8Yl7R\nHHpaUWVYw/KGef0lgXWX7RQ3km6IBrClWaTbZh2orhF5zaC2CScxCwytLMUZ2vFRaEPIKQT9JVe6\nDLDolM0TOvKeXQabDFyEaw/7b7XOVBIblako526cfGKzMB5sWLiG+hYdjzll6B7VAuz61dkTtOi9\nmxjSFMuwJKGZIK5sdsZ0hDhKIW8n1DdetzsqHdXBPF/51B7YzoVJnZsplz5zyCu/Oz1RhwITP1ze\n06YzmxVS6myj8M38yqLO/3b5HQpROFXZhvI43VCMGxPbKPzp9sicejj0PLH1mNOufaKbBuibkGC5\nhwJg62EFbhZLtjF2JoIJ+0prVw/skjEyjBHQhL0Ca3KSjKB+7a+bvqRJDHLqjF2zLNoZwyklHt42\nMmkWrBuisCSneo9kigSyxzqnEaBz0X0Btu/ukkbXbbvmzUzCcXYMLTEYRRKmRpY9HQTHUgSbZg2N\nc0kpuBcl8+5w5G7KfFw3bn1wLInf3z1wmidet41tDKaU+JuHdxxK+W99S/+rv776ojtp4vvbS+SD\n9cE2Bu+PCw/lwJ/Pz3x/fuWn7cq5Vx6nYM+eJPPzduZir9zsRhElZ+fej1zshqaK64VM3AhKwXZ2\nV04rTmzRlRBfFm2Qjaxtd1Ip6zYzZwt+LRf6nhXWtwUBiijOhbsDcQwcMx1lRmluTMfGtm/FW1Wy\nZGQIkkN32SW276sNMhLJD5pYpe1qBWcpidpDGN8wsirXPjimTN3h181b8FAtUXLCe0QIFYVNKqlN\nSOkcVKgVxshsekOmhlqAwcd6R7eNH9sjIwU15ni48Knd0V4zi3TOlnkdC8Ph7+8+Bss2Kz+vD6Rd\n9pSzce2Fxzn4v3+4fkCAagGWGR4hmkkGlcRw4S/Xhzj9S8xt2wjkeBsp5p+mu6QsWAXmgo2QlAEI\ntqsXJH61t1y00OCKhJ7XPUwumVAMGMKU38Ir431QH1+0xe6GqiKu4C0e2Cn4DiahXoiVXIx1dlcy\njn5xHCfAC5A1RssmDIkH6qRKyYXWB6N33Jz740zZlcP3aeLqnaTRvZasfHc6cWuVc10peuTD4cih\nZHJSiure4Tr308Rxmr762e1/6vrqi24IsxXUKEkoOVF75395/QuXVnluK3dT4Tg9ggs/bWd+3q6s\ntpKS85AODDeSNK72M54rMDikPf6HicGNOVeEDSfRXPaY84aIM+mNOgqNhHoQqx6WG3h0SA0lWWJt\nhdPcgU7STh27tXOb90UdeLlwkGAtqBfGiAWeY8zHRpWBWmbrwixKa9F1bQzKLltr5vjImMamfoyB\nTkG/WlJ0xAGeGWRV6rCI7q4dVaPtmWZFMnOaYCM0nmpU7Yx6YNbKos6tK9taSFpJ07rLsWC93pHL\nxnN/R8ux7Hk4nXkdB/7j+VsWbVxH4XUszNL5h4ePIRVLmed65HWbyTpIydlG4lgaKvDzesIdNtPY\n+uO4JLIOes+A8mmbqF1BItomeMExZghZF5G2S7wnwV60v+IrDJJGBlpyw2zXxoqTJMYDeZeemUHJ\ngtBxTZQ0MEsh3cJJOijJGCO8akMl8tEY5KjeXxgLqpHGYTXMDVmgD5DdNKEZjpK4yGCYUDxkXyml\n3b4rLCnH918iuWTKiZyUKSW+mQ98UqFbuOGWKfG3D4+svXGpja0P3h+O3E3Tf+vb+P9X11dfdLtZ\nDPlTpo/BVjs/3y4YztYHj8sBxPlxe+Kfr58xd1ZufDMfSTqzjc7n8YT5jTQ1ZhyRst8wK8gzIkG7\n+pKB5ZmsV6ZcydJYbULUKENRvaAIRTfO/Uh3UAuo9PvpGvM7lG13Sa0jcTdHh5wZYX/titkEsgcY\nzldUhNYBmxhuTEwM6UyLUIdTJLNVo5DpBrM4q23MOX6vPSEeRUMJwX3kdCkTxrYv7aoMJg2r6ImZ\nta44ke+lLsyWmTTB+cTFB0MGI1detyN3cmNW49ITt3VCDs683GKmmpTX1zt62fjUHhnZcXMeT1fM\nhD9e3nHQzqXPvNpE78Lf3T8HJ0GVWy/8fDmR1JAstAE5OVmdS0tgmXWk4OKa7PpZoY1ws9WmjLEv\nQGPau89zE0gPT5k6tDj2RypGOMtUJLTG7M61PVmXffeVGPF/fLMyo0ypoe6IZ0QCnGMqzDK4uZCS\nI6ZIifm5tng91cS8n+i77GOtKTFaZAIuy0Lftn1MBiUr91Nh3QHpQrggEeN+nvjmeGDSzPO28pf1\nwsO88N3xkZwyY4R9d0qZh/uFu+nrAtf8515ffdHNqtxq42msDHdeW+X9ceFvyz1Py40/X175YX3h\np9sLk2aGDL6TE5e2cfEnbvaKKuTcOHBi0HFWTM7MuoGEUcBFgsavZ2DQPTrhnJw0Gk2NZXpFJeiv\nr+PAgYDJKLe98xl8rndv0npy6nyTN9w7ncI6MoVMU7ibRjAQHK4jYZboVuIbLgMpN4QYc4hMDDMm\n1X0hB9WdjLJ2o1gQ0BaNlOJjjsyso8coYxYisTYHQDu5RDacDFQSLfXdYuxMo1B9w2kMoqOaemGS\nQns+cbHBUMOnjXM9QHYWGbz2xGU9MEw5Hq64JUYSXq4HMvDpfM+OG+bhtOIOf7nec9DGNjJXm7la\n4t207cUwMYby88sc27cUumNcyMnwIYwRRXfsHay/LcxGzHQFwXdEpYggth/5ndAgS0i61COf7E0H\nEZCaFOSvXSvcEMShm1J0oD52IFOMNQZCEScni1GBCuaODYs4n/0kr0kZfTdeGGiCY8mMfem31jC7\n3C0zNuCQCscyMerGXDLJE3NOfDjeUXtwJkrOvNcDh6kwp8TdNHM/z3y83TB3HuaZYylflavsv+T6\n6otu0hDZh7bSwnOO83+/PvFpu3LeNlSEv7l7ZGLiczvz4/aZVz+TdGPJofFcNGasSV9J3EgaYYgq\nC8MbS1rJXMnaGa40n3CHwuCb6ZmbhfZWdn//A5WuW1he1RmmfG5H7qbIFktcd2uq82m9jwAChJIq\nxxw22XUUtpEpWiKAcjE6YV5Ym6Ij78fM/XiaI/bbiWXMGJE4oDlufrPBJBpzXlNue1e7mnBKwU+d\nJD5XIUDZBxRnInWle+eGhVNrNmYPXm6yhd5XOp0h0T3mOqNp8PrxyEcc1YFNlZe6IHmwpE5ryut6\noKTO/enKsFAznG8zvSmwMDToYKeloTjP68KcGrUnNk9UT0xqiMds1hFeryVmuqooUYBlFyIM2wu2\nxyAkOt6EDMMHwTwgospFIFlDUmKMcNsJ+9hAAuNoZEw6ygBRltSoveBawA1JoefFHDNoXSHFA1gl\nRUCnOrbP65HgUhyKBI7X4UDhpnFKKSVkgnd5QQr0YWw2mFIJXkLKPE4zvzmd+Hi9UvvgmMMI9Df3\nD7zWyq03lpL57nTkcV5+Lbb/H6+vvugOM+6niemQg+IE/NPnj5zbyrWFhdGTcRmdP1y/p/tg5cJ9\nmrlfJtyFp/ZE942kFxatuEwkVxa9IfqZSVYm6WSxwEh6ZiaixyfZGEQEyjpm3Iyizil/5lO/iw5l\n13e+y9egR0l0PtUyz+3AqdTdwX+lETlnT+sRdsxOzhtLdtw3fETMTpaCqbOkSPMtHvKzbKGMKA5D\njSk52wi58Biyb93Bk8YM3DRYDN1RlGqDnKCJ8rg/GqbhbCOsrBELNEE9IGZsHrActYlcNkqKoz9t\nZvjK8FBeDE+kNgGNp+u3kX+mhsyV3pVSB0tu1C68bBGdfn9cA6PpwlbT7h4TPnPEhjFNTlJja4ks\nHp2tayzp2M0NO4prbIqNty51ELBGBY80XI8KSsJIu0lNJFpN8yCNmSVKMkhRTPsIJgdIsB9SgMi7\nwKIDM0VcKDKwFHCblJw2LPCZnjDdF2sqWHdSCeRjyYXiM7V1Lq1irnw4HTjpTO3OfdszcFcAACAA\nSURBVCnchpGzsJRCUuX3dw8MM2698bTeeJgn7ueFrIlDLkw5c0c0Ku+WhekrC5T8l7q++qKbVGlm\nnLcrzYyn9ca744EPemTtnR/OZ35sz/zz5RNdOq7w7XRHt0a1z6z+RE4h9TqkI4c9/mbzz0x6YdIb\nkzrNMyrOUS8cuSIyyG7M2mmWuUjiPr9QtFHoXHzmXja2EUmzKQ0W2fihPn7hBSQx3uUrniK1taPc\nuvDaZ+6nmLeRN6onxkhc2rzHcQmlbJHX5TH/HWOQZcJlkAu7AytRR1h+W4vl2WrOKQeLYbGYB59E\n2bwz5+DhLjIxmoWG1garGkkz6nCvwnWAuO3FQzimRJKC1IU+BjIGW3LcF9KyUoA6DLuVCGnIoYEd\nJLRmFOfp9kD3BMnQqeNm3FpmyQ1DuKwzoyemuYeeVp0xnO2WQBJXl920IKEM6B5H+B4OM/O/Ging\nYZl7ewCZhulBPAI9zVGcEQ5dMh0hI6UjFiME9cIkLZZnMuhk2ojxQxZQj3/jonRROrFky8nIgVQA\nC/PBJMbNYyx0yhNni7GSijBl4X6eqRWKF+7zgZs0rj5Qg3fLiXdlISXlYdnpc8AhTRzKxN8/vsdx\nPl1v3FrjuGec/Wpy+M+/fi26+8xsHQHsduJN+VQvfFyvPNUbm3V+e3fHQb/h5p2XeuHFnnF5JSdQ\nKRyz7jfKZ4q+clQnS2WWJaJoqKi8ctSVTKf6TCWBKN/kMwuRr5bcmFNHLZJwv8lndGeqPvcT93mj\nWgnlgzrZB9/XUFZ0ElmN9+UCPqjMbEOhH7iZcyoD8xj4rSS8JW4tk3XnAKQWigVRep/CIWUa1LId\nsD0R+tbZCXsuxm0M7kqkLMwyUdtgksTqnfu00GWXmdVIHWZ0ukAuinjigHLpLdxewyk5Q+okZrhM\n3KwGQlIJMEyuZHGaG30rAYYpI2a1DtSwqd5ejry+QcVLEHJ7F+ZsDBdaz7QmUHYGgobUy3s4VJpn\nRGLJhcoXFpiZBhxc9Ev+m1icNXBHkjEsTBTZBpajcIdeNqMqKB08Ia543p1fClmNPsAl8uyw0OIW\ntaCYvbkfkwV71313Q0bHWrIyhVCEIZGn92G556yVRTMisTB7P83oJHy7HPnN4Z6fbhd+ulz4cDjy\n3334jlMptDFoY3CYCt+dTnx7PJJUf+1u/wuvr77oDnemnPn9NGFE4uo/ffrIp3blaVvJKXE3hQXy\nn7ePnNuFdaycSuGuPHLUOy7jlcYrnU+c0iugJFUO4mQuHPTMpJUsRqbhLBSM93KjyErCmL1ytkPI\nxoDv0jNJYmmmboH0SytJnKJnwGiW+dxP3OdGH4kjDduZpj/VBxzBPFNSp9BBbqwWKMNSMzU5BzG6\nd7Ik6t6B1RZyooEwlcD8zTljTSj7Q4qcyckoUsDjYypCs45KouM8loKgTAhbjcyubpV304KIs2jh\nta5Ylx3yLRxKRlWZh1LdcHXSCJF+YoAUxi1YrgnFdgiMAFOKaCFvU6Tb5ugIDYNWd9vsFEoLYGRF\naSH/2htXN6VVIhIp2RcuAd3wDoNExjAHcJLZrtkV8LHHnLPP8x3/K2a5jcSUW0jJslF7QtwQDx1w\nSi0K7XiLDAo97pTh1gRJgUvsQ5jgCzpyyodIlfDB1kIjfX+c6bWwaOGYZ9w0lrY58V4W/s3xHZdW\nI7VBhftp5rd7PuDdVPjN6Y5PtytOAM3vD/OvyoR/oeurL7pvwvWXutFG5/Nt5TBP/M3ynr+/h/O2\n8Xnc+NP1R652pdrg/XJgzsKcOk/1zwxWVG+8k4X79I4kBzZ7ouiNSV45yEZYEJRZGioXssQYYdrl\nRqsfmPVKliiQDeW9nnkeRwZKkcpvyzPf7+MFgFkHD2mlSCdnZ7hwtZlXnznlTjdlkkGXYBx8rifc\nleGZkhvFBGhcbcK7MDEF31eF4Y0kE5vDTMIbFI+YnmXKVOs8TDN9k8jL6o1TSTiJY87QFSGWb3Vn\nDagI30ynsM2acW0bM5mWOx/KEbFYCt1a42pCdoeUSHNYYGfPtDZoorhPSBImnOadcQvlwiJw0zem\n1wg1iA8YKey4UwSADgZptFh6DYkHjhkjpV1Bwm5wCFQipjsnYYDFgi2eFBFVroSEDUkoffdHKCKD\ntM98s0acRCfhIyR+QyVIaW3CPKO7nbvkQbeQqHXPCImikcmsqmSU5oVSYJZMT4lEZ0qRoPyNvscW\n4TYq29gQz7zLBxadeDwe+O3dHZ8uV9bRUITfn+75t+/ec+2NOjprj2TsgI7/Wmz/Ja9fi+5+VHrZ\nVmbNVNvTbzHW3vjYrjy3K+/miX+3/MP/09659FaWXuf5+a77dm4ki3Vpdasly50YsQ1nECCD/D7/\ngCC/xIMgoyCTADE8CAQEsQQnMSz1raq6ung5l733d1sZfIetUltK7ESi2qr9TEieQxLk4eF71l7f\nWu/LnBOnEvg6vuKQDqBGANa2Y7CGVu2R8jO8CXg10ZiWThnAUeSezsx4mQBFUBbB0asZw4Si1LEy\nFZmLQZSjt/tqyyeZUTxbO7JPXR21UonBjHxeLinn6dHGFIokGr8nUx2yDqXjUDy9zjVxogSC1JPz\nY2jQaJLUaPQEKGMgWyxgczVFL0pIOeOsJue6IhqC4BWkkulbC8Wwaz1hri5mpzCzci1ZhN6cexil\n2g1OBayq4YUXzpOyYtSRaU60xoOJ+HNgYZbCsQROKWCNxWhD62r0kSoFjSFQzbS1KFaqIEYIs6WI\nojEtU6kHb+SANpmaiKBQQVGsPeeqCTYnrBFyquJYEyWqwY3TtS1hqvRhynlxpPoxni0tz4JNjZPX\n559J6YIxhSAWJWBUQZLGuJpZnLE4ndClEHEoDVY0s6hz5Hr1enBa4Z1C2QavOiiRMcwUOrZti9dr\nYpK6ip0tygsr35M1fLy+xCrDXRi5myb6xvNhs/3GerGxlsYappjYNdUvYaluf/O896IrUs1Avrfe\nUkQYvOOr05HX4z3HHBicIyjHs34gcCTLxJvpLYnMpW+5av6E+3hAZOSUf05vbikqoDC0pqHXhkbd\n4dljKDgyVju0UvQCmolGJ3oCx7MZi2DYmEQuAc5y2uqIFM2gM5vmFkUVgPvSMuSZkzTnf/JEY2Ze\npm2NmsFhNaxMopjEXFw1944doUBvFbHU2eBQCtZ45hmcspQCXmeCQNMYcqyLDUoZXDbMMYJxOKNp\npPZRw1itDHMyDM7XvnTXMYeEUsKxBAbf4kXT2+E8eqVJaQapqR0r79mphoBmSickKQbbYJXFuNri\nOJXEXAKHPOGNx5kGTGHMc93ETZpG1faKzjCYDEYzx9oycKrllAtK57q15+QXSbtzAe1w5HpYpWOd\nZDj77KpM9c1AnVN6VR0rU4I6pwA7Vc2KRBeyhpIM1ig8gaJq5HzWHpEae69VQUuqZglZUYqliKut\nIV0ICqw2eNWRJNMCWWuiqqI6JbCqY+dWTDpyKDNWaZ7YNZd2BV646Fo0NRutN56Vbfjh7gqrNW/G\nA1MKtNbxYrNZVnh/i7z3olt34/W5J1j3zPdx5ooVP3ANXmtu48TPj1/xJt1SRLhqe8aseNG3CPe4\ncsfN/IbWaHbuAm/+gH26w6s9WT6n08d6wKMsToNXBieHuqZKHSMr1H9KkYijgC5cMHIr/rzdpLnQ\nM7M01SJKgdOZWISNC1wyUaRwEs+9dLQ6M2HxUrBqphjF67Cq7YXi8BqcLQhHpuJJWVNCTynQ+4aU\nElpZ5iK0tqHMit5YUoTWQlSR7dCSo7Dz9dTbZsWUMivb4BTfpB1PIaMUWAzrpmPlHYN1TDExxsgY\nAxftwAWC8RDmBCgkBaai8Q42rkEbw3GcuJMJK4rG9zRUQxhdDCc14ouwLzOD9gStKDYSSqkTBqnQ\nyBrJI6q0rH0myam2CoLgaRnPI2BJCqLKORG9BpJKqWKpEfQ518HoOs1gzLmdoDNW6sxtURp3zjcT\nfc5hKwYx9jyJdr460ZE5a4qyWGrabl2MSEhR5NLUFF6lcLalDZGcA4HM4Fqu2ivenmYGa3FKMSl4\n0V+RsuKyWfP91Y6vTge+PO552m/4ZHfJrh3Ikgk50znHRdvzpB8W+8VH4L0XXajWc2+nkZgyReBp\nv0LmwtrVwe/GOm7CHavmOWvf0yrLp9Nr7sPPmcprFIZdM2AVbJxCcYvoN8R8y6A9g3lCpiWUA5Z7\nrNzRqrq7prA4Xf/5tMwYfb6cFF3H06Qg5/q3ABt95K70deyLws4eOEZHBlCKhkynBbETICQR7vPA\nofhqOiPVayGUSMFwOzvqHJanM1BUAT1RztYnLnmsaJyvcd7KVOP0retQSeMNjKHGwRutedI16ATb\ntmNMAacUY8w1ZdkovDakkhlDQgR661h5T+c7Oifsp4KY6vR21a256ATjCjEJY050zhKCRRvDuvOo\noefteGAvRzrxDKahF4fx1e7yVI4EiYwFVtYwKSi6o0gi5oxVFl08mkCmYaUiiYSRRBKNE8skqjYw\nVF0iqZNiBp2rOJpU/zZKp+pOhkXOK75FNE4SRgtWFYJu6qQIGq0L2mTEGCR7FPGcqWZoXWDKiqI1\njVtBBAikEsBA5wdMVLQ0tKrlsjFkFSgonvgdH7bX3M0TQ+vwxjD4hhf9Gmccu7bne+s1N/NISBkR\n4emwWuwXH4nlUQYaa7nuB3IpKKXqpZURXp32dTxG4I8vX/Dl9IoLvwFA1BM+Ky/5wP0xg+0wSvHz\n8W+YyxuUvMaoBmsbejPQ2T3kO7K+QSTidUejmnr4IgFDHSMrSpERLBav6uqtqHrQYZVmKtBq6NRM\nof6DTWLo1MxROkBjdWajjhxmW3u6aKwq9NbiORGUJmTFPg2MxdIbV4VYYCwRoy2HGTptseKwRnGS\nmvBbsKyMQRVH7xz5HD5YpFaEu6ZBF4tyhVPKeBzOGi52dch/0zQcY8AaTQyRp6sVSgSnFWNMHMaa\nQrxrGmLTcbXyKIG7KXCbR1rtGFY9F6Un6YzkzCFm1rohkOhNS+89onveTrcc0shKe7QdmDhhjKXX\nhaNMlOwgWQajmJWQaFByIkndlIt4jGTA06tUI3vOK2eOhpAFI6m6OOra3y1K1UkErWqWWqnLFxmD\nYAkkvCrEc+hlKZ6kFVIczlmcyWTJqKJJqkNpR+tqJFLOGmUKiMOgeN5dczSZIDMxR1DChd/S645V\n0/BsGGh1vWpLrvDDzRUf73acYiQVIeRM7xo+2nSL2D4yy6N9xmr9S6e01+2alW1JUvDaYLVmKiM3\n4b7uo2vLv9j9Icf0lsHuAJjzC0r+Owb7bxhMQyx77ub/juR7YI+mQWnB6QsaFWnlSFHVP9erDk1H\nZKoOUiRWCEephyig6bVhLFAI9RJXFVKxrGxhLRO5BAKOEw6vC6M4HA5lBJ1OfFUGkljmYnFK45wD\nF9mnmgxsZEAVw851RKlPjpASjTTMUXjStBjRGAOHOLNuW2yyXHWWkhRb23JMsaYBp8zgPVtfrShD\nzkwp4Z2jU5pn6xW6wOA9x3mufq+ieLIZsCiyZO5iJMWMM5oPhg0hZS7bhqlk7ueJN+FAbxzbdsO1\nbJn0jIqF+yxcmjU5ada+p7eOUDpu0tccQ2SlHEo7Oj9ilKYVOEkilx5yojWu2jjqCSlTrVbRJGmq\ni1exeJ1J51ncrBTGtGTJGEnnF9PqOaxEgzIIDq8KuZTqQKYbkmnRKp4fZ4VYjaWlAI1qKdqQSJzm\nCdGOje1AOpBSjcStozGawbSkovlk9bz6JpeJUwp4b/mBW7HzLb21dM4xOM/tNNG6+sK5CO7jo+Sb\nfOhfyf/xzveNIoV9OiEi9LajSOAnd3/JXEYE2LhLVPqvWN1jdXdOofgr1vK3OPsUr9bE/Dk5/y07\njkDiLG006glWRZRMFOYaH6M0RToix7O5ysy+FI6oasyiavX5MnWkUj1lixiOZeC+DBgsc47speE+\ned7GHUkcSjVMuc6N3gSDlIZQHJ1p8aoBgbdTBtHYPNAoT5M9oqpJj4jCFYsuiuf96uztq5jnyKbr\naKQmx4KwbVpOMaKVJqXIVT/QW0PJiqPUy2mtobUereqLX+MMd/uRWISQE9tuAA0pJw5hYgyp/kHE\ngCpsGkdAeHk4cBcOtMpz0ffcTDOTnKAo7tKEFXh53HPRDFhlmWXkPr1lKmCKQtEw5wlnIWbhJIIS\n4VQyg3LEDEGNIAExhhLrIkmhIKpDmElSfRREO5T0oGacnlA4xM9IaerihNMY29WpFTVxCp6usWga\niipcdTBGTSSzsp4xZZ6tr1iblvv5SONAZ8e6GbhsNigj/Gh7hRbLV+OBZ/2GtfV8cnWF14Y344lt\n09E6x65p6Bf7xd82v7Yxvoju/yexzBzTHQrNyl1wip/y8/1fUK2oCxfuh+j4H7HmOUpZSplI4b8w\nkFD6CUaviel/oMsNDRHOVS0UstqAzMAJIZKK4oQm05LlRMGQZOJV6jlKXUTQypKK5/PYEc6+DUUa\njmUg5o6iPGOK3KSGMRnmtEXEn0/zI1o8x6CwdJQiPHFbrHYgwttjoFOOjerZ+AaVBeccUhS2KJSC\nTjuu2wGFYk6RUgqbrqG3FqU1NmvaxnKcZpSqovKkGRic55SrQbYSBcrQOohF0TuNRnNzHJlLppTq\n2xpzPTSLMXCYahsmlrpavOlbss58ebfn7bRnMC298dzOI0HNSIbbOKELvJ1PDK7FGc+YjozlnllA\niULTENNEUZCLMFNz00YpdNrWVWJGVK5pC0UMTkNWgqarW3JqQiFYZ9C0JBG8HynJoUzE2ZaQa6yS\nVy2xRJQN5GK4aDt2fsPb6cR122BpGWXmxWrLGArP+i3f6zd8NR5ovWXrei7bniftgCBc9T27tmMf\nIpdtS+ccbplMeAwW0X1M5nzDnL/GqI7efsBh+g8cTv8ezpv76+Zfoef/hDIfVveu/AbiTzGqQekd\nSnXk+FOQmcyR85IqIpqj6hEJZLknS10nvucSwRDzWFd/y8xnYccoHqNaimoR6fh8gjlbomiUWpFK\ng9MtWRoOMfL1ZEjJsFI7nGnwxXBKCSMWyYqOpjqu9RcYVWPA7w4Tu6ZjYxpWtvoH9N4zxUBv6pzs\nxlk2fYcqcMiBkoWVryIoKqNQdI3jNAeSFOJceLbtUdpwmGdup6l6vWqFV4qMojXVKvLt6VSFOAm7\ntiOXzF2ciXnmLkScOFIJWG0ZmrpS/MXdHffhQKNbBttyE49ECaSs+Ho+0VnDzXSi1R5nHMd8Yk5H\nkq5CnIuvVyO5ms2cznsSQaAxdaIhqYTkjDGKyHk210GjTO3NqsQxK7ypLZAxTnS+bvidJLN2HSFB\n7z3Pz9tjmRNON2zMwPdXVxwlcN33PGnWfHl/x7NhjTeejzYbPtxueXM6ElLmyTCwbVqGpbp9TBbR\n/V0iIsT8M0q5wehrrHlOPPw7cvzxuedn8e7PIPwVmBf1a/LPKOkzUANiLilFKOmnnCQR5HAeWhKQ\ngbfqKanMxPyWIIqprBjV92vOWp45MbAPEy/DRT140xcU3eHwfHYciaUhiabTFxjRXLkthwRzityO\nGSeWZ/aS3tWDrVwKpmgUmrX2eG34YNggCnLOjCGwazt60+INxJJpnCfkRGcsWoS+aenO/cS7eUQV\noe8bvKnVuRTBa8v9OIFSlBx5stmQSuJuChznudoaUn0HIpnBeXIRvjrsGWumORfnKYr7NBFT4utT\nXYI55ZneeXrfcJhnvtjfcB9HVrqnsZY3YU+UTAFupxGXNXcl0BiLU5ZTmRjLsVpqKkfOmqJnpNga\nXy9nQyGlWPuaLoKaCKXeZm19Ee2dpXMdpxQRJlC17/rR5prXhz3OFta+45gSHw5bBEOnPD+6uOB+\nntnHiet+w7O+5/vrC8aU0FrzfFiRpLDxDaumWcbAHp9FdL9riGRK+muQEWV+gFItcv/nUF7WsCu9\nQ/QFkj9D6SsAYvgJp/IS9HOUuSTnPXP6n3ydG05lQuGr5aF6xq18yJxHjukNs1iyXFD0x8QSSCWz\nzz13YeI29MRsuG6fY2hxSvP5/g5oseK4cBusUnzQbTnFQoyRccr01vNRs6HxnpAi+mx2OBiLQdN4\ny5OuJxVhTomYhF3XsnIN2sIcMk1jmGPGn5NsW29rzprA/TSiCzTe0XlLKFIrR225P50AISXYrVqK\nCLeniVMIdfrE1lTjmCK994RSeH1/x5wSpWhWTa3E384jSuDN8YDBMknAa8/gPfs48uXhhqlkXDZg\nHPuwJ+qMEs1tCOhSGEtE6xarFacyM6UTjVFo25KLQqmIs776U9hISAVRmqf9mjlElMuIFLISBt+T\nEzit+d5mx5giogtO16DRf35xzTFFjIXvby65PZ3YNT2D86x9w48uL8lFeHM6cj0MbJbq9nfJIrr/\nFJBygvTXQAH7R1Buyfs/BxnrbfpjjuUtCo/SA6UUpvhj3uaMNR+gzQvG9Ipj/JTP4wVTCWjVM4sw\nuB9xyNcc4oGv5ldkelp9ydp+yCGOiFIcgud+PpFCA0rzo/4F3jhySdyOE6ZYBttw5QeMUXzQbjjF\nREo15mjTNly2Pa1VnELCW0NMhaFxWBSNtwxNA0WYYiLnwqprcaam8mYy9jzH2zpHTInOe4zRpFK4\n2Z/wzuKMxqiz3WNKGGM5TYFSClEKl6s1WTKvbveEmOpKrdKEnMlF6BtLzIUv3t5xTAmNoreeQ5q4\njSNGFK+PJxyKUSIOQ+ccx1J4eXrLNEesrVX1mCdCqdtuhxjOsT11kUQri7KZr6c9xhgG11CKobPQ\nN459SGAEjUIMfLJ5wnEO7MvMqvWkmNk1A523ODH8s8vrupo+nbhebdhZzx9cPEEQ9mHmqutx1rBp\nWrZN+7t+Or/vLKL7TxUpXyPxb1DKgftTYvgJ4+HfnqNfCtr9Ga/DS5x5ilaOUiJvph9zUy5p7XO8\nfsFN+Jyb+IqX8xWjzDR6wzEpPu4/IeWBt+HAp8dXODNwbS945p/wdTjRmZbDVLfGiIrBtvzhpp6G\nzyUzzgEjmrVvWTtTY136njkE5lRQpbBqm7o4YTQxl3NOnKZpNCUrGudwtla6U4g1/cJrvLU1paFU\nR7A5ZlrvySXhTE3YSKVwtz9hnaXxth7exYdQSMX9/YmsCkUUK9+QcuTN8UiMQpK6GFGkkKTQW8dx\nDnxxf8eYE5pqMD6Vmbdz3YD7KszfzDNbDK3VnHLkq/HAnAvWaVJSGJvJGrQossoINfb+uhuIImQy\nY5mx2nDdr0gCc8o8G1YUEbIkLrqO+xD50foCawzHHLjuB4zSGKW57ga00ny03bH2nrenE7uuY900\nS+z5d4NFdH+fyPlzcvo7lFpj3Z9wN/1n3hz/AnWeemj9v+az6Y61e4pSmlBGfnb4KZM8Zeuv6ewV\nn59echMO7MOKMc+s9JapCH+6+RijHK9P93x2uGewLR+0G678wD7WGPppyoSccEXTWccH2w0GmFKk\nlBpJM7QOh8VYTesdpQgpJyjC0HdYXddgU6kx8UUEb6tYGKMwRiMiTHNEa42zBqM1KRVKKSglpCJY\no8m5YG09kQ8pcZoSKKE53zaHVH1yFeyPI+VsML5qakjnl7d7plRIJeNVNWavkehwCBMv70aOKdTZ\nNiNMuXAznbBKcxdOJIF4jl73xtAaxRfzoR5i9T1KaZSGvrHcToEshd5Z5pz4aHOJ05q7OFJUQSlY\nuY7LpqcgfH+9wRnLy+OBq67DG8snV0+4bDpeHQ8YrbnoOlbOs+u639EzcuFXsIju7ztT/Bkxv8Ga\nHa39A35+/Etenv4bSmkE4UnzL/l02nPpa3/4mI78r/3nOHXJVXNBZwY+PX5NyEKcaxLEzvQUpfmj\nzTVWKV6fjtyNMxvnuOpWbJ2vywZNQ0y5bvQh1Wu4a1DU0E9t6ql/2xgUGqsN1lZRzaV6GFtjsdZQ\nihBTPkfe1PhkpRT6bJ4tUu9XCozR55Tjgkg5JzyUOllBqb/7OYxxnCIA3hukCKlkUhRKhvvjiSnX\nQ7PBOYpkPr8/MMXIHBJG18cwnEMoQ8l8cdhzP084o5gEsmT2MZ6TnzXaWEKObJuWQEYVxZGZnAvP\nNms643k1Hrns2zoRQeHpsOJ2nLjsWp4OK+7nCW0Nm/Nyw7N+RSiFi67hxWrLYZ4ZvGfdNMuSw3eP\nRXTfN0SEffySUI50Zkdnr/jp/ie8ml5hzlaEH7Q/5Ktp5sIPANzMR16e7lnbFVe+pzWeL457HIYU\nMgVhbTzOWF6s1wDcnSZiznhj2HQebwwhFVpvSbEa91DAWUPX+rNoCtYqpIB39uz0VoUV+CarTmv1\njdDmXNBaUy121Tdf8/C75lJF1+jaeogpk3Ou3y/LOaZIqte4CHNMjGP1f1C6puqmAikGUhK+uh/P\niRUKY2sM++vjgcM8M+VcJyakzoqFUkhZ+HI6cj/NPFm1KFVTJLyznGIgqHr4OOfMrulYtY4xJUQr\nphRZOc/3NlsOYeaq77loWz7f77loa0bZi/WKH24v62jfeOT5as3gPRdttyQ5fDdZRHehbtTdhBuS\nJFZ2RaNbfnr3BfdpRFMF7oPmiinV8SuA27l6r25cx9Y39VJ4nHDWUGJG0LS6xnQPrQOEKdRcMq2E\nznn0udq2Vlf/WW1ACsYYjDG/JKpQxfZdUYUqxA/3Pwjxw+d8++3D55cilFIwpibgFoGYIgpdgzR1\nbUcIghTFOM2MIZFzRhDS2cjmECZSgVf7A8cwY4wGU83DX057bqeZmKtxTBThwrccJZCoh2QnyVz7\njouh5e/u7+vBojbEknmx3jKlGY3ig9WaQ0rEkhiahsum4YP1lpBrCsXz9ZZcCp2zXHb94nX73WYR\n3YVfTZbCXThREAbb4JTls/0toZSakaY1l01Xq9XzJtNxDoRYncVaV6Nm5pCwpqYrKK0wul7++/Nl\nbwh1ikBrhX1nI+pBSB8E89tC++3P+zbfrngfPn4Q6XI2MXq4r5TagtBKE0IkggHvngAABxpJREFU\nlcI8R3KRmjwcVT3MUkKI1Q1tzoGYCoVCiKU6t80TU858HU7cnE60zjFLwQGNtxxDrO2FbuAuzfTO\nsm4akmQ8hkkKISU+vthxjIlV43i+XvFqf8Rbw9p7Wuv4oyfXoBSvj3t2bc+2adi07TJ3+91nEd2F\nfzjpHMMtQGctGsXdWOPMoaZtdLZOC2hdn1tzTEgBa/U34lyrV87Cp9G6iue7bYQHof1NXSI/PJ8f\nhPdBbGtror59aDvknMm5EHMiZwFR9f0k7MeZmAuRTE4FozRBhHmeOcbIMURiyQSdOM0FTeHrMLNP\nM7um5X4O7BrP0+2aT2/vUaYatIcQuV6vEKCI8NFuy2EOHGLgomvpnePj7QVKa+7GkauhpzHVrKYe\nyi1i+0+EX/uHWrrvC38PqzVr3/zSbduuJZ0rT2cMSBXa2kutPVv7jqD+AsFai1K/EMIHHg7HfpO8\n+/3eraIfPv6mz/tN31hjRKOQ8wtD9YPo2oLPgjaKwymSSmKaRuZUqjWGgrX3jNpTysQhjSgprK1j\nbT2Rmkx8DIF1VwMitVa8TpGh9aRUo3ys1jhn2OiGy74jS532WPsGROhdPShbeb8I7u8Ji+gu/IMw\nWv9yD1FB4+w31a/RD73UX1wcaa3+ngi/KxyPJSLvth8eKl6l1Lm/rMhF162wXM1yjNG0yjHHhFGa\nrhVS1swp4XWdx4WRkoQpjNzME60zFKvpsVwMPWmqcfIUmHNm6BxaG54Pa666jpCEm2kk5ITVik8u\nrrnse746HdlPM1ZrLoee635YWgm/Zyyiu/D/jNYK/a2rqF/S5XcOtR7e/21Ut/9YHtoZDxWwNQZr\nLCnV7TpnFdNYf+aQMxTBacPQNpxCwgJD6wkh4YrhSdvirGbOmVAyY8mIUTxxPV2jOdwmnPV4q8/z\nyrXdsutaLrueuVTxb507J00bdm1Ha+3v/LFa+M2ziO7Cb5RfJRIP1eWvu/+xeGhvvPsC8O7kg9Ya\nYwRUvcRHIMTIlBTGCtMYSbnGpueS6ZqGlSrsJzBWs24bxpTqpETKRJ9ADM/WPc/XKwyKTw/3ZKBI\n5sPtlo+3O27nmZvTiTEGOue4Hpbq9veZRXQXHoXvSsX27gvAw/shpHcO2gxFCt45QohYa3FNoqjM\nuu2ILhFyYoq1JTGmyKEEXDacSmbXtFz2HeE2Y4xl8J4pRqYQQWme9wNP1mvGGL75mTprWG+3XA0D\n7jtwJbDw22UZ9Ft47/j2tIS1BmPqgZZIXWMex5lCncCwRtNYj3fV92Dwnq5xiFKsGs/Trqc3Dl0K\nc4lMpbDuPM/XA0/7FVkEDBQlrNqGj7YbfrC7QIBDDDTW8XS1whuzCO57wFLpLrzXiPyi4hU5V+QK\nrDMowHhT/R7I3J1GlDZwXl2+6Ft0UtyPM94Jh+QRoPeWKIlTikwp82y94sVqzZwzsRRirh4LH603\nPF2tliWH94xFdBfea5SC/LAmLIK1BkHOJj2FItVgR6gVbtFn5+AiZCm83h/JZJwxGGd50Q5s2477\nOdSZZKNBKS6HAasUr44HUsmsfcNF1y2C+x6yiO7Ce48xqrqjqWqwo5VmzpFcquCiNa0x6JUm5OpG\nNiVNFqFvHDEZ2tYySeYuR+ZTZtc3fG+9xWjNzThymmda53g6DDxbrZeDsveYRXQX3nuUUmc7ydrD\nzWfrSG90nTRQgjGa05hJORFyxmjNtvMoZ0ip3n7UgaiE1ltCKvTOsWnbanxjLZumYb1E57z3LKK7\n8N7zrn+D1gpB0zj3jfHOHDJzyigNThnQipgzThvmnNiHmTlHjDV8vNmw6zpeH/bchxmjNbuu49nS\nu104s4juwnvNu9tqD1MNWkN6Z67YeYPKYLTHWkspwv00cYyRnAveaqxp2IeZxtYxsat+wBvNRdfT\nObcI7sI3LKK78N7z7REy846HL5zN0rUl5F/4/Pbe44wGo1k3nizCV8cDd/NE5xydc0t1u/ArWVzG\nFhb+gUwhMqdqfG6NxhjNy8OBwTmUUpxiQCvN9dDjjV16t+83i7XjwsJvggdDnwdLy5vxxN08A9AY\nw/Wwwi7V7cIiugsLvz1izggsK7wL77KI7sLCwsIj8mtFd7kOWlhYWHhEFtFdWFhYeEQW0V1YWFh4\nRBbRXVhYWHhEFtFdWFhYeEQW0V1YWFh4RBbRXVhYWHhEFtFdWFhYeEQW0V1YWFh4RBbRXVhYWHhE\nFtFdWFhYeEQW0V1YWFh4RBbRXVhYWHhEFtFdWFhYeEQW0V1YWFh4RBbRXVhYWHhEFtFdWFhYeEQW\n0V1YWFh4RBbRXVhYWHhEFtFdWFhYeEQW0V1YWFh4RBbRXVhYWHhE7P/l/l8bI7ywsLCw8I9nqXQX\nFhYWHpFFdBcWFhYekUV0FxYWFh6RRXQXFhYWHpFFdBcWFhYekUV0FxYWFh6R/w2rIyATdi9ssQAA\nAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nOy9WcxlW13u/RvdbFfzrrepblftBhBQ\n4LA/9fh9H+K5+JQLNaDBiArYRbwh4couKolBjQY16JVGxUAkATQmGk00GhTsosZoOOboATe7qV17\nV/s2q53NaL+LWdSR0KiAm2bPX1JJvWusOdaqetd81pzPeP7/IVJKjIyMjIw8M8jP9RsYGRkZeTYx\niu7IyMjIM8gouiMjIyPPIKPojoyMjDyDjKI7MjIy8gyi/53xMdowMjIy8p9HfLKB8Up3ZGRk5Blk\nFN2RkZGRZ5BRdEdGRkaeQUbRHeH1r389Fy9eZDab8fznP5+3v/3t98Y+8IEPIKVkMpkwmUy4fPky\nr3nNa/j7v//7z+E7/q/lwQcf5H3ve9/n7XwjX9iMojvCj/7oj/LEE0+wXq/5/d//fd785jfzD//w\nD/fGL126xHa7ZbPZ8Ld/+7e88IUv5Gu+5mv40z/908/hux4Z+cJkFN0RXvSiF5HnOQBCCIQQPPro\nox/3PCEEly9f5id/8id5wxvewI/8yI980jn/6q/+ipe97GXs7e1x5coV3vnOdwKwWq34ru/6Lo6O\njnjggQf46Z/+aWKMALzzne/k5S9/OT/4gz/IYrHgoYce4o/+6I8A+K3f+i2+8iu/8mNe4xd/8Rd5\n1ate9Qlf//r167zqVa9if3+f5z3vefz6r//6vbHv+Z7v4c1vfvO9nz/wgQ9w+fJlAL7zO7+TJ598\nkle+8pVMJhN+7ud+jieeeAIhBL/2a7/GpUuXuHjxIr/wC7/wac838uxmFN0RAN74xjdSVRUvfOEL\nuXjxIt/wDd/wKZ//6le/mn/8x39kt9t93NjVq1f5+q//et70pjdx584dPvjBD/Lwww8D8KY3vYnV\nasVjjz3Gn//5n/Obv/mbvOMd77h37N/93d/xghe8gOPjY374h3+Y7/u+7yOlxCtf+Uo+/OEP88gj\nj9x77rvf/W5e+9rXfsL39+3f/u1cvnyZ69ev8zu/8zv82I/9GH/2Z3/27/4/vOtd7+L+++/nD/7g\nD9hut/zwD//wvbH3v//9PPLII/zJn/wJb33rW/9DlsGnmm/k2ckouiMA/PIv/zKbzYa//Mu/5NWv\nfvW9K99PxqVLl0gpsVwuP27s3e9+N1/3dV/Hd3zHd2CM4eDggIcffpgQAu9973v52Z/9WabTKQ8+\n+CA/8AM/wLve9a57xz7wwAN8//d/P0opvvu7v5sbN25w69Ytqqrim77pm3jPe94DwCOPPMKHPvSh\nT3ile+3aNf76r/+at771rRRFwcMPP8wb3vAGfvM3f/Mz+j/6iZ/4Ceq65iUveQnf+73fe++9jIz8\nZxhFd+QeSile/vKX89RTT/Erv/Irn/K5Tz/9NEII9vb2Pm7s2rVrPPe5z/24x4+Pj3HO8cADD9x7\n7IEHHuDpp5++9/OFCxfu/b2qKgC22y0Ar33ta+8J3bvf/W6++Zu/+d5z/i3Xr19nf3+f6XT6SV/n\n0+HKlSsfM9/169c/o/lGnp2MojvycXjvP6Gn+2/53d/9Xb78y7+cuq4/buzKlSuf8PjDw0OMMVy9\nevXeY08++ST33Xfff+h9veIVr7hnV7znPe/5pNbCpUuXOD09ZbPZfMLXqeuapmnujd28efNjjhfi\nExcTXbt27WPmu3Tp0mc038izk1F0n+Xcvn2b9773vWy3W0II/PEf/zHvec97+Nqv/dqPe25Kiaef\nfpq3vOUtvP3tb+dnfuZnPuGcr3vd63jf+97Hb//2b+O95+TkhA9+8IMopXjNa17Dj//4j7PZbLh6\n9Spve9vbeP3rX/8feq/GGL71W7+VH/qhH+L09JRXvOIVn/B5V65c4WUvexk/+qM/Std1/NM//RO/\n8Ru/ce91Hn74Yf7wD/+Q09NTbt68yS/90i99zPHnz5/nscce+7h5f+qnfoqmafjnf/5n3vGOd/Bt\n3/Ztn9F8I89SUkqf6s/IFzm3b99O/+N//I80n8/TdDpNL37xi9Ov/dqv3Rt///vfn4QQqa7rVFVV\nunjxYvqWb/mW9Dd/8zefct6/+Iu/SF/1VV+VptNpunz5cnrnO9+ZUkrp9PQ0ve51r0uHh4fp8uXL\n6S1veUsKIaSUUnrHO96Rvvqrv/pj5gHSI4888jHzAumNb3zjp3z9a9eupW/8xm9Mi8UiPec5z0m/\n8iu/cm+sbdv0mte8Jk2n0/SSl7wkve1tb0v33XffvfHf+73fS1euXEnz+Tz9/M//fHr88ccTkH71\nV381Xbx4MZ0/fz699a1v/bTnG3lW8El1VaRPvV3P2PBm5FnPE088wUMPPYRzDq3/vR5RIyPA2PBm\nZGRk5PODUXRHRkZGnkFGe2FkZGTks89oL4yMjIx8PjCK7sjIyMgzyCi6IyMjI88go+iOjIyMPIOM\nojsyMjLyDDKK7sjIyMgzyFheM/JZpXWWrXMAzPKcXA0fsZTS2PhlZIQxpzvyWSCmhA2ezntWtiXX\nBhL4FDgsJnTeY4NHCckszzFKEVMipoQUAjmK8cgXH5/0Qz2K7sinTUqJkBKnXUMfAstuR0Jwvh76\n2G57S0qJWZ5TaE2IER8jdZaxsxYAgWBeFGgpcTGSUkJLiZKj8zXyBc0ouiOfXVpnOes7trani56L\n9Yx113Ha7hBCUNy1FQiCRV1RGU2MkVXbD83PqwIlJSFGQkxopei9RwiBBPbKkpQSNgQACq1HIR75\nQmIU3ZHPnJgSO2dpveOsbVgUFY13HHc7JirDp8C/nhzjQ2Ke58yKgqNigvcehCDXavhAJajzjEwp\nXAisuo5CG2ZljhAC6z2R4VMrhbh3zF5R0IdAf9eqqLMMfVeIR8945POMUXRHPn1CjLgYWPYdPgZ8\nSNxutxyWNUYq/vfxLTZdj1aSCBzlJUaqwSZAses7pNRUmWFe5qSY6K1HSkmmJQjQUpNphZLgY2TV\n98zznMIYhBD03hNjuifeIUZSgllRsLM9PkaUlMzywaoIMZIAdXd345GRZ5hRdEc+PTrvOWl3uBC5\n0+y4OJ0gkVxbL9naHoFk6y2kSHQJlSnmWU7fBFwK1DrjaFKSK00Sg4dLEmxshxKS3Cj2ipJN3xNc\nxBiFUhKpJFoKlJAIAS5Gdr1lUZUYrUgp0VpHBHKtMErhYyTGRGUMm7uesRSCRVmihLhnVYye8cgz\nwCi6I/9xYkqs+47GOU66hv28REnJrd0GnyK5VDy5WrHqe4yQCCG4PJmytY4nzu4ghGKeFUgluJBP\n2asKCIld7xAKJkXOXlYQEaQUiHHYR6wLniQStcnIteakaRGA1hKlJUoKFBIlBCDogiMmmOYZSklS\nSmw6CyIxyQarwt0VWikELkYgIYVkUZbA8KUCg2esRyEe+ewxiu7Iv48LAR8jy74lxiFFcH23oVCa\nic54dHXKtfWKXEoSiUuTOSnCzc0G6x0COGt7FnnO0XRCrhT0Eht6yjyjUJoiN5RGkwtD7x0uRmSC\nWVVQFhmBSPABkTRI6KNDIKiznETgdNdDSiglKDJDvJt2IIGQ0HqPRFAag5SCGBOrriU3mjobtpW3\n3pMEwzF3rYdEYr+scCHQeocUktoYjFLD+OgZj/znGEV35JPz0ZztadtCgpvNhlmWM89yrq6X3Nxt\nkRHa5Km0RiVF691wy+4CT63OkEIxKzMyKZllJYSIRBAiTHXOpMiY1QXOR6xz9C4yLQxGaUplSApk\nSiAF3iaMkVRVhlGKzg7CG1NE3fVzpZRkQmGDZ9X1ACgpmRYZO+eQCCSCJIa8sBKSXOt7C3Vba6kz\nc88z7txwxZtIZB/NEZNYFCU7Z+m9RwrJvCj+z/iYMx755IyiO/LxpJRY9R2t95y1DbXJmOUFN7Yb\n1n2LQnBme1a7BikljsiFqsZaz0dOT+iDZ5blWBL3TSZIJzjuGrZdjxaS0mge2Ntnr8zxFmJIGC3R\nGWTSUGSGQmsa22NtQEqoixKtBs/VJ4+Saii08B5tFFWek0jYPhBjIjMKqSQhJUQa/NrWOda2JyXI\nlWJWFpy2DSIltNB4EVFyEEsjB3+4D4HeBypjKLIh7tY4RyKhhaQw5u6CYmSW52z6/m7CIrFXlORa\n40IgpoSScrQqRkbRHfk/uBBonWNrLYlEnWXc3m6xITLJDDd3Gx47Oxt++zJxvp6RC8W11ZKTdkul\nc9Z9yyQrOCgrGtux7hx9a7EuMqkMB/WEShumyuBTwNqElDArci7M5wQS26bFCIHJJKUuyIxEa42S\nwxW0c8MtfZFrlJIEH/AxoJQcFtgkKCXRUuFDxPlAjJEiNxg9PCaSJCZPILHpLS5GpllGkWlubbYQ\nh7PDZAp1V4C1ECQELgZsCEMVnR6ubjdtT0iRSVGQaTncJfjAJMvYWIsUEBMsioJca/rg8TFh5HCl\nPfKsYRTdkSGK9dFcrJaSs7bFp8hhVXO82/KvJ3fwKQGCMjPU0nDS7ggp0jnHre0OFxIHZUmeKVQU\ndNZx0uwIPpFnCi0Nl+dTMjTrtqMPln1TclBOqSpNTEAcbt1FgiRhXhTMq5Kmc4SUkALqMqfMzBAT\nQyAEWO8IYYiGVUUBCXpncSHcsw6kFHcjaIPoBp9ADKmIzGi8T/joSVEQxfDlk5KgzgwJwZ3tlhgD\nUmomhUEoifN+sBAi9CkigMpk5EbRB8+y6RACDqoao4e4WucchTa4GBECQkwsypJMKbbWElOk0IbS\nmM/th2Lkv4pRdJ/NDDZCP3iZfU8ScFBWnLUtt7ZbHIG2d/RxuGV3wTMvC1oXeOz0lJN2y2Feo7Rk\nnudkSXBrt+W0GxbcYhI8Z2+PSZbx1GrNruuYyJxpmTErKvbznFleYH2i6x02efaqkoO6RBpF0/co\nIkZphJAgBLOyIjOabdvTe4dWillRkOeK3g7+qxCCGBIpRaQSFFmG85G270kxkWca+dF4mEhD2XIA\n5wJSCoyR5JmhtRbbR4QEbSTeR5IYrAkbAre3W2IIKK3Zr2o8gV3bgxgWFI1WCCkwQpJrQ+stq65H\nCTg3nZLrYZ6ds+RSoZUa0hQhMs9zlJSsuo4EVMYwybLh33b33Bw94y9IRtF9NuJCYGctrXP4EJmX\nBZu+Z9l15Fpz2u64tlpjg0MgOagLjNQ8uVxyfbfGiKGCrFCKeVbRO8fNdkPbdnQ+sV8WHBYFLRC8\nJ3aRM9sipWAvr7g4q9grSm5ttiQnmGSag7JmNi0QSAqtsX5olJMJxd6kZF4VbPue1nkyqSgLTUIS\nQ6LOMyKJzlmsDRilmVU5Sgn6PoBIiLu5XmJCqkFUu97ROktKUGXZ3SKOhHOBlCCSCD6RiEzKDCkl\n66bFh0BKiarICGmIn2VK0TjLadviYiTXmkvTKWvvONvs8AQKnTMtsqFxakpkSrP1jp0bInYXJzOq\nzLBzlm3fY7RmanKUEHTeMysKUkr37J9SG/aKAhjuVmDwrsc0xec1o+g+W0gp4e42ltn2PUYpdn1P\n4xzzomDdd3zozh3WfU9pDIU2TPLBRjjrLF20LJsO6x2LsqbWmjY4TrYt675DJsm8yAlEDssaBTx+\nesambynJmJY5D+0tcDHQWE+KnrkqyDJNVRgOqxm98xyvG6SGvSzjaD6hMjkuJRIeQSIlhVKCaZlT\n5zmnu5a26ylNTl1qJMNtvBIKT8DagIuBUhvmdYVPga73IEBJcS/6hUgYpWh6y653QKLKMgqTYYOl\naT3IiEAgEngSZZaRUmTZ9HS9Q2g4rKcEEbHe4UOicY4uONoQmGSGC5Mpt9sdNzdbAoFFUXFhOsWH\nQBsDuZJsXU+8G0W7b7LHNMs47ras+47a5Fwop0gp6byjNhkhJVrnQAgyJdkvK2D4ch1SF3q8Kv78\nYRTdZwMpJdZdjw2Bzjl67zmc1LTOcvVsqCALCUIabrX76CiUYek7rp0tObUNM50zMRlaKXz0nDQd\nZ10LQEiRC/WUS9WER9cn3NxsKZJBIThf15RZRmt7rA94lwgispflPGfvgEWRcaftaV1PqRTzrCQ3\nmqLIqHPDdtexsnaInBUlF/fnWB9o+h6tFZmUaGkQKpFnhlxqTncNXWfJc8OsKNBK0vtIip6UBJGI\nd5GiyJgUGa11bNoeGDqd5Vk2RL/ulhBv247OxXvjRaHZdJazdYNQglybYaFMQKYUnfWcdi1dcGRS\ncmkxxwOnTUPnPH1yFCqjj5bCGA6riic3S242W0IK3FdNeXDvgMZZTvodWg1+tBGaROJiNeegqHly\nd8y6b9jLau6fHKLl4CVnUuJjwt/1mZWQHFQVCQZxBkpjxiTF54ZRdL+YsSGw6y2987gYmRc5vfcc\n73YgYO16jpuGTd8PHbyKEiEjj5ydcmO7QSlBJjVGKmqj8SFydXfGuu1oQ2Rhcu6r53TRc9zssCHg\nXECjKIxhnhlKrTlpOm7sVuRBcq6ecjipKDKN84loh9tiJSSLScZzFoe0IXBns8Ulz2E+4aCqMEYR\nBWgJ29biY0QoxaIsOT+dctZ07GxHphSzvEDqREqSTEmUUpxsdxDBGM1+XZCSGL5sfIQkEARCEhRG\nU+YZ67Zj21kSQxSsLnJ654d4WPC0LhBSwIbIoqootOGp5RnrZvDG98qMPC9ovUUk2LiOs7bHE9Fa\n8dB8jz55rm5XbH2PkoIL5RTPUCm3V+Q8vjnjrGuIBC5Uc16yf5E7zYYb/RKtYaYK9vSEHstRPuOg\nqPnI9il2tmfPTHn+3mUyadjaDiETEoX8N5vCHFTV0Kzobml0nWVjkuK/nlF0v9hIKeFCuNeToDCa\nzjmWbc8sz1g5y6MndzjtGmptKE1GnWfcbDfc3m5Zu6FJzNZZFkXFPMtY2pYn1qdsXI8Rirmp0EKi\nBKQUud1sWPUWpRQzlXNlOscFzxObU3obmSRNpjJmdcZRWdP7wFnT0sfATBuuTPfZLzJ2IbK1HToI\nMqUos4y9uuKoqlnZjrOmxQXP/dM99icVKGh6e7fbmCQRcSTO1xOmecGN9ZqujxRGslcW5FmG8/5u\nPwbPpusJIZFpzeGkxhFZbRusd+Q6R0lACKRIZNJwstux7Ydb/0VdsKhqNtax6Vp6m4A4NOFJgUVe\nIjV8+PiYVd8hhODiZMKkzDnpWra2ZxssNgUiAaM1z53tswktH17dog09tcl53t45fAj0qaNWmuvt\nis4HhIgcFDO+fP8KT+7OeHJ3nUxL7isPOZct2IWWOss5yGqebG/gfKRUOS/aex6lKjjr1yQCkoza\nfNQXHhIoIcZ7PSomWUadZQDjAt5nh1F0v1hId62BjbX4EOmdo3We/bqkD55HT05Z2naolEoSpQQ7\nb4kxcdLvuN6sOGka6ixnbnIyoTmNW46bLSvboZMhycBE5VyoZjzdnfHUdgkkMjIOdc2kKFjaLTvr\niCnRdpbS5BxlEyqjOXFbNr3DCMm+rDlf1RRZhhGSpW1obQQCe7rkuYtDVCa5sd5hvWOqFfOypMpz\nMp0xzzJuNVua3iG04P7ZHkeTmsZ51n1L8kMFWRLDrfl+VZMpwbWzNb2zFDLj4nyGlIImWDQC6wM2\n+HsLYeemU3a95c62wTlLofVQhCEl3nuESBzvdux6RxKCvbLk/GTCzd2Wm9sN3jsyrZjkOVsipVS4\nGPnI+s4QD5ORB2f7HNQVjy6PWbmGTrQUskAoUEiuTOZs4pIntjfoYmBuCl40f2BIloTbKASnziNC\nhpSBaVbz0r2HeGR1nRv9LXIN95f3caU8zzZ0SJGYZxV3+mNikGilePHeC5iomuP+FBsdmSg5LOYA\ndN6xKEpsDDR2sCYqY5j/mwW8BJhxAe8/yii6Xwz0ztFYR++HnrL7VYkLgVu7LS5E1rZna7u7Vy/D\nSRcI/MvqFjd2a5SQVDJDa4GSCZciT27PWPY7fEoUynApn2Ox3OpOaO4G+/NYUBqDkYIMyWncsbEd\nPgomsuLAVGgj2LqGTdeRgsYkRZkbHijmSGW41p3SO08tMy5kMy5O5vTCk1yi8Q7vA8ZojqqaB2cL\nts7y1OoMF+EwrziaDJ5xIqElrLqe1kWMUlyeTTmaTjhtWo43G3KpKU2G0RJHYmIyvE/c2W2wIZBL\nxZXFDCkVy7bD9h4kpAhBJMpcsyhKjrdbbiyHJj91nrNf1wQCm66j7Txndsg5I4fdLw7Likc2p9za\nbbAxMM9Lzk0qVqEnxMAu9txo14ToiCrx4HSfc0XFP6+fYOM2GOmo1XzwmmNkP5dETrnZndEFxdQY\nXjx/Pjb1WP8RbErs/ISMBchIYXJeNHuID62e5NjeJleSy8UVnjd5gK1v2Pkte9mEnW8gGqRKvGD2\nJexnC55ubrJzLaWquX9yASEGX3ieF4MQu2GBM1Oag6pCMNhaiYSRauza9vGMovuFylDx5Ol9oPOe\nymhciJw2LaUxrG3L1eUpx11Dpg0zlVMWmqd3a25u1yxdCwgaa6nLjLnOOe7XPL67Q+N7tFBMVUWh\nMhw9XWzYhZaNjUgkpTbsmwlBWk7sKa33qJShY8lUZ9S5ZGsb1t7Sx4iKkj094bye04iWZdjgvKSI\nBbUsqUvNoZ6w9h0nfYOLjj1d8ZzJEYui5Mx3dJ3FhzQ0nDGavaLiYllxp9lxfbtBCcn5csq56RQj\nYGMtCsnOWiLDQtuFasJemXN9s+N0uyM3mkVZUOgclzwGQWM9p02DDYFCK55zeIALkVvrDb0PECDL\nFMpIUoR5mXNjvebpZoNzkYnRXJkvaILn6vqUNnraEDFSkHSk1hmzouRf1jc5btcklZhnJc/dW3Bq\nt6zslihaNt4ikShtuVLNmOmCR3aPklJDpXpqdYjSE3rfsjAdRp1x3CXWMaeWkRfvvZAu9Rj/PzkL\nhi4eUKr7STKghOBL6vv4l/V1lnY1WBPZJb5s7/mcNmtO/AnTrEYLQy5K+uh4zvQy54tzXN1dZ9U1\nTMyE583uQ0tJ5z2F0oS7/TpAoKTgsKoRDEIdGbq2ZR9NjDw7GUX3C42PNlTZdv29hjS73t2zER45\nPua0a9BKo6IkzyXLvsWnxK12zc12zaltqJVmL5+gBdzszzi2K3a+R0WFFIoi00wzzdKuOLYrbIpo\nIZmJCaXKCHLNJu4IATpv0EIx0yUZmpYNbdzho0TGjCJOybVESU/rO7oIPkoKlXOo5+zpgtthSFGQ\nFHtiwlTXlEZRSMVZ17HxliAS582U582OyIzi6d2KtvdooTjMCqqqpJSSWVbw9HrNadeiheR8PeP+\n+XwoTmg7gvcIFFJKdKY5V5TURcFjx8dsu54i0xwU9ZANdo4UEpuuZ91ZQgpUZcZ99YTWB55YLelt\nD1KxX1UILfEhIIAbzZqTriOKRKENz13sc2p3fGh5hy5YogzMswqpI1JAJuHJ/pid69HGsshLHqiP\nOHW3aMIZRnR0QWOEIdcdFwpFoQqeam4gaZnrllodYOURnV9xqJcUYscNO2EZa2oZeOHsCi5aTPgX\nbrgKG+fMihfe7cq241JxxIfWG3a+Q0nJheKI/zb/Um62S653N5nomrneY8/s0UfLQTHn/uoij26f\nZtk11LrkyxZXyFVGd9eCSSnda9OZUuKorodUiO0JMVFqTXXXN34WMIruFxKNtTTWY+/uontQDy0H\nb643dMGx9g7rPBtnsdFxWFT0OP7n6Q1ut2sUikIbKmVwwtOFhhvtko1tcCQyqTlXVCSZOLE3aFOP\nJBJiQUZBYRJa9PRxRxs9PomhUIEKiUKbU7rocEHS2BIjMibKoEVkHVtschAVhoJp2keoiBdb2hDw\nPiMlQ60KLmRTZNQ85U/onScThkO94Cif4GXCd55tHIonjJYclTOeVy/oY+Tx9SkxBCqVc2W2IDMS\nJRTSwXG3Y+17jFRcqmdcqecc9y13NhuUlKikqHONMpKJyVFJ8Ph6yabrKbXhvumcWZ5x2rfsGsvG\nWXpvQSsKIzkoK1a247Hlkj4EtBScv2t/3Ol29DFwq1/T+4DUkGeK++sZd+yWJ3a3ickiTeIwm5Hp\nSBdbcmVZuS19DNSqZ6/UnM+OWPunSKwoRUsTSzI5IRc7jvIORc61tkPhODAd+1nJWbpCY0+4YE4p\nBVxzC26HCbXyPK8+ACx9f42nXY0SOQfVS3FJYtJt9rMp/7oV9F4glWaRTfi/91/C1c0p19obFDrn\nfHaOK9U5tr4nU4qHJhe4ujulccPV+pcuLjLVBZ23dH4oJtFSDf00fLhXCr3shl1ISj3E/T66aPdF\n1EJzFN3Pd3yMdG4Q2j54apPhQuSsbciMZmVbnl6uOG5acq2YmYIiUzzRnHJ9s2LtWhCSEIduXJMs\n46xf80Rzmyb2qBSZZxVG5EQ29GmNpaMNEJOiVJKp1ETp8ZziUqQQAetrEiWzvEXQYaOlCQafNJWC\nkpo2OHK9xAIxaLb9nExm1DoQYmKXLH2QIBQTWTDjiI4tXdhioySFAh1LZnnB3FT0feB6XBFtolQl\nl7I5i6xmG3qa3tKHBElQasO5asq5rGZpO261GwiJvbzkymSfLBsao4cusLQ9vU8URnK+mnGpqnhy\nt+bOZoe622tiv6wIElQA7yPXd2ts8JjMcN90xsQYrq7PON02tAQUirIcSo2N0py2W57crvAiojVc\nqmdUueaJ7QlN6GlSh0BSZwqtIvtlzrJfsnQbjLbkOrKXLchkTxfXTFTL1idcgIm27OeRmdmj8TcQ\nsaFSDp80Shyi2XCQbZBJ81hfIhMc6J4rJdwID7J0t7mkT1EYjsOCp92CjMCDkxJNz0l/yk03wQjD\n/fVL6VMB8QkmqubJtsLFEikkE1Pw/x68kCfWS661t5Ey41Kxzwvn97FzFps8D1SH3Om29CEiheCh\n6T77RU3v/b0dQypths1IQ6AymllesO46YgQtBdMi/0L3iUfR/Xxl2A03su77oYGLd2x6z9GkonWW\nR05OOGmbIeSOIM80Z67BBc/17ZpbdstZ3zAzGZXOyJXgan+Hs36JTRYtNKApDZQmsPUb1n6FBwoR\nKGSGllNycZsotmjpWbkKH7m+rmAAACAASURBVDW1kdQy4OlIoickQS0cvZ/Qx4r97IwkHC4JTmxN\nQjFRCRVLbOqRqsVFhUyKnd0nEzmF6mijZxcEfdBoodnXBVncZxnPsKnDJ00WJ5RiQq0VuTCse8eJ\na5AoZjrngeqQUuQcuw3rvkcFhRGKWVEyMzlzXXDS7bjRbDFCcFRNuX8yR0bBcd8QY6LrLFFCZgT7\nVc1C5jy+XnHWNQgxLFwdTSs6F2mdp+89d+yWpKDIDAflhFIqPrI65djuIEWMViyKCoujCZYm9Bzb\nBlSg0IKDckou4Ia7g8cihIOkmRcKIztKk2h8R2stle6ojWOWT9H02LBmrhtWIaN3GbVyHBYdE5Wz\ncg0SRyE8Rka0ukgKWyZ6+J1+uNlHJs1MOZ5TOW7HB7nZ3eG8ORt+P2mfW2GIrV0qIrkQPNU1nLgJ\nRipetPdldGlO4x4nExmn/QQt5ggMmVL896Pn8NRmw9XNCUZozpVzvuLgCo33bH3LpWqfTd8PTeuB\no6rmqKwJKbHpOkplKLOhabwP4e6O0eXn+Oz8jBhF9/ONj1aPdd7TOUeIiYO6pPeBp1dL2hDovaX3\nkT4FbHAsioqt7/in5XVuNRsMmjLXTKShp2cddtzpz9j5nj5GaqU4uLsbwzLcxMeOUvf4mJFSxcQ4\nCrUl0dLGRJ8MC91TCEXjSw6yG8OquLTc6ud0sWRuHCUdloSNAo9ioVsaP6ONBQfqGC/BRsXtbkYk\nZ64iMRk6LD4FSBKNxrnDIR/LljZK2iCxPiNXOUcmJ9gJJ3GJjQ4wTKlZ6AVCDrnhTevYeYeWhkVW\n8tBknxgSN5stW2cphWKiCw7KEoFChcSytZy5FqMUR2XJ/fMFtrfcaIa4WBKS3BiyXDFRGTpJHl2d\nsrIWk2BaZhxNpmy95bRraK1n63q0FpR5RpkZBHB1e8LK9cgsUpuco7xiG3csw44QPF3wGAVl5oft\nhoRkE24jhUPLgEgZe0YixZZJ1tG4xM4qSuNYmB2LIsO6RKRhrhvWLqdLJaVwLLKOPR250WtclJTS\nMtUWo87jfIdRO4xwPNLtk0JGrSwPVD279ACPNytqucYIgZB7bOP9rJ1nkXdksuSpxtH4CQLBVxw8\nhyiOuNk9hYgJFycs1CFKKIQSvHjvAqeN5er6DCUUi7zkv5+7TOs9a9tzUFSoJJECcqUpjWZWlMi7\nG5Hu19UXclZ4FN3PF2wIQ52+c4QYmWQ5rRu2NC+14dR23FifcafpqHPDVOUUueGxzTHXt2vWfYvW\nwzY0WabItOSkW3K1PaYJHUYE5nlBLkt8WtNzShIN1kssmqlOzE3ExoBkSRBwPlvTx5yVm3Got1S6\nQSvLiStpY84Fs6GSiRNbs9BLnFAs1JZbfo+1m3BohpO4TZKlKwlIzmc7NnZOGzT72TE9Bhs0t/o5\nUDJTjj5odjFg43DiVVLj7TkCHseO3itszAgxp1YZR1nBbqe4E7d3iwAyDsyEhZ7hoqf3nm0f8SlQ\nqoy9ouRyMWHbOW43W3rvqVTOflVxUFfsOkdwjnXvab0j15qDquTiZMambbm2XeFCQuvB+83zoaG6\ns4GnmjWNtSit2KsL9rKS037LzXaHY8gATzIzHKPAB8+dfoOPDmUSk0yyyCc04Ywm7ZAi0DtBoaA2\njroAGRNt2KGlJZMRJSQTbdBixTTbsu5L1jan0paFaTlXBjbW4JKjVh191LhQYGSk0pZD0/FEP2Vn\nNaX07GcNhT5PH3YEOiSJ63ZOCBVGOS4WPUJc5n9tLZIWg6I2U0r1INe7brgy11NutRERSyKC/+vw\nPmb6Ao9snh5igJRcqY7IlCGFxMXplOQFT23WVEKRK8NX3HeZ5BOts0yKfIgnKsWiKr+Q/d1RdD+X\npJTuRmwCq3boZbu5e5V7NKnYWsu/3r7Dqu8oswwpoCwylm2L845rzYpj27DueqZlxiRTEBPXumNO\n3GAjFEoRk2KiFbnpWPs1O7fFIllkHbUCnyYU8hStBr9w4wo2sWbfdBzoll0cxDxKeCi/g0s5N/t9\nLpolSgaM9Fx3U3Yh53K2pBCBO26CkZGYEhfMkhv9HmdhxoVsTSTSRsNtWxORXMx2bPoZ2ygp5RYv\nNCEpTro9EiUT2bHzil2SdEGhhGSuDMIfsA0dHS3eaUgZeSzJTc5ESDZdYul7fISZzjhfTJirCcuu\npfUW74at2Kd320xOlWG9bbnVtyQkU204mk6ptOK4aXG9GzqzCZhow6TKmRnDSdNxc7chCIlRgkVZ\nkEnJKlgabznrW6wLFLlmWuUYKTltt5zGHSkBKjDLcjIFTll88uz6FkGiyCOFEUy1wcc1TnRoEem9\nohSSiWmpcjd0H3OQC4tRkWkW0SlHpzPqvGdpK5ZdQaU8c9Nxqd5x3Ne0UVKpDh8kUpphM6NkOcwb\nrnZ7rF2OTpG9vOew2GfZ7+ijxwvJ2lYosRhKpTPPzFzgn1Z2aLkpNItixoPVQzy+3dKHhkk2Z9NH\nJqrCB3j+4oAr1Xk+cnaMDZ5K5tw/3eNcNcO7RJ1JFqZi1fdM8hwlBRdnU2Zl8bk+dT8TRtH9XBFT\nGkphg2fd9RglOahqtn3P9dUaGwN99PR9oBeemCIznbF2lv+1vMmdbtgYMjeaicrocJzaJbe7JW2w\ndDExM4r9UtHHxJm7haNnlu0Q0eBSwVQHSn0Gsqf10KaMB4szKhk4cTMO1ClSJS6YFce+4sTNOZdt\nWMgd21ixiQUuwZeWd/BR8aRdcElviUJQyJ6n7IxVKLiSnZHhueVm9FGjVOA+s+S2nXPsphyaJQ7B\nLubc7qckNBezHet+wjJIpAggBALJrt/HpwwjWrZO0USJ95pMZuxnBuEmnNmONlqIikLm1FSUQiMT\nbDvPNnmIsJcVnC+n5FGzsh1b2yOjRCvJQTlE1kSEZdOz8R0yKqrccFRPkclz3HU0naWLEalgLy8w\nhSLFyKqznLQ7IoEsz9kvK7QU3OrXbEKP844YhzLbohL45Ng5Rxt7hAwYDbUyZCrh5BaBpXUCmRKl\nSdSlpxQSHxuiCCgZCEGRC0mpO6ZFTwqB467ECEeuYO9usgE2FNqxcRVLW1CQqEzH/dWam3bK0mXk\nwoKQTLXAhWwo6sgabtopm74gCUmlIvdNFpy0HZvQ0ZMTvGGencfFISJ4lC341/Vw16BEzrmy4v85\n+hI+dLa826ynIkTFBTOFKDmaVrxgfo6zpsHFxKV6Sp1nnJvUaKWIJPar6nN9+n4mjKL7TNO64Wpp\n2/cIKZlkGauuY912TPJ8CPqv1qxdz16eU2iDUZJHN6fc2W1Y9x1ZpvHeU+cZSSWOmy1PNLdpk8NI\nz7zI0WQ4duzSGYkG7xMOwyLzLDLPLkRS2pJE4sHqDj5qztycA7WlznfkyXHqa9ax4MvKW0xkz3W7\nT6k8McHziyV3QsF1O+Oi2VCJnm2ccNtXuCj4b/UtQoTH7YLzuqFPipnaca2fchKmXMlPIEVuuj02\nLsMoz6V8w4mdcmwnzLMNNmiakHFmJ8RkOMwsK5uz9Bofh50ktDAIv0fnBUJ0NE7TB0WKkloVLDKD\nazVL22NTGIo+dMmhnoBPdMnT24ANCYPgoCiZFvkQMWsbrA9oMfTfPawqQND2ltZZNs6h5NDSclYU\n9D5wq9nQ+p4kFEYqZnmOMLDzLcvW0YQeIwR5rimLDEic2g1tCkBCpMTUGEzmcMphncNFj5GJzISh\nMEREgtwilaVpc0SKlBqmVUcpHa1PRECREDFSKIOWljJr0CLy1G5ORsSYyFGxQ6ahfaVQnjYUdC5D\noTGq43y5ZWUn3OoLjAhEJEcF9GHGJrRU2rKyFa0riNEgteByPWPt4KTbEFMByfBgdQFLhvWWc/mU\nO41j0zlKShZlxf93+XlcXa042TVcqicYZXj+4pBplqOU4qAuydTQZW0U3ZF/l5gSIUZa51j1HUYq\njnc7pIDDesKqbfmX28d0wVEZBSgmuWHtWvre89R2zalvWHUdB2VJlWfE6Hlsc8qJ39KllkIMfQZm\nxpBUy8oPlU0eyX6xY6IlPpVosUaIDYtsR+s061hxMd+yb1rWTmKTJCJ5uHqKIBVPtQuOzJZMeGoF\nN1zNMpS8uDxmploe6w8ISSFIvKg84yxonrQz7jMbpIg0MedaP8UneGl9kxDh0W5BrTq8kCz0jhvd\njOMw50K2JETBTTtl7QuUSJzPd6z6KceuotQdfZDYaNi5ghQqKuNorGblJD4IlJAU2pDHKbs+0kdL\nCAKSIUNTq4xa5/jOs+wskcDEFMxMyb4psNazi47ODpVnudEcFgUJifNDTrp3HmMMs7JgP8vZWc/S\nNViX6IKj1JrZpECqoZ/und0OGz1CCLLMMDMFPY6zuMW5SJ+gkJEil5gcXIAuNtgUUCKik6TOcqTe\nQdbhe3AuYRRoHZjmAkkkihajLLumQAC5gWnRMDUtZ12BCxIlBVo4Jpq7u1tYtLJc3+0hQkJKwWHV\nUQrB0iUCw5d1ChojKjwd87yhDQW32wKiJAnJhUqhxD432g2IgPUZKRYYJkTgQlXgY8n17QqTFJqM\nlx5dIaPkrGtYFBVZUjgfOF/NmRjDSy9ewvvE1vZMipxMKi7vzSi/sAspRtH9r8aGwHGzxYfEnd2O\ng6pimucs25an1muIidY7XEzY6MilxijF2nX875PbbKzFaEFlMkqpsERu9CuOmzWNd/QpMM8081rR\nhf+fvTdrkivLrjO/vc9wB3ePCUBmVmVNLJESu0RRTdNT//+XNqPpocVuiSzVnFlIDIEYfLjTGXY/\nnChRtO6mmUS1aFWVxwwWsPAIOIDwu3zdtddeq/DN9JHFNsY4owjQs/eF4I/kurLUylYD/2J/z6DG\nUxrp5ILXzI+GB46l43674nvxmRt/5jGPPJUdBeV/238DUvn58opbdwEcdy7zNo3c55GfDI/sdebn\n2y3PqSNo5a929xyL8ov5ijfhiCFM1fP1ekXC8a/HdxQTfj6/whCgchtmHvKOj+s1N/FCqsKn7cBp\n60CE225jWgcelg51mVyVag7LA7UOOElMyVi2AAZeAzvnGa3jdEnMZKQogzp2LjK6DifKtm1Ma6VS\nue57DqGnU8+8rMxlI7cUSq76vuU21MJp25iWBBR23UDfO7xrVT8P27lFR4qjd45d35PcxjGvnJaV\nLJXooAuO4AJJFp5SBivUXBhCR9+B6zZyycxpQzUTFFSVQRzeTUi3sa6QN49XIYbE1ZiopTUnB7cx\nLW3RwIuy7y7cDBc+TAeWpC9h7onbaCzVYbLhXeF+2iNFMDxX48qVNz4uylaEZIoXx97dcrEZ71aq\nRR4XDyVQcXy+G3gV7/jp8zMlV1QjHYHP4i3TVrgeAzf+mo+nM9ehZ4gdP7n7jB9f3/G0LETned3v\niM5zs+vxIgwxcui7f74L+p9+/j9B99tQzX/CaWlfK2suPMwTh9gsQ3F13E8Ta858vJx5+/zMZpm7\nbkfwjn0c+M3xmfv5wnFbCKL0Gng1dGQ13p2O/Pz0wFYyLho3Q090ypxXfn3+RKozazUwR6Tn0MNl\nSzxtMzVVvr+b6FHm7FjyDgln9u6ZU+mYtpHvdCdu3czkIk8lcCo3/C+7j9zkhbfrLfd5YCcLO1Z+\ntdy1IdjhPX/aPZHNOJbKQ9rzg3jkPQPv0p7HFBlkRqTyd/NrghT+cv+Oa7fw8+WWxxSpCNdh5t1y\nYLaR67By5WY+suPjssNJ5TpMdFL4tO6Zk7KRQZU5eSAy+spK4byuFK2IOZwvjDpgizClwnO+YKU1\nP+ycY9f1aDY+nS+slvFV2fc9B+noxDPllfvlSEktzvC2G9mPnpQr95eJraxINaKLHMYBZ8ZlKxzL\niW3OiBeu4kAcHKkUnqYTR1kQE6IId35AvWfSC89lYdtaTVDvO/oRoLalGEtorlAFp3tiLEi3kurG\naXboEvGqBFcJ4nHSJIZ1VZZ1pHN7QljYDYmclaet45KVqXR4BMkek0ph5mnrWLYBE+h94m4oHJPj\nsjqWXHhaO1wJZAIWN5Az0+pYck8xR6fKF8MdD9vMw7zwtNyzbkK0oVUz7Tu+s9vxd+sTXz+fOHVw\n4yL/5s0XHNfE+8uJqJ6ajS+veopVolfGEGjhmX+4fO9b0P3vOL+TER6XmfO2EVS5n6eWv6pKrpW/\n+/Aep47olEM30MdWTvg4XfiwzDxME3PaeHPYMXatx+uXT488polzXeh9oPee6y6yyMz76cz9cqJg\n7Afjrg+IBbYy83460rsFp0atkVRGdj6xsfGwKi7t+LPDB3Zh42GrPGwjF/Vcu5mHbeCpDFy6Z651\n5pOO/HreEXTkL8cPXIeZr9YbPmyRToRBNn4131LwfBlnvh+OHIvnm02pduAH3QMHP/D1cs1j6ols\nlCr8PH2Gp/Kn4wd+NGZ+Md1xv41UU3Z+45yEUxroXMVrJlW4pAFRGEKl08rTbBwNqmXECVoVqQEv\nxpxW5s1RKngcnRd26imb8XCaSC9abifKIfRE8Sw183SeMBxeWjVQ1MbC3x0ntm0jmCIqXO92xOC5\nlMRlXki54L1wO+wJQdBKC1i3hbRmnHPcDXsktHSxOS0stSJVGJyj9xCccqwbay3YljHx9NETRsNy\nZq6FfAFvCgW86whUiBO5bhyniKse9UoXK5FKRTllJa2OLQXU9oSwsT8sbKnyuA6c1shmnk4ryojV\nM1PdOCXHsvWA4rVyNyZOyTMlIRW4JEe0DqmRTOYxHzlvjnlzYMrgHT+6uuXTsvJxunDJBUnwnf6W\n67jDOUG8UhfjkjLPl43DGPjy+orTsnFcllYI6h2f7/f/3Jf5/2/nW3nhv/Fcto3305lcCh+mCz+8\nvqVzjrfHI+9OJ7w6Lmkl5QIIuxixWplz5mePj8zbinrhOvYUA5HKh23i0/nMwzyDg+s+cDOOnNLK\nV8dPPJUJDYVO20BpH5SFM5eysOSMmfHFYab3jjU7sq102jQ5q5WpDHy3PzH4mePWc8odnsJfXL3F\nq/HNckUvBecKt37iYRt4znv+7f4DBzfxt9MbHtPA4Cr/dvcNyYRfb3fcuBURYynw9XpNIvDv9l/R\n6cbfXD5nM4+a8UX/zFI73i5X3MUJscJv1zue0ohQ+Xy8IAZfTVdUlFqVIkLNynkb8L6Awad5oBYH\nCOoUX4Vp8WwVajXUPM48sXjAWpXNIihKRDj0nqiedUmUXMlmjBqIockS6pQ5bUzLimJchYFdP1As\nk2tlLoWSM+od0eDQD1QHx5qYLzMmlU49+2EkW0tdO2mrL5IKonAdOojGU02kXLFScSK4l804EFZ/\noVjGVrCqHGLEDzPVNlJp9fUKUI0gAa+FMFwwFS5nDyZ4D11ndFSWYkhI5BTYkiNqQFxmP87U6nie\nPKqQzdE5w0tHZiaGzCVF1jXg6EEK12NhywPHdXvpqVN2bqSToQ3cek9ZoFblEAe8KX/+6nOcBN5f\nTtzEHbsu8he3n/G633FJiTfjyBgiZsJV36EqfH7YE3+/2y2+1XT/KSfXynFdWEricV646XoM+NXz\nI53z7EPk7fHErx8fCc4xBMfgI513vDue+DBdmNJGJ56lFr53fUWi8vb5mV88PzYfr1S+2O8wjCVn\n3i7PHMvClleCBnYj7ELgebnwlDYyM33MDB6ont5ncBeKFVIRqPD9qyPRVc4psBUYdeP1eEGt8JxG\nvhufiS7xtI18Sjsihf/15muiVH613KAmbQgYzpxL4DEd+PPhnp2f+b8un/M+7dlJ4i+u3qLAr5cb\nvBRUhK0qD3lgyZE/278nSOY/Xb7gnEciietuRaXyfjoQfbOK3c87phwo5hi7jJrxMA8sRRFr228O\nWLdIqopR2OYALUsMT0QVtrlSNodWcN4xqBAKrLVl5lqGIE1jvQoRE2PZEjU3L8Dee9S7tqwhMKXE\nmheitHr0seu4sDGta6siohKdZ4gOlXanc5RM3loN0OAdY+hZXGLOma1WSjWcgSoMXkiuMmmGXLGt\nEvqIN5BYoWayW6hSKUtj8qMLaDdhvrBmwZLhnWAVorafm3bNA3yaIqA4dXTdxuAKp+Rp8C2k7e/D\nkYahZXg8XwIijlQcu2AMrudYVqBQqqckTycjicqurzg6nudElIATeDPu+f74hq9Oj6hXdm7gJnT8\neP+KbJXrYeC74zWndeVNP7KPHV9c7bkZW7hTHwJjDP98F/0//XwLuv8953e5CL89H0m1VeO8PT3z\no+s7Bhf4xcMDf/vpA7fDjiVtvB4GvI/UnPnmcmFeFh6WhVyNz/Y79jFynGc+zBNP88Qlb3QhIlK5\n6QcuJfFuOvLN+YRR0d54HQcqlSUn7vOJJBksEcVxGIzghPO2sOREdIXdsDH6QqkBbxveryBQsmCm\n/ODqgU4Tn5aRJUcGt/H5eCJK4nHd8Vl3wUnmMe24X0e8GH95/Vt6zfxqvmW1iAPuwoliwsN24Mv+\ngV4S/3n+jPttT+cSP97d02nh6/mGZAKmFNrSw5w7Xg0XOk388vyKU+pwGF2sdJp5nEc2c6gYpzVS\nsqOap6jhpDJfAltxrWrdHF6gJKibggqsAhLocHhgqxWyoEVxtGn/EDylNAtZrYZWoReIock6G5W1\nVmoqdL6tWnuF4mDFSMXI60bwcL3b4Uw4kjitM2ZG0FYIGTzgHKlmLra1rjgxeqfsYuQoG1uBUjIm\n4GtrZvBqWKwsumAbUIzOBZwv6EuQ0GaJ4Gtj8y/bXS5MEBPr5im5Zd1aFvpYMHOYW5FgTOcA4nHi\nCH5l6DLHLVArLTgpO3YaWKngN4JTTpOg1mEI3gmv+4EP04ZJxmtAi+fL3TXnlNFQed3tmZfMzvfs\nfOTVuOcnbz7n60/PrCXz2XjFq2Hg33z2BXNOFIM3uwGvjle78Vum+8d0zIwP04WHZeKSNpaS+NHV\nHcWMnz18ZFoTow/cnydEIHglqMcM1i3x9fMzly0Rg+duGNhyZh8j35xP3E8z99PE6B19H3k9Djwt\nM2+PRz6sE8UVeqf0vmPXK0/Lhfdzi0XEMrGH6y5Qa+WcF6qsIIXoMornMCREKuuW2bKn9xtXu5Vr\nvzClgKfg1DABamMt398/0rmND9MVz6lndBufj2cOuvBpGzm4BtzH3POURzD4V1fvGV3il+dbLi+r\npoe4Esl82gZu4oyTwtfTHY9pxFvl9e7Czq18mPdcSkREyKaIGXPuUFcJUrmfdiypZRiYh6iVaXZM\nOeAUtlVBFC1CLi2JyjaoW2swUHPtJV9o1b0VXJbWKvxyHacMmCHJmv4bFJy23jkzKEIUGJxHvMOJ\nsFlls4JUJapw4wPJwWLGWlpuQ86JXh37vsdEONnCUhOWK94pvXhMKxsNOFcBV4Fa6VQIXjk5o1Ko\n2VqdT1bUGYpQu40sG5aNukGnPc4niO1zazWcN2xV1IToAhIWXLcxbZGyOoJzVCtE3xLmittwrnCZ\nAmoBJwGJC31XmJYGxKYNxPexrXNXZoJX5k3prSN6T5XMl7sDT8tKKpXRdxxix1+9+R6XtXDcJu76\nA73zfP9ww13XowhfXl+z5VbVedP3RO/58vrq25SxP4az5MwlbVy2ladt4aYbuKSNXzw98OX+ABV+\n8fDI18cjn+92zCXz3f0BqcbzuvLrT48YIE6oxfiTV7eUYvzi4z0fLhOIkSj84PqGUo1UMr85Hjnl\nmeO6sYuBcYjso+fddOZxmnjOC+Kh62DEI75wKYVLnrGacALdTth7I6VKtg20Aes+LtQS2PcbVitb\nFpbs6HziZti4iTOXzZMrKILzRrTMWjzf3Z0Ikng73fC0DYx+47PdhRu98JQHzAxEuJSeJXtqFb6z\ne2YfVn5zuuUp74iS6bvEXlee1wFxTYv8tIxcXtht7DODSxynjmPqwYQsghejFMdiimAskycXjwIV\nh0mlZiHPDicgBdCmIVvSZkvLoMXhTXG+yRFY+1oxac+vinVC3ipOwaogGfa90gX3YvFrwzsvincR\nj4A3nOhLpVEhiCOoso+BIsKxbljJL2WaxhCablpFONtKwZpf1mDwkWwbczUQYxNwxdBacc7RdY6L\nbqRSIBuqSpRIlSZHZE1ol7EMZRVC6FES2q9YMZbk8BHqJrgKXhwlZLTb2LZA2ZQgkWSV0BWogc0S\n6iAlh1aHcz2VmW4wllUoqXXClVo5dJ4gkWNa6b0i1fNZf+C2G3lcV17tOzQ5nHheDwNj7PjJqzdE\n9TxOCzf9yHUf+WJ/4DB0pFK56rvf98bib0H3HztmzUP7y+MjgvBxupBK5k9vX7OWwv/x7i0P08xt\n7PmUF/706hZ7GWh89fRMLZX784UheF7vD4zB8/F0odTKu/OZlAvBtRSqfQjkWvj6eOKb5xOmRggt\n3MMU1rTy1emZU95YauYqeIYx0jvl/XTivGWybjgHQ18Y6Knadv+NDQW8Gv1oDFrYspFXcLGgznjV\nz2zJMcS2nrqkwFaE3hUOu4U7f+G4dkw1vAx4jL0uLNlzO8x4Kt8sB57Xkc4nboeFV+HM09oxl0AV\nZa0eLZCq4zAs9C5xf9nzsO0IUtBgXPmVy+ZZJWIVli2yZkUQTEB9Ja3KZelo9n4FB2QhFUetgm4G\n+e/ZkInDklAzxOoQVRwGTqiJtmK8tbxW5+xleASigiYjGIiHIEr1kCs4M1SUaELsO1QhGWwlUUsl\nijB2ESlKDcZSClmMmitdcI3des9WEqeSqSnhvMeL0Pnmxy65smqh5opIRRBGDSxS2KRChaoVVwUr\nhnqHN1j6jVIqrAXnIlEDhZVqhURFfUt0K7PQSURDpYZ217KuivcCWUCMoJGl5saYa6BuQqcDqyRC\nrFhVlmwMEbbi8ea47rrWHB1qu9NLwmfDFdlqW6PuDny4XBhdZHCBP7m55cfXr/nF4wNRPdd9x+fD\nni+vrrBqDF3kdmhbaLe74fe97udb0P1/O6kW3p6PnNPGh+nCm2Hkqut5Xhf+9tNHrkNkS5VvphO9\n+OYhFMPjWJbEp+nMOEDzmgAAIABJREFUeUuMoWPsPCVXXo873p+PfHi+8LzO7LqeMQS+e7jifr7w\n6TLzzfMzzglJjDfjyBADp2Xl548fuaSNpJWxi9yOPakUPs0njjmz1oQp7KNn7KGWwlOaXgzpBR+V\nfZ9w5ltSl1WsFrxUOjVCXxl9YlmVeXOEUHHeeNWfyCngfSahrGsgV6WTwn63ch1mjlvHKUUERbVy\n282sW6DzCQQe555T6gmuMnQbd3HiskWeU0clUA0EgyqN8WplWgPP64hioNCHTM3CeQtkFJKy1eZ1\n3UwoZlAUW9rFKAaIUCpIhlIc3gSXaQBtTcfGGrh0Kk0HVl7WM9r3OdrWFyqYFARtX1ShDwpSCQir\nGtWk1f9UpRPFqVK8YerZ8kaFtowRI5gx1cyUE6KKlULU1tAcnGezwoWErOXFLxsJKsx1owDppflZ\nzXAe+qIsrrB4GvLTmpDrZpgDqUodVqpV1ksluoFRPYssVMlsBXwwNCtpNTrft39mmDEnbAs48ahz\n5FqJvi3iqMuIenIy9i6SATz0Ksxbs3hFlOgDP9zd8H45vSTo9exD4M+vvmhNIwZv+hExuO0HDn7g\nauz44fUtny4XVIR913rlvndz/T8dD/4Hn29B978+l7SRa+XddKLU1hb786dP5Fr5cn/FcV759++/\nJufCq92ec0r8q+tblrVyTDM/vf/EVYictsRnu4HvHW5YUuJv7z+yrQkxZRPjBzfXGEYulbdPJ87r\nzNO88OowMHQ9u+j45unMMS18PJ+QoKgKN8MI2jrRvj49s1YQV9stb9dsRI/LzFYS5jLqhJ1TQudI\nJbGlBSuCDxkXlKshYVlJlsjVI1YJWhh8RYPRu41586w5IGKEWHkznEmpaadz1cYuixC10A+FvZ+4\nzD1PW/PSOmdcxZmcHdUZ1ZRpjcwp4qTgAuzjyrYJz+tAxoOB8wXJSjIoopSkLItHBUyE6goUJW0K\n2TWkNEWstv/bKlAV3QRXBXUNqAoOMUWzIqZEKohQnZCK4K2BbRRDQmPS2QS1iq+N3XbRUTAIYNXa\nMM4r0UWc1KZFlwKqOFE67xgR8KGxRmnV5lKNPnTsukBOhac8k7UixaFW6MRT/MuenjSgFmtSTDSP\nV+MilWrN5iYCUoFqdOpYXCWrQTbEIDpPzjRLIlB8G56VWZES6VVJbiW5RE6gHnrXMW0thL0C6IoF\noawOb47gha3CGGHLgEL0ihR40x0oCqdt4mboyUlakBAeBP7k7vVLYl7hTb/nuov81Xe+x5ILx3Xl\n87FJDq+GgV3fkUvl9X6H/1bT/cM47y4nPkxnDPjF8wN/dvOKQ+h4fznz799/TaceJ8KcEj863IHB\nMS18Ok/kXHk4zby5Gng17PEOPhwv9K7JCXPeEHFtAusdrhpfPx15miZWKxziyNArvWvWsbdPRz5e\nzi3/tYvcjiNd5/jNpwfeT0dWDC+C74U3/Z6prDzOC5eUwQwXC0PniVGY18xam3brXAPTg6sUVbKl\npvdlI8aKC3DoMzk1h8ZmEbVMDJVeMy5URIx19cxbh9OCi5VX49RAsTpWC5QCghCo+JjpfWZaPcd1\nbOHkrrALTcYo5tgQ0hJYq8dTMS84n7EknOdIEcVZG/I5jK04cmlNvGwOMASjCFAdrJVaBIfgq4Ez\nigmlCq4EtAihbliEoorVgJjSZVDfKn+qglGp0ny/3krTdrvGinN2BAed0Mz9QKpGFWk/A3MMITbf\nrGtMeM0FU4hOCARGbeC9kClmDTitPT6EQM3wlC8kMVQMxTNWJUuhBMOKMWlFMyiVXgIuG88ksraL\n1Isiuf3eiZDU2LSijTLThciWK1kgqLLJigYoyaEFYhRWKlWbC9iAUZTZDC+gL6A7eNpHc4yxYy0b\n+95Ts1CrsI8dwTm+t9sTfc/b8xMH19M5z7+8e8Uh9jxOM3fDSNTW9PF6GAjO8XrcsR8a6L7ajX+w\noPt7rVT/t561ZD5MZ65i2+nex8gvnj/Ra+RhnlhyRsWxHzoEx5Q2ns4Tz2ljXlZe7/dc7QNXPjKG\nwON54t3zkVyM637gZhz43vUtp7n1mT1fZrxTBOXPXt8Rg+c0rfzn9x9Yc+FSNt7sR3Zjj1Z4N5/4\n6nHmcT2DKK/Gtn9eMN6fnplWo4g1rbATRiJrNY7nhVw2MCH0RnCeLginS3mphPEtBDsavWtD/adL\nu1BKUYYXRtyFwrp5tlXYLBIohC7Ta7sYn5aOdQkspbEvF4yrbqVmOC+BB0aktosvagFgKY6UHJcl\nUlHEWZMkgFQ8l6VDq5CLw0lb/ixOSFUpq4PSRmdobXJBVWxrK75aBN+2BKgKuQhShJgBTairJAeI\nYqniqxKkIrpSDYpGMIeY0FsCryDWfKirItrcFFKF7JSE/Refr0cZoiM4JVMpwJYbaHrvGFA67Vg1\nM5fyUk1e2YXQapUMNqsc1w0n4CQwmNA5j6pRRDivF2oytLaA97Eqi2TWUqi1QlB6FHJmjI4MXKSw\nmlENuuJxpclYNUE1gd7IVGRtzzeZUVxlVSOl5tZYX1j2f3GA+PZmIgUGF6Em1AlOhGqQU7Pq7WLH\nd8YD98uJb84XDiFz5Qe+v7sl18paMocXHrfkTIiOm6HnZhx5fzlTLieOeeOz8fee5f6j548KdOtL\n06iIUKpxHTr+5v4d+5chwj5Efri7Zc2Z3z4/8/NlJuB5Whb+6osvuA0DD8vMf/zwjoM7kkvm0A+8\nGQaGEHlaZ35zf8/DeWbaCm8OI1djD2I8HSd+u8w8ritOhcPY86PdLahyWWd++njPNK+Yd3yxv2Lo\nI6qOr073XC6J1cA5GAfleug5bwsPl0LJCQtG6Dydb46ApSysz/llf921zSTfrEpPk2K0avLohTg0\nvTKZ8HDqsOIoBYYuo74N5abVkXJkrYHgKj5WOiomlec1si6eLXuCq2io9DGTspAWz1KaWb4AnVRQ\nY6uOlJS8eqo1jVi0AZflFymhKFIElMZIRaAIkgxNoK4i3shFGtstEKuBGpiBVIoolh2SHNEKqhOm\nkBGqKa4Y/gVFzUN5GdipgFqlc/Zyp0DbrHCCD4IXEFfItbJVyCJorQzSEUMgOk8RY8ory1JwTtl1\nHaFWHI7N4GKwbhuYEWLHPjr81qx+Rweqlc4FdhawIPCi0W+AiRJUCFkYLHCx9gYrDiwIQxXW1Kre\nFVikKTOWhT43Z8iEsVAaEHtAGtBqU1PYaJpyNXC1LX0EB0Eda4ViLYoyquc27pi3hEnlkhKlCIfY\nIaJ03nM9dHxzOvFxWigFboeBv3j9OXPJfLpccCLsQ+QQ48tCxB/cDfY/OH9U8kKplf/06QPfXE5U\nMb4+PXMXB65CT66V35yfGMTz4fnM47bw+bDjuhs555WUM6E6HqYzp3Wj94HvXh3wKogoXz88Ma0F\nJHHXHwDj88MVx2Xm0+nMV8cjvQiuj3x5e0BFePd45pdPn0iloAEGF3h9tWOpmbfHZx7O53bBe7jr\ne2IXOE0TzzmTc6WWSogePwjRCXPe2JaMpQQ9xB4cDjVrQ7jcGGM12A3gvVK1MM+C1AoKnTd8Z7hS\nyQo5e6xCKUIXG+CKZtbNk7OSzRO0ga2nUl+sXlvylKKoNKnDhUopSlqFrTicNugPAiLGVpWSpckI\n8kJsXdMv69oAVytABWsgLvwOcNuwSaQgAlaUiiBZCeRmhK1KNcG84mtoOrCtbeAmShEFiQRJqAk4\nBQTTNnkL2p4jBkeRtnbM1mxzXfAEp7S/mDVDhQleHQG4is2ZkquxlZY054IQUAbnkKQkX5lrppa2\nGjx0Hd4pmitzyVzqhjUizqCRlDaytOCczaC6iqJYMfYucN4SkwEC4mDnhctixNhY9EKFYNQidIB4\nz6VsrQrK2o5J7x3LWhmCx8SxlZV951ursgQ6DaCVL/YHLmtiKold7Nl5z59c3XLJieOSuOkjwZTv\nX9/g1GNW+GJ/Ta6Z3gVux4EhBL53fU3nPWvOfL7f/z5X9cC38kI7Ttu0WVTwCHvfUpuqGed15dP5\nQk2V22Fg08D3bq6wBFva+LsPD1zFiHPK9++uuetHTsvKLx8eWVJBTCi18JPvfs7oI799PvLXv/wN\n6oU1JX70+oax63AGv35+Yt0S35yOdOK5uR647XsWK3x6vvDV9EytmT5GhtgxRmWpxsfTkXWplCr4\nXtnvPc4Lp2PhuaxUMUQMFzu6qJgJW54p6WUIFB2dd6j3rGlj2YxaFFGo5jn4ijhIW+WYIq7UFkAS\nKl3XGOSyNdZbm3pA9G1LrBpMKZJWoTrFiSHeEG235PPsSdVhSVExTAqq1lZYq4ekjWZJ+1WdQJJ2\nW1sFDEwyAAUHGcSMkK2BqhlVXHull/bHeE0vK9bNjeBQdDOcm6iimFdqbkwvmDbGa5XqpUkSuYG/\nD4ajsb5LyogpTqGLzeLm1ailUFUoUnBJ6PqOTgNqpa0LW6F4iOLY+cBeI3hpj9VEWirRCyFEuuiQ\nIpzyxlo3RD1d9QTxzWucGtimImgAzcaNjmxrZo2FLWcyEL1QK3QVfA2IW5lcRUoL0YniubCRRZBa\n8AoBoUpzQgQLFElghtWNqIFBOiqpBTKp45gTz9tKKcarfs916LmUheO6gLVwn84Feue56Tsqwofz\nxv3lRO8jP757hQLHbePD+Ux0jje//4D7j54/Pqb7+IGr0LqXnpaJ//2br3EmnLeNOSd+cv0ZQ/T8\n6vjIN88nuqp8XCa+f3Vg9D27LvLxfMIJfH1/pGrbonm937OmjbVknueV52VBaiFo5O6w4xAc76eJ\ntw+PnFNi8IG7/Yh3ns47fvP0yIflwrwtRHXsho7PD3s+zTPfPD+yrrVNrXvhZhgQqTxNG3OqpJQR\nsdZSsBNShmXZKDljL9qrOofz2jad0kZNgtWCdkLfGRI865qb/aoKJoY6odOKacUUpjUgpUk00Rfo\nDCtCKVBqs3FVe5FGXRuIbcVhm1Ck3bZaqFhtskFJ1gC3CDhDqoFrTJXSNAmhvZFY0dYzZkbD3ta+\nILwAe36RH8zwtVKCoaVSXGgAXIWaDNdtmAniAlYVMY/LGR8zyRQrDpPYmKi25yu1It5h4nC1trxa\nr0RVtmIUA2hrxEPwGErvXLu1x8il0Imj04j3CrViVllKxdToYsRXGF3EpDCVSsVYtkzvFVMYVdnm\nwpHMkhecd0SNRBMSCanKbIW5ZHzLzeHK9Zxz4WILGjypFEInpNK8v10MTDVRfEVx1FS5jgMXa4Pa\nPgSWnNk7papSK1zHnkta29zAKasZn3cDl5LZh47r2PMwT1x1jaB8Z7ziy8MVvz2fXoo1I3fjwJdX\nN5SaUVFux5GcM/uhJ6jjs/2Ou+H3ujUCvmW67agIThwPy8RaCqdtpXOOuzjwncOBh2kmRMfzeWab\nMlqh7wI/6A5cdyPe4MPjkd+ejqgKb653rLXwo9tbztPG4zLx9unEoevoQuTHr2+pOfP+dOGn90c8\nxmaVV/sdb26uCQj/8cN7ni8zc1kYu5E///wLxITzsvLT9/dsZSWlSheV/TAQY+BpuTCdCnNJeOfZ\n7Ty9UzaMy7my1oxkQdUTg9JFJVllOW+IWYsy7A1xPfBie1ozdX25f42O3lcKwpSFnPXldrktEjgV\nsgl1FbbUFhLEK9EX1FdyVdIimLx4ZM0QUbJUtELO0twHL+Arrrx8mbbchPq7n5hgWPMsvHxOrTFh\noWIiWHFIri9tCO0NIhdt9jATXG7WMpOKRMXEUdUjWQi1UKkUX8jFY85wri1UqLRtsg2HqoMsBKmE\nAMUqKRlZClUcQQ1XHfLiaMPBJa0tF9cHbvodJSW8CtTMilJLMxMffEeoQlXlvGaqZdaS2e969jEy\nFCFROZfSUno249qPLfIyBqZpZU6Z4h3OCgcX6XBc8sy5rixWqAFEK7Eqo3acSCy6YWSKwRUdWYXk\nK+VliOb1ZVhZ24bb+rL9GEQQBOcdinCFctvtqMvEkhNBHIeu40dXtzwsM8dlRTlTS+XN1e5Fb2+J\na8d5w2vLmni1G/nu/gAi5FL/nxfvH9D5o2K6AG9Pz/z1+6/xzvFpvjD6wA93t+RS+Q8f3/E8z+xc\n5Lyt/Os3X7BTzylt/Ie3bxm95zRvdJ3nh/sbdsHxt48PPM8rqTRf7b94ffcygHJ8OF6Y0sLH44Xo\nA4dx4NWu4zJv3M8zH45naq2Yh+shsu8G1OBnj584TWeqGWM/cLMf6L3j4XLh4TiRqJgJvVfGMeB9\n5P58ZlsKRSqqwjg6ojqWmlnmQpbUbpVFcUPAibVG2smQUhEVpKsvxnijWsVKW2YwBLwQXSVhFFqI\nCtYuHnHtY2lUlJKbBxcRRNvtv5m8OBG0XXjw4m9qrKs9ZmC/+/giSCJYac6FxnppBtQCUhsgUwv6\n8phJmxqpFcSaHGFBXyQHeQHUimWhqgFKFY9XjxRQl6gIpQjqPDiPV3DVqO39A3OKWhu49SJEry3m\nYauINg9XHAIOobfW8mBO2XKhlsLgA7u+b8sipTVFTymjobH1XhzeQa6FtRhbLZTS7IFOhJg9c1mZ\npOBM2aywD4GUK97BlI0nmSEoXW6WuhKUKSecV0opmDdCjGxbYqTHfGWxRMSxUunM0YfAXDKj8yQV\nrGSuYseSCzd93xLKcuL1uOOyFcboGVxARfnR9RUPy9osZbHj0PX85O4N7+YLp2VhP/R06vjJm8/I\nL9kVr8c9YwjcjSM3/fA/GRn+h59vfbq/Oz97+kS1xp+mbeOnT/doFR7PM8d14c1w4LYfOKeFaVsh\nw/OyUErlpttxs++YU2LUwDfPTxznlaVuXHcjIQbeDCNvzyfeH088zgu9KFdjT+wdu9jz4enI1+cj\n25aIMbALkS+ur3m4nPjq+MTTNONE0KB8NuwJ0fO8rrw7PTfdsMK+j9weRs7LwnFuTDiVSghCpx6N\nUKjMayVNGQkvea2dQ4IjLYlUK1YyIuC8oiEgUthI2FpfbvnBR0E8pCxtyCPWAPR3wyZtEgbVUWsD\nTIGGrQIgWH0Jk4UXMG7LHpQXk+nvXp/1vwLj8pIEY03qIAvyMkSjOLDcQBhrQGsgm+AojW1KQXBt\nbVbb465UvBaKOUpQqI2tOauIFGpRDNe2u3CoRCTlpouItOGdaQNh1zblSm3M0NShqrjawsl3PiIO\n5i1jpaLiGfuAAr0LRHNsaiw5YSmjqrwed2Qr5KX5b88lE6Wtz3Xq8KIs28ZCYckVVNn5FmwDyiUn\nZjZEFHFAcKSt0HtlpjJLRhQUZZcD2sNjSrT3JCOI4zYM3KepeYGdkGvhphs45ZUgnn0MzGviOvaI\na00dt2PHVipfdCPeO+7XmUOIRFG+2F+xj4HjurLrOvoQ8QrX/YhT+O7uwEZlSZlX4w4V+PHtLfvu\n97p+Hb6VF/7hUYFawYuwpOYovxp7cPDl1Y6S4bxUfvPxiZu+J1sza3/3cMV53fjm04klJ8Yu4Lzj\n3332fcyEx2nmb377niiwpsKPX93yajfiqvB/fnjPb9ORp3Wm954vX73CR8dxXfnZxw8vG2aF637g\ndjfig/Jwmfnq4xOrJYIIwziw00YrPx1PnFOhbgVR5TAErvcDl3nlOGfyVhEKsVcGHxCFtVaW00ZJ\n4BBC7PGdUsRIcyKX3BiveFCHqLTkqpQgv4CtU1Rb+lcDQ9/Cb5D2MlN5CXkR+B0Txb2Aa8tBwGhW\nrpeVXOTFjtByxf7+cyKN1VbAG6wOJCOdNcaZ2+KElDaUe0mQbN+Gb1KIGfoSm5iCkM1h1rILqIbW\nApJJeIpzqApSA5FErQvFGVYCCE16ENBaqcAiL3cBVYml4HvBqyenzJQX7KXaPXiPM6GrDvXCeW1/\ntgG74PGxAyssl41Zc9tiozHUq6FHpHBZN1YT5lLpnLKPjl49lmGqK1NdWZ2wOdhXZdDISmYOL9ts\nGENWrkLkMSc2LW2LTpXRKStgtdngelpcphMh1RcNXx2dBjrxzC8yDuI4xMBdGHmymVPe6Myz9x1f\n7A48zRNrTu06K5lYA+TEl1ct0vHd5cxvzs8cYs9n447vHA6YVVL9g+N6/+D80THdh3nir999zf/N\n3ns12ZUlWXqfb3HUFaEApKosXa3Ipg1tjE/kL+AbfzbNxobG6R52VXWJLGQmVKgrj9jK+bAv0P1A\nGp/YQ2bOMUMCiYhARNy4149v97W+lbWmj0458LdXnyEKfznseD8eGdTz5njkJ9srrrqeRoRv9s+Y\nrDzuR7DCddvxcrvmHCdO58ApLOzmmZVv6JuG7dARUvXdf/fhmUDGestNv6a9iOp/+/49T9PIEhKr\nrmM7dPTe8TydeTiOnOOMFYv1lq82GwrCm9Oe43mBnJHGshkaVnj2sWqDl/hxSSKsvQUrnGNmnhKl\nKA7BdwZrDSlVLWbOsXZjapHW1aOtSeRQKCXXuxQG3KWgymW5FrlIquqs9+LspRbgAhiUOnKAWmw/\niUEr2uai/boU2FKBMxU7WWVeZNBYoKkFWYRKqcm1eTYR0EJppGpNi0AASXXB9nGGWLiIVhF8KEiu\ns86MR9UgCtbURZxmkFzIzoERRCy2XORtJRONQYpiLqMHd5k9O1NQdRSxuEtTvm4czjZITlAKxXuk\nKK215Fy48Q0YeAozFsOSYeurCsM7IcRCNnXu7o2DImytpWhkVsOxRGYyrbW0KiRfU4ZzLiRbZ7SN\nr/2/LZbeGfY5gquL0rwor5oNB51JkumcZ4qZtXFgDXNK3HU9x5jondA6xxgTV11HTEpvLLerFftp\nxjeW3nk64/h6veHDPDKGpYZ/th1/fXfHm+ORSQurC+j/J1dXFBRnDJ+vNzhjue46rrr/2un+YK6s\nhZfDCoC57Xg3njjFmefzwofTiZASm6Hlq+stnbWoJu5PgefjGW8sX9xtCaXwxWbL0+nIh9PI/e7A\n9aqjazp+cr2BojwcRv709IC1Duvh8/aKL19cszuP/Pn5kedjFepfdT1fvtrQGMe785E/vX9gToGo\niZfDFbfrnjlnvj/uCSEzxUTvhX69ZbCGuWQex5HzHBGg7xx943AFzpoJp4WQCwah8YbBesRkzlEJ\noaBFseJwbX2hJ4UQYp3PFgHT1cZToOREoaCqqNQFSC2ktbTyETTzaYZwuWeLfBzPXswOfCrKlLrk\nkaZcukZFiqmf21blAt1FSnb5mspHVYMTtKvVTSKwgHhFDKjL5ObC3cUgC0hUhFpsi9hLVy6IVJNE\nmatUrThH8Q6bHU7BaK5JxLlQRBBMlYlRAxRTyaQC0VwSd6kMg6RaObG5sGiuxXlOXLc9DYYoyvM8\nkUQq+NxBa4WWqiw5LIWghRSFtdSCWUQ5hUCgMNlEYwyt9XgRTDJoiYwS0cZgROii5dq07MrMZCLR\nOUwWhuwQYznbwFknCoLHMRjHQsEZCyIMztMZw2Lq921EGJzjzvXsNRBLYkwBZ4VXTU8q9WR0P55R\nlM83W1rnyFl5mmfUQIelc57WWdZtw26eK61scrTW8vkPOB8NfoRFN5TMpunonCPlzLvziT/snumM\nY9LI3738jNY4xjjzH9+8pVFDTIVV5/n1zQuMwDePO/7x7VtKTIwx87c/+aJCRAT+/P6REBNP5zPX\nQ8961XHV9jyfR3739i3vd8cq51k13K4GCobTMvP2+MjxeMY3DS+HLbebHs2Z5xB497yrPnox3A09\nd/1AyJnvT3tKLJQCXePYDh0rY3nOkefzSI4ZdYaNd3Tek1Q5xUJZlFgErwbTVc2mAktWUsioGoy3\neCOordvzUiKlSMUnVt8qlIoirDNe6gigUDvjS3EVcp2J1tiyqkLIVXsrLmF7RTyUjzNbapMrtuBX\ntaAWrdZbXQy4SkurGuDqssIIqSjWFHAgvVBK1brqDNiCuFJHEp2QTZ1BuEWRRRESwTikNaiRqnJR\nEI0VJalCaUAbg8umKigKiMnEpWDEIq5+XJGq543pUoh97Zq9CL1rKrh8mVHfcixzfdyK49pY+mKY\nRPkwTfUkQmbjWoqjFuKkPJWFRHWYteJoLvK+fTUik1C8OAb1qFVGkzgSSSit8fQ4jqZSw9CMMeCs\nJaaCpS7yvEqlp13kcopggd55vDpmIkup2vR107I2nlOJ/7L8NPX5UVJh4zxZ4JBnZKxUtF9c32K9\n4WmaeL17pmtaXvYbPluvyKUQS6b9AZemH+539n9zrX3Ld8c9c0oXaPnCb25e0DvHdddzP59w2XB/\nOrN1nk07cD00HOdQgyLvD4xhwYlytVrz+XWDwfDhdObxOHIKM531/OLzV3RNVRD8+f6xxn/HyPUw\n0LQNrzZr3u93vDsceZrONMaxHjp+tr0lW+XD8cC704iWhKHws/UNt33H/Tjy9rxnGWsgZesN19sa\nE7Q7nfg2TTDX2exV32EdWNNwjmfmCUquc9WrC9EsSd12E/UyOvWoq0d3BVKqCQdq6lxYqNOBrPmi\nmzXgLx2tkY/bs1olm5p8gAGNgAUxVT3gmoxtqhSslCrVygF8m2hWAbFCypBDLQBFQF2h30YwSimG\nEoQ8g3ro1xkGwFUWgBclL/Vz4hTppHbJRbDjRfvr6oKueIu6ixNuUSSUurT0nuwFjEMQpCjFKSYo\nKSnaGaStoByTQY0imliyEKnutybLBV+gzGFBjWHWjMkTvlQtbdJck39z5mwzWINQGMTVhVtJ3MeF\nYqpbbNu0iMvkIkRVJkkX+liFsptSaWlLqVyFYJQMDFh6dZxcIpKqplodG9OzmKmeYoxirb0wGCrZ\nzCAYK3RiMM4wzQVKwQisTMPat8QSGUvCY3jRD7xoez7MZ96dR3pvuesGXq5XPI4T3x33DG0DIjSu\nudyfP04y683rh3z96Ipuc3GgZVWsMbTW0TrDtATO88K3uz13bY/zlpthxY3tOcXA++ORecqsOkuH\n569fvmDRzOPpzB8+PDIYyzFFfnZ3zd3QM8bIb9++J5bEEjOd9fz6Jy+wWN6cj/zjt98zlsASlZ9e\n3zA0LUWUb08pQZASAAAgAElEQVRHzuPEaZnZtC1Xqyu6xnMOkT89PnNaJkyBtfdcrQesMZynwP1h\nT4wZJ0ozNKybDqvK7jyzS+eqY0VZNZ6md4RQGOdCvDjLrKkRNNbBohCKVqnWRUcp8tFdVh1mdaBa\nancrXGazdSSASZj28oCb2t2KFaxk2j5hbMEYJQaDFos14FzE9ZmmKaRSwTYWZVZH2yysNhGMsGRL\nmGtEeLGC6YV+HcFCilLTgueqMe7XCYbqkEux6ipMqJ13UTArgw5VeeGmUm3STmrn3nrUCVIMQkJG\nd0FGKtmDbS0Ug+RMljq31gyzF6QVRAUXMuIdUmCUaiPWojSmTrSdE0yJLGRmrXNNKYYb41AjdXQU\nFmbJ1ayiYL2DrESFHTX40oilFccaQ7Gw08QsCSPQZstWWvYERs0k6tfRS0MQJWsilEijUlm6lx9b\nKnUCtPUtV21LmjKhKD5neuu4bnuew8ySImcnWBx3fc8UM/MS2IvFimUzeCwWK3VkZICgSlMK133P\nT6+ueH145v35QNbMyrd8tdn8G1eFf9vrR1d0Q87cdgMr31BU+YM88s9PDzg1vD2d+PnVLTdth4jy\nzw9PvE8HjnPFE359c8WmbTjGyLf7J/ZT5LjMvFz1rPqOX7Qtu/HMt897vn3c453BGs+vP7ulYBhj\n4tuHe07LAlK4Hnqub9d4K3w4nXi/27NowRvLl1fX3K1WFAp/eX7ieB7RyxjhZrVi23a8Px/YnWdK\nqTjC26Hjdr1mWgK7aSYsE3NUemdoth1WIYfE+ZiYtdbH1lGTdH1F9i3hwmE1gmJpL0Uz5EtIg1Q9\ngn4c2Sp8ooJLwrh69MeAZqnb+ybhXaq+/3J5f5TGVX7E1WbBmoqTTMmQgsOZQtsGtt1E0yhztmiB\nXiJFPI3L3FxNFAyhOE6jx4lQTMKsoOkK4oSchJQUG4U5KN0qI32lhuWo5CKYnEiujjx0XYuNqmID\n6FKjg7WNaLHQ1G8+qeLmKmnLAqkTYgYRg891JlKcsMREcYZsqmxPktJhSEnZmcLpMpGpayPFmopQ\nzAR2mrGmDsNvXYcIHAk8a6KUipnsjMG7ChgfNRNzuZyALOWiCklaH9vqEqzjoFYN2MQsQrZCUmiN\n4IzH2khjhcVBLlRdtlRYeUVYWlpT7b2idbJtjWUlDUlmZgpTCVgxvOoGnufIfpk4p8i2a/j11Q1J\nM4cl8MfdI5umpfeeF8MaqDe2H3Jh+iF/b/+XlzWWlAv7MrPkzGleuG57etewams0TC6JD/sTp3nG\nYPjiak3OhU3X8O5w4nkZeT7P3HYtt6uBF6s1uSgf9ge+Px6RrGxXDX3X8Gq14sPhxLf7Z47nmdYZ\nurblVy9vWZbM+/OBN4cjOSdA+fr6hpfbK3bnE3/aP7HMgSVGts3AF9s1ucDTMvLHx3vmOWFdjfpp\nnSWr8nA8cTxNSFG8c9yuLL3vWLQwzQsxAgk6D+1gMFiWkJlCXaJpASy0lzoayyXB4KMKQUyNlLlE\n50j12CIm1yoeK8VNpWCsYsn0TUSM0rhMzoYQhaGJrNpQ/04NS7aoQuuqLvZuM2JNRqlzjhg7CoZN\nN3PbjXhXmEpDSobWJnBKKMKrm4mkQsBzPHu0WJzNMCj9CqSruV+SFa+F6WRxfaHrCsl4UjLkRbGU\nT8u20ipiBdWCRIOMVFB7W8jZ/otqQ8CkGm2j9eHB9hf+Qa6HgiSwjwW8EOWTKY8GR0mZg4nsNSJa\nF2qNFXCWMSaUzGjKJS3Y1udScYSUmbSOlBCDF8O2eM6ucDaRyAIibLXFGcOjzowuk7TGuW+tr+hJ\nrTcbj6H1jhy1ks1yxuO4si1ihad5ZpdnjBpu+p7eep6XicdlxBjDq35F6xy7MHN/HMHCpqkuzTlF\njstCKvniNlSWknjVrFl5z5IT/w+Kqv/fXz+6ots7x5wTT9OIAPtl4W9u7qosCMP/fv+GpliKFm7X\nK77e3ACFPzw88J/ff0CyMqbIX798QesdWpT/4+17NGd2y8LN0HG7HVh1De92R373/oGnwx41lq/u\nrll1HYXC64c9j/OR07iw7Ru2qxtWQ8MpzPzzu7c8jzNGlK5t+WKzwTnLuCTen47EJWEFrtcrbjc9\nphge5iOn44ymgnGem5VncC3nFNgtM3mOxFQhLZutw4hjKZE5ZsKil/ka4OsYNNYxL8lcomAEvKkz\nYRVqkoJTTHuRlF3gK+rrQmwzzDiX8UaZFgdqSAEaH+nbmXWTyAjeZnoiIfX0LvLFsEe3hqzCIXWU\nbOht5KqduGoXnMlEHI3JlLGmCW+bGd8lxMI5tsTo8ZK57TI7bbjdTICyaMM0OVKuVmbbgLku2E6J\nyWBzxpGYTh5apb0plGxIxVa+LqAX+DoNFA/kDMViRxBVUgslZUQq1awuE6ubLWnV0RYPThRfhAFY\nDOxy+iRJbqXC1aUokg37UmOAnFFEHKticM5wjoVFFsrl57PBU6ww5sDB1JuoK8LGtJxyIJChQINg\njGBL7YRLUVpr6JrqJouhzpyMCINv6I0jSapruiQ01rC2DccSiVnxrtq8h6ZlSgm01MfXWIpTLMLa\nt1x3Hd+OkXfHA33X8MV6w8th4O145PvjgWMM3HWVv/BDvn50RTeWzMY33HYDRQuDa/jueCDGzP00\nMuDYtC03w8D705EP44H3uxNjigyN47ptsb4K4N8fDnw4nkhZ8dbw11+8IqeEGMM/fvueKS2ErNxt\nNgytY9ut+O64Y3ca2Y0ja9/y+d0NL1c9qcB3z8887o+IQts5Xg5broeed+OR+92esmSUauT46mrL\nUhIf9mfmvBDnSN80XG+ucQ2MIfA0nzmdAw7w3nMzONRAToX9slQgtwitNbSuiuAjyjlXnqqgVZuq\ngArRF0Qv9tkLQcx5JS+XOWhWbJO4WZ3BCPPsGIaAawNT9ny+3nPlJwwwZcsh9YTZ0g+RL/s9W1k4\nxwbnC9d+JmdLsfDz/oGyskQsu9iRsqXVxE135jOzx6BM2tLajM2FgGXlZoYucLc2nHLHeWkRCtfd\nRE7CZog4m5mLZ4mWNLYYU3Be0euIaWpBL0FpcqqJwAL+qqZVFLWUqQLqDVUnrAo0kKxgFNysWBWS\nhcVqNZNcMt+KqXKzc4ZZ6/HdW2gVemDWytydJZG03vA6sRQRojFMOZBEEYTGOiRVbsaSMuky6ilY\nOqDRmo4cKDU0ExjUMUkimMxiKlCoyXUMwmUfWp8bnsYJmgyooVDYuIa7dmApR5aSIEJnLK/aljc5\n8xwXJgq9cXw5XLGbRx6nE4e8sPINL6+uiFIYY+BhFgbX0HnH2vmqBPmvi7Qf3iVSqfdKnbOdUmDr\nO27oGJxn5RoezyPvD0emUBgaR2MsP7+94RQCu2nkj49P9OKwYvjids26dRznwDe7PeMSa/igb/ir\nL25JOfLudOJ3H14jpWovf3F7x6prmUvmtw/3UOA0nrhdb3mx3mCs8jguPD49sj+OeDG8WPU0bUMp\nyreHA+fjBCL0DtZXa3qpgZQPTxNBKlOhbzybrho2DvPMaQp1eVag8Yb+IpzPWTkV/bRY6ww4Vy22\noVQplBbIti7PqsfBICoYVyiqVZ0gWhMooqsdczT0q8DPh3ucLdyfNrzoz7z0Mzob7roTX3Z7NAnG\nZN4s14yhwbrCV90z2zJxv1yRvPCqO2FK4Rnhl90jIsqonkNsSdHitXDbTWz8jBhhzA5nMi/kTFgs\n1hZu/MimXZhKwzG05GLY+oXcCc4XvMssuSEjHPdN5e66QrMJqDUk6g3GaK5z6wSlLdACOFiokjgg\nN0pKgrZUFCWVWWNi3TvOH7dWWuOAJEN2sA9KMFKVV0DjocuWUjITGTWJIsrKNLhUkyVGEksumNZh\nszKorws6B6oLWSNr0yMOxpyIVFVKYywtlkUCs1Fagc5aNqYl5pkxBwxVq3vTtYylEGLiMSw4U8cL\nFIilcEw1eHNlG6y3hCVxHidCyfjGVmdaqakhU4jEXFkVfeP5xXrDpumYYiRrXQ7+UK8fXdFtrUMV\n/vj0iLOGbw87fnl1S28cIWd+9/jAU5k4jgvOOX51vWbwjv208Pv3j8xp4XlceLUeaH3LTe95vTvy\n9pR5u9vRec/tqufFeuAYAh/OB96ezkzTTO8c22HNdt1xDgt/OTzzsDvgnMOL5Ze3r7CNJ4TI9/c7\nolbt49W656bv6XzDh7GOEVKuR9i7Vc9q6Ilz5HE6fTI8rLxlvVnjVBg183A8MeeCBTbe0jpHLLl2\nW6EuuIyHxtXEgayV+jVTKEYJl0BEp2BVyEjd5BdwQ/609dZgoBi8zZguk7BM0XHTLEzZsnUzMRi6\nleG/v3nNxs98c37BSmZe+JmkR4q1/KTZkYNw1Y14k/k+3lKS8Flz5G/1A3+aXnA2LZ93O9bNQiyO\nz5sjGzuzLz3n7Dmmllzgxo18tXpGRJhK3dDfuZEYDMEo63ahawJTbjktDSFZOhdZd5ElK77NxFQX\ni9PZQ6Tmu/WJXEw1WUSBmlGJomSvVVvnBEl1oWiMVMWAKAWDuXB/MXUWHp2Q0DriUehsFYhEhWPO\nlDpipgWsdZSiLFJVF8XWfYW7wHqyhYoBqhhJ1WrztdljBJIoQZQBg7GCZEfrXL15FqWQsM7QGmVl\nbSXRYSglYQwYUWJU2uLIRjmnxJgjiPBZ15ERnmxiXwKtd7zqBrCGt+cjr497GjF8ub2ibxqew8h3\nxwPbduGmHarS4Qd8/eiKrlLnX5+vN3U5pJXmdD9HnsYT+2nEGsdX19uL+8vx5nBgv0ycUuCqa9kO\nKzaNYy7KHx93PBxPRIVXqxXWWV5uel4/H3h/OjHNC2vXsB1W/ORqwzlF3hx23J9ONQLFGb68vmbT\nd3zY73n3/Eiaa3Dhtlvx9XbDMQYe5onz445cMq0Vvr67xRrhHDIfDkfGccQYy+A9q6ZBrIeS+TCP\nlFx16601DN7hxXNOMykqsypahL6hdqqpzhVjgYU6z0Vq51ssmCgEVZLJFAdGCnkyn7riYpUpOl7e\nnFiyq9zYxTJGT+sDdih8GLcsU8PfDG9YmcBP20cOc8+E45erJ75udvx2/JxQhLUNfNXtsM7yqgmU\nkLlxI8PqDf8wf0mKhutu4n8c/sSfzy/5Rl/woj/wVTMSs8NL5otmz6NdMRfPvFiW4lg1gZfDkYxh\nyg2heLZuokSIWLyrC8AxNYyLJyaPM8rQB85Ul2H52NEuhhylpnJ0BbUXAEQEENSCyTU5AgtZBIwi\nqR6jc4HZ6idDSVsENVUxApAyuLaqKryxSIZYEkmFbBTvHTZB5yqAJ5hM0YhiGcRfNLfVxZb0Evdu\nPbHUhVZM1RbtStXzzgrFWETrArc1jlFn5hKAzLbpuG077tPErkwQhbX3vOxWPIwjH8Yzzlp677gd\nOk5L5MM00reea9/RthW+H3KmlYK3VZYYU8F2/1Wn+4O7iipODNu+++T5/o9vvkcQshZu1is+H9bE\nXPhuf+DD6Ywxwpwif/fqFWpgjoF/endP0cIUE9ernhfrNY01/PHpiX94d+Q8BVKBn7y4pbN1m/27\n+w+MKTLPgdu+Jp/2refD+cybw57DPNMZy1XbsNlUd873pwMP8wQJeif0pmfVeeYUOEyBaUk0BtrW\nc9W3rJqBwzIzhjNxDqDgfct106GiTDlwCBPp0mX1TjDeIkUJKJnyaX7pqMsULmiEc1FGo5W74MGX\n+lhqVJIxqM2IV1yTeTp3GAzGZKxXPoQV/2675zn0bIaZlCy70HPXnDBt4fV8y7vzNS83I1u38HfD\nW74r1xxLS2+U/+X2j/zn8QXv8optt3AtM2d2WOuRaNiaxH+3/p5/mL8gBIcbZv6++5b7+Yo/H+/w\nTeFVewRVbFI+c3sOpqsKCHUcoyEby00/0vjEXBrG5GlNwjolBYM1StMkioUQXOX2quCbXBeM9sKB\ndZUjYbTaedUquaEaSWK1PiOQjSIfzSQf68xlmRYu/ytVvAARvKdykS9akhohb/EYAoljTrUDptBK\nS9RS8ZBFMCjOWzKpSuO04kfNRYOtIjXiyCit1sDLSQ2nJVAunfe179kvEzELY8rghK3vWEzGipBK\nwVhD6y3xUsxNMkwpYa1hjJF10/Gir8uzU6ngneu24xfbW1pnOceAXrIMf6jXj67oWhGsGL497AF4\nGM/cdCuu2hYjwpvjgd008/3xyJQin6/XDG1NTHiaRx6nmYfTCSOKx/KbuxdMJXKOgX98/0jIipbC\nq+ttDdlTvSzPzuQErfV89WKgN1Ue88eHR84hIqpctQ2rdmDbet6dTxznmZLqSKDvWr642nIYZ+7n\nhRATKSmr1nLbrDCmRsXvxz1zyrTWsGo7Vr6t8zsyD/OEKxXQYp2hcxbnqlwoaGDRSlQ0Qg0+FCEY\niLkwX2iP3gLUuG9KdWbRVa2m8bWjMwghWJKtGtPGFa7awOvTNXNuUMD7wvt8xd/7NzxHT98lBo08\npx4jysbN7Oj59nzDX/XPeHnm74cHNimyTy0Rw/988y1/Wq75p+mWYQh0NvGr8sgjVzQKXVF+M9zT\nucDr5RYKfNXsuM4j351eMIrly+FAS2LJjrWbgUJURyFxzg1BhVW3kEQIxTKFOp7ofaLEapXFKrar\nLrkcqB2uz2gjNQUDqshZBWzBpGoVVqESfqijCS7uvSCXN350U5f65kOp7mtspWpKriODuVT5l/UX\nB242WIVFlSgJNULJWscZzjFfEiaSZnrvq6U6mhrNTsYZhxFDMTXLTUXxVmiMqQ18DoxFcQo3Tcvz\nsjBrYhcnnBpeDStOIXCIgQ/xTOcsr7ZrxpB4DhP/vIs0Tvh8fYU1whQj3512XLcdG9/9oAsu/AiL\nrojgrfl0N22swV1oWB/GM+/PI8dl4qbrGbzjZb/ifp44xpk/Pz3ROce687wYVrTW8u505PXuwJIy\nrbF4p/zq7pZDiLw7HXh3OtIWR7GGn19f0bYN+2XkD7tHUsjkAtvW8/m2gnTej0fenhIxKU7g82GD\n8YYpJ17v9zV4sihD4+lWDU5gKZHdYSRbQ6PCpvF45+iM45SXutEuihcDDta2oUgBrdyHkEuFiAGt\nczgnLKEwkxlLocaIac3wUiEBcyk11kwUW6hzwVQz1+bJsd4sLMXU43SqKQ4heaIaFhxOlJ+uH/jL\nfMs+9kzZ09vEEz2IEtSRnOELu2OmBlx6Uxit8B/2P+XKz/xP7PhVu2fYJr6bWx7zwH+7eeKYT/z+\n/AIvmY1b+IIDszQ4MnkxvGzPdC7x+/EzQrSsfODr5pl345anNHDXnenbyJIcheqkc7ZQEJZiScXQ\nt4mmjYTkSAXKJYQTI2T3kUFQJVmSQVMGV+cFxXJBX9YoITURNVIJagpQ6Wa2KryqJKy2tgCXpAmt\nPwdT45GMQqvVJhxFOWhAqY1BYyxnSk0EzlqXvA5iqjZvNRZva6GTDDORsUSctdzYHmMNxzBfcviE\n26HDi+W0BO7niZQynXesmpbzUnceGaV1nrYxlKzMS4UlDb4qFcRU1UWRQk6QSuFpGXnR/7BhN/Aj\nLLoAsSg/u7pGRBjjin94944pRqYUKWR+cX1Law1P48R/ev+OVAoP48hPbrY0rqF3lj89P3EOgfeH\nY72Tr3tu+oF9WPjn5yeepoWQElvv6Zqebet4PI8cxz2780xvoWtaXqwG5hx5TDPvD0c0QWsM133L\nqm3JpfAwn5iXKhi3AnfrnrZxzDHxGAIl1hd7a4W71QrNmbkU7k8nIoqxhrW3iFZod5BKKwsZWjFY\nK7Q12IyMMsfEolTXlalypdY5FlGmnOryxlWOgjOCiR5VJZHRBqIIc7YV3phrPM/9ccWv7h44pYac\nLVkhZcMsjlQM+9yxz4bf9O95E1ccU8+YG3q/sCicCmwNvNMNP10/MhePAo3JWJf5D+PPEYX/Zv3M\nnZ/5d6t3/OF0w7uwYetH/n37F76ZbjjEgZVbaGzkha6I6piCw6N8tT6QTjXJeO0C183IYVrzNK6w\nDoYmUBrDKUjV1Go9jptc5brq6xKMj4kbc62GeuE/VCeEfBolSImfqGyVdgb1HwAI5I9sYmcgV3g6\npRZV/Vcdr9XL5EKr4yxLwVlbRx65FloRSFR4jWiVsqEF6zyIEDWhBkoplzBJx7JkJhtwWqVxFdgT\nkFKj6I27gHGsVKSjOE4SiGSmErlzPVvxvMlnwlzAKHdDz6thxdvxyLf7ZzZ9x2038PXmGmsMsaR/\nw0rwX+b6URZdb4RjXIipMIZAYy2dc3zht+znGYPyzfMzT/NEUqVrHL/qbmmt5X468buHI+cQCKXw\n09vryyLK8c1uxzGOnENk41s2reem6znGhT+dnjhOC0aFVWtZdwODdTzOZ/ZLIKdqx111ls+3Gw5z\n4HE5MsZCqU5UXgxrGifslomHcSal+oIbGsuV7whkzjHxtCyYAq0zrIzDO7BSu+WUM7nGd9FYoXEO\nrcgs5hI4R62gbgtb68hGKFoYU6rglE+d1QXykoVMoWiGphoErIXx1KK2SqvckMlW2IWepOZiszW8\nPt3ys1fPNCVznzZEEZbkyMVQgFE97+cNX/sdz8XwPnmO2dLbiDWZ75Ph123ij+Gaz1d7UGFfPF+5\nheKVf1pecEotf7f9nq1f+Hn/zOsoPIYBFcMvhwcelhVv52uuhgkjhduu4TGuGaOjZM/KLxQKp9DS\nqMG7iI+OODtCdPgmIa6S6IpWR4JQQT7ykf9j+ARvJ5VqS/OKmn+VpvHpV7xA3LmwLP4FIhQ1YsSQ\nLXgxn2iZyiWG3QCU2vWKZRIh2+pWS0YZRFCnlKgX92GdUUQ+IjAvNmbncAIqSlElGaHFsLWOsy7s\nYyAUj7fC593AY5g5LoGpFAZjuel6DlE4hpmpFDZNy+3Qs4szxzkw5x1ODC9WKzCFMS/sw8TgGjZN\nyw/9+lEW3VXT8O1uhxjDcQk4a7hqW5aUOcwT9+czzlgaa/nVzR1TSSwp8L+9e0tKhZhrcunLfqAI\n/Pn5kafnM+Uye/v13R0qMMfM7w8PaM6cY+a6a1n7lsZZ3o8nPsyROVSf+U3naZxnKoHX5z1jKEiB\nlTU0naVpHHOYeQiZGCvTu7PC4A1d17CbJuZcKmEqg/OwcT3ZFnIpnGIgldr5DBfGrxdLUGXOSkgL\n2UDnpeZ/GUuRxFwi53KpG6aSpkQsSQuzau3quqq3NFqHwjlUGI4AtFBKHbs8jGtCrjSvrk2Ih7dh\nSy4Gg9Jo4c18xb/fvubazFgpGFGm4smXo7OVzG/nz1nLjFl9y+8Xy326cCOt8iF7vhLlQ1wzdIEr\nGRlxeJNxKI868Hp8wcZNWK9s2kDMR3ahR8XS+cQNI7ulATJdU8gGRm2YkyMEhzWK8ZmUKvQGq3UG\nmgXmOog1vuqaJQG5dvwI4C7VWC4LNdULukI/6scuvwtojSVCcsVpGqhhO/WPmcqH9mI+wd0+iq1C\nSSQtNE11CJZYwDVQQK1QfCHlGiLaXohylgrqCTEx2yohu2p6siiLZg4xYNSwaRuMCDElzimxpETj\nHcYIS86XSCdomouxglQ5yVRVhgoYZ7ntB/ZpYlwiT9PE0QY+G37YsBv4kRbdJWW+uowXQkq83u24\nP4+8Px05LoGhabhqWm7o+P544GmeeJzOrFyDbYVXqzUfxhOHMPGH4zM51+PczabHWcs5LLyZTkzL\nDICXhq9vNmTNzCXx3XlPyZUbe732eAzOGO7DSEkQY+0mB28ZGktImft5Jl8E9p0VVo2tnysFjvPM\nErViax0MzSXKu0wcF0XKZUdjYO0sxlvmnBk1kLPWBZ93OJGKWCzKOUaC1uj11kJjpAK+Vau4XgBb\nxw8u1I4s6SW+LDg226Vuw4G0GM7SsOln1ChRHXN2WJT7ecMxtszZ0zeJKzfyOtziqRrYtVnYp46E\n4cpNKEojiaCeotU1du1O/K/nn2MUftk+8Idlxbs4ENRcwD2OoBZVIVjPV9tnDrHD2VxP+w7en65q\nokNfUIGbfuH+tGaOlqiGziViNBUN6bUmZnglJUu58CyMK2gjtZMt/0qVoJcYCTF1HoCgF2ANUtGJ\nai9xSJ92SFrJ7B9h75dCXC7/TZLI6hBTKFJPD0YqPS0JGFOVFFrq4lhMIZDRS9ovtDi5UPcQirEV\nHXmhmTljCRpRCrlu/Orndoa1b1CqWuGQJlSFa9+iVnleCg9TPSG+8D2dtbwbz0Q9Yozh5WpN33ru\nzyf+fHhi7RvuhhU/3VwBQiz5/90X//8Hrh9l0f2oA9TyEZqtxJL5YnPFtl1Y+4Y345GH6cS3xz1W\nhOuuY+sbMMI3h6eaOJEC26YFVT5fb3laTtxPR+7HI95YjPO87DuMGHbzmecUKKkusAYLLzZr5pw4\nxJkxVBi5Efhs06Io55y5z7lGUmfonWVlTQ2d1Mw4x4orVGHl6wt3sI5dCIhCyZdtt4e1eIqBkCtv\nIV+e216gbw2t98RS2MeFWP7FPdW0DU4MY5yZSxXUI9VAYQ1ApWuVj5peFNco46nBXKyu4uAQWq42\nEz5XRmtKlufzwPZqYrCBiYZjblmbmUPqeA49D3FNYzMrF/g2bNmYlnPp2NqZWR1vwoZf9Y98N1/R\nSCZjCKU6z9Z24n38mnnx/KZ/z/u0YUyeQ2oZXMTawlQ8WzvzHAderI+cY4eK4F1BC0ylIUdhKr7e\nfHzCxkJMdaNVtBa3nOSSnnEJ0axijqr7sgptnX9r+vg2xV8AQZq5xAVd7Lel1I8RqMX2X8YLVpSs\noXaKRjAlVxPyRakWqO9ubI1KH4tiUaIpKBVKpEUo/0p9Up8gSs6pNt1GGGxDo4ajVs4IUnP2blzP\n/XjmYR7prKOzjs/6FU/zyG6ZaL1jcA3rruEcAueUSKWw6T1d03IMM0/TSJstiqE1dUNY4Tc13Mn8\nsIULwI+06K6ahte7XXX0pHoM+8nVFQK8PST+vH9iLpHDNPOrmzuMEzyG/3T/llMKPJ1HjDV83q/Z\nti27ZdhaDkgAACAASURBVOL3+/ecYmTJieuhp/WeDsu7cGSJkWNKdOLoW8vQNMw58iEcGFO6QFEc\na181juccOZdYF1m42qH2NRRxJjPmjJR6RHUGVhYoEAQecyRSG6OtlUun6jjnSM519ufLx066q0qB\nEnhaJjK1KWsErLe0RhhF2YUFqK/9FmisQ61hiYlsM8nXrs0axSDVzpqkbt1dtQY3TeHdYYuUy5jA\nV1XEV2JxRnG5IAr344qf94+s3cJTXrPPAxZlLC3vw5r3YU1WwZnMXhu+Wa75Pl6zNjMZwzfLLf/D\n5jVTdjSSq744txQVnC3EYvnD+YaNm3AW9rHnKQ54U1MUotav9xwbmiaRiiHEyjzImPorG1IWcrKY\nJmMdpIWaVBE/dqul5r3li0niMkKh1E6SS+KDcpEeAGhGrENNdX+hl3SMS1CjmArXwdUbau2eE2KU\nIvV9RRs0a30OSDVmZJNqGsenOA8lSyFJHT05e6GW4VhKBdtMriaHbLueJVVH26xVBXPlG4oKpZQa\nRZQL3gpRC7ZAT8ORCVVlEU+Tla1riDlxSgESrFvHL27ueHs+sI8z358OtMbxxXr7b1YH/ktdP8qi\nm0vhbhhQrX++H0emULkJh2Wuri/v+avbO5ZceDsdeTceOadARvmbl3eV/q/Cb4/3LDEwxUhvGzbd\nwMpa9nHhfRyZYn3hX3eO1lmcdeziyFwyWQvOGJx1rBvHGDL7PJJV64bYWAZjcBamlAlaJWZGLM4I\n20YIRVms1uKd65a5s0Lr+LTpPleCN8UIvULrLaDMLERV5qUuvxpgaD1FhKVEnkKhSC3gTsA1FqPK\nVBLzx1e9aNX+WiFHQ6Ru8ptVvHB5FXJNeVCpcrOaSmFYeeUvxxtCdChC4+u/+ZxWXPkRq4WByP20\npqxg7SJdXPgu3tGXQFTHn5dbHlPPLq3wJnPjR/60vGAfOqxkvM+cSoM3GUtBBFZ2ZswtWc8UDL2L\nvD7f4STRN4nzrCyxYYyOxinGlgq40XoUN74gF5lhLg7RgqYL8jKZOrdsDSwXl0mpHAxjQWwtoprq\nIV+kFmLxtdtD02VOUDtQb2ONjDcfY4aArDijNTXXFyTXzthaXxM6JIN1dREmVTdtCpDqKMKaXOPt\nczVWaKrKCZFE0UIugmj9uE4sQQo555r22xhu2hVLWjjGwHOYMAJ33YpQEvu48CEciEm5Ww3kUjik\nwOvTntY6vlpvsdawD4Fvnh9YtR3rduDL1RqVGjDwQ79+lEW3qOJMTeQtpfA0wrvzyLatXcLtumMX\nF47zxD/cf6jWyZK5bltu+xXHuPDutOM51PA9RPjN7UsWIksMfDM+oihJlXXr6F2LENmFiZgyIRes\ncdz4FjFCKJG9ziRTJT0NDtt4LBDywrnkS/GzNMbhMLhGOJVQj/sFDA7jYG2ERHUtRT6qC4TBSI3V\nNkIgM2UlpfoE6BuhdRYRw5gDc6wdsQVaJ7RiSZKZU664R6mdMgasWDSVuj2nHlfVQFkMYoWSIBeh\nxJ6bqzMiUotXFpbJYdpMypZsLTEbrpuJt8sVfzy8YMkefynEf1nu+OVwTyyeazPytKw4R8+mWdjm\nib8sLzBZ+Vl54t2y4ZhaPsQNXgpbP/E63CIF5uxpbCFLJKth5Srpau1mjrHBl4pk9C4zTSvmTA2l\nVMFQSFlqTP2Fi1AUSqzLtI9zXs1Sz/rZfIonAkXU1aRhVTClNqrkWmw+5iNJPWZz+XvVi3+iKEYT\nOTusr52ndZUeVoyphT/H2j3bjBpT6zc1kkkdZJtRzfUGkuq4AgulVPjNQk1T2fiGBsNSKhY0URUI\nG+s5LpHnqbo0WwybtmeOMzEmcELvHK335ByZYuT/ZO/deuzIkivNz2zv7X7OiWCQTGaySpdSV7ca\n3Whg/v8PGMzjADNAo0etaUl1UVXeeYs4F3ff28zmwTyoh3mXgKQcSIDJIIMRfuKY2zZb61tFK7NW\nDrVwGYNbN4p7OvUorDb46nhkro3heU9+6dcXWXTnWvn+8YnVcub0aVn564cHKHA/H/jh8RM/rxe+\nv5z5+ngiCnw1zfzx/JHfn3/mT9e0k8618XKaUFE+rBc++Jnb2GiaEJFf3T2w2I3hVz6NK1EUMeHV\nIZdAIZ2rLfuYT1FVDrVSS9BtY41BaBBDaF4y3kWMLVKuls1QpUpJETrGEhtbODbS9IEKLyu4OiOE\nJxu57fKdJFayMFx8MHryAQThJDDPlQAuPlLADqgGB01KGxFs7my79lMb6A7GGZviKnnkrQJqnG+H\n7A8DBsrqjRd3K7V2rgYRwtP1yAvdGF7YUK69cd8WPo47/vefXnIZR1oZaAl+t3zD/zZ9y3kceF0v\nPPYDPy73/PbuHUiw3BrnKFQxHuXA4o0P/YQaFHU+9uA8Gu+XO0oNjm3QXbmfOsumzMXprqyjoBr7\nUk4Ym2Ke/xXJhODYqXWxa7jEM1WiaJ4wcN3HSJ6KhX02Hs+pGxFUcYYXijhSlYFn6sRWUBlUsSzs\nxREqw6CJJ7dX876HjVQ97GoJKelu8wiKKO679ThSa7xFzxEWmfXWw+kS+yJOOdbKZQwKhSBtyKcC\n23AepomDwuOATTfCgvs68UoOrFNnsUHVZFh/c7ij3y68X84c55lX84G/eXjJt9cnvr885ey3zvzm\nxat/s7rwr3V9kUVXgFZSulQjeDEZgvLt4yMflisfrjc2nN8+vAKUJTb+eHnHT8uZj/3G2+M9LoPX\n85Fvlw88Lmc+bp1aMnPtVTsQYjz2D9xiAzKCpxTlvswsccM81QGiSnHl7rnb8ZU1ep5KKZQovJgb\nIQP3QRcnLI+EszZAqOJsccU8MCtoFKrmoiKKc90jwi2guqIKpynvheFcxq51C6VpjiZElKdtVynA\nbn1VCmBuyWnYDQJaMmaGkYoMM4WWqgkpgYRBpDXYDUoBWtCa89PlDg3PpGEtjA5vT0/czyvbpp8L\n8VYvmBdc4EM/caTTpfJ/vP9bntYD4Ewt+GG84Bs/s3jlpJ3Vg3friW/mCyIdJPjUT/t0Nhg+c/XG\ntlRaSYnap1W5Lo3uqRqompq5z3t1DXzsqEvNKhvPN2qTbIFboD27ZPWssaqDEMnjv9ccoBeSByFG\nU57/ICNyMTepZfdboDTFLAMo3R2JdMFFAA2CkvFDbgnd0UCmnYU78hswD+rkhBqxDppnUNAeOsG2\nBdWNyYVWCzPKFsF5XbGpMFf46nDiae1cxka/Dooorw5HLr1zHYPv1keqFB7mxs2c27byz+HctcpX\npwekKN0G356fuJsbTStfHU6oJP/kX5LafpnXF1l0PYJDq7ys+w8cwj/8/DM9jK0bb+/vCc0f7P/7\n5++59pVLX5Gq/Oe7txxq4dvLB/7h6XsGK+dw3t4dqSWlNj/3D3TbMmxQCqU07kvhZp0lHjOzqhpz\nKIXUxbreCAbbEIoX1AtHESwM0Q3z7ILcNUE0U6FIxyJYJfZCDMfPqiNjMLAINg/UGy2UeVZEglsE\ny0iamZgwq3KcC4axhLNuKWEoIhmkKDk+uO2owV24wCyJLdwiGJFRPbjnAl6yGLkXhECPlr+nZLdn\nQl8rvUvOO5tTJ+fb60sag26VEGEdyqfDkW9OF9Zbpeng02VmrZVhBRHng93RNuMvT53//vhXPK0T\ni00c55EjpH7ivqxUgruy8GE9EfONqTnzGJzHkWXAVI0Rycy99ErTHI9EJL9gdE2Ae3m+z9k9pus3\n5WZOSrg8hEIWSeGZpxCE7L5ehSaGi1AKWGShzCFT3uMm+6igJnw8pFLEKTijQhRHPMc1qvkQr9Uh\nBhJCeN2BGoZpztvNndA9iFKSKOclZ/UUuNeKUFhs49PYuIXzohUOUlnG4LZ2eu/UIhStWAzYYTlT\nKWhRbn1w7+lUlJrpxo9bqn1u3dgsITi3Ifz64YGX05HNU0HxS7++yKKrImzmnNcrZs772403xxOu\nwa/vX/DpeuPPt4/84+M7PAwR4T+/fpMcA9v4v959xxoDcE7TzKvDnHBte2S1czJMtfDVoVHUEIJV\nPqLidNM0ekpjLmC2EToIdE8PiP3IbyALYCy9AiU1uASdQZGNTuC9EK40Ap0qKgOjs1k6jtyEu5rx\n6RHOEh0PIUdwhSbKdBCC4BydbdsTEhQaQkuTFVvsGWqkY20CiH0xvwc1IinyFxViEwLf2QXZUetQ\ntDgxBBvK5s7hxZoMWkktaKxKp3K+TjvxO2jN+HG959N6YBstwepW+NP5gf/06gNjO9Dc2EblaZmZ\ndKAaXEfjtjReTAt/vr3iulUufeY4jexo+4FXLKzWOGjnbI1hylSNHkKgLCb0HcGouoc09gKRYxLB\n0bFnpqGIpeNMIxIyU5xoAiMIK5mwTCCaXbXm0ywtxftQ9zhvjO1Zh1wTGq9BCfDIguuqIE6twdpB\nSdJXPEfeWwUbRHiqIpqDllyoGZlsXDqjVmIo1dIMZOTHIhL88Dy7PZUZkcJK59wXtgjelBPHUvhp\nGbxbr0QID9PESSqr5lJNq/BQZ+7miR9uV769PXFola+Pd7w6Hnl3u/Dt5ZHL2Hg9H2n6y+5y4Qst\nulUViWDtmbYaERznmQ/bhfe3G9+fzzyOhV8fX6CqSIGfbhd+7J/48XamqDBL8BfH11y5ctvOfLdc\nKPtc7a5OvJSajq64UcrKRLAy8bppbolZQFbGEPpQGi0lYC2IWIgwujQgo3SObRBhiBizBNetoaMg\nFY4BQzpIZ4sgtkZYLkrmWkA3LFJPm/mXhbuWXY4S3HYtsNuzdEhoVegenCMwDwRl4tnxlOVh850a\nIDBlaA3AZxMHoWglS28RbC1gJeEuJbAQtm2itoEPpS/KNoQmRpmC7pmeMLbC0Mr5esB3HsE0D65x\n4O8/fsM2Co5ioWxb5b+8+REsu8QeyqfrgdfHGyJZND8uRyYdGMqH25HLluoGKZnEWyIYts9rTTM6\nvD43qpqEsJ5LrkjvbA68gfCS89BdrhZRYDjhkvKuIrToWDQgF5NaAyJQHKQQUVLjLIGK0a3Q1CE8\n72fZWb5RMtetjnyoquBrzoTxQVTf1RKaI4aSc3hKYGrYiAwS3b8FJOijc50Kh0i1zalNfFpuPK1X\nWm3cqfLVfOLT2LiMjaULE8KLMnPuC5exsjFoKhxapYfzuCz0yJifqSmihdWcS984TTOTJtehqP7i\nCWPwhRZdi6CVyl+8nDOuulZ+9+E9H7cr765XDlLQ44FX08w/fvrAu/OZH5czg85fHl4y1eSSfr+8\nZ+UTF+/czxlH8qLes3BG4sxTD3TfCM+l0BBEVkI2TmXL42RMHBtUveIBiLOaMnpFItNwWzHQwAzW\nqIQrhzKoNZdjok4NWHpFe4EaOZdlMOhsppjVzOfSPMpKGXTzPFqGoFq400zkDSIVE4C70CRRgVWg\nh9Aj0YSFxETWCLqAeMakK0Fo5qu5Z6Ea5vvoQTKiXQTEGVuln9uespDl/Gk5cDps4IX1ovStMsWF\nNjmLKVRh2yoHNc7rCQtNlkQLpHb+4fEN0YUtlBBlmPLyuDLXgW6BYjzeJl4dl70Td679gA5DC9y2\nyui5OJJKznMD1LJoScSu/YLPbrFQZDhiTnnueqvss9QEK+TyMU9BYXkLmjjbEA6T5bJNHRPJ2WyB\nWQcueULIcYFQYrdYM9BiCJJBmLtJQ9UR9jGJKOF7QY+SI4p9vFOpVAqbG6M6myQ78khBJdjMWfpG\n3waH0yGpYGZJMhu2n8iSk3GoyuIFK8EQw4fw1WHi43ojVNh8sEXwN8eXfBgLPVJffGyVv334FS/m\nicUGHvHvEPNf4vX8kq5bp+Pc1o1Zla8OJ74+3nPbNj7ajf/x8Qc+rTc+bStfHQ+IZDLqP17fcR43\nFr9wbPCy3jMXxVm5xXvmduE6Ci+mQhXnUBroE0TnaZvzzVPSfNCKIbog4kw62LyCT5zmQdXOHmfG\n1hsRymHX6tZqKI5YqgBGFObinDRS4qaOmWKjUEygwkxgkfKzdQCmQGHWtACn+iFSeO/p8b+vuXUf\nDhf39O4jTPuWXhU2E8buimqRH5fd+Rq2F26UaLmQirGLf00zQXgIMWXBNQliFD79fIeWLBBago/L\ngTtZwZW+Ceva0AGHY2cMCK2sIygo63qgj4IJ1OYUcf58eUF1Z+vpzBsOq1VO02DrlaqDZam0ydNp\nVpxbb+iaDyVIzoQNQVqOY54NDOGSi7m2W3StoWqpVpANEc2xj7WMrg9BxCnqVLKAh5SU3InS2EDa\nPnZQtPm+gBMwqGqpoii5yLMtX6umaWnUZjAqpUNsCa3R2XHJzlZNCBO8DnrZGAbTEKQ4UwgVTWKZ\ne4autsqpzWDw2Dt1uzJGwoFECx9vKz9uC+KDh3qiVOGRlffbjVaEh+mIFuHDeuNPl0/MpfH27p5D\na5y3le/On3jcZt4cj/lg+oVfX2TRLaqowI/XKyrC++uV+8PMJoNl63x3eeK93bjTSp3v+O2Lr9ii\n8+f1Z/7u/B2bJ7ru7eGBVlcknD/fPhHamcrGAeFOJxrKkAtT/cipbXxcjzzMnUmy6zuUG+HO+/UE\nCC6Vg0Z+TpJ7cBJjsQpVOFRh0nSOrV5ZtoaKcNTB4FkPCm6FZeTyZ9KgHQaO5RvJhN5TFVHq8xss\nAeb9udhGZao5SNgwxsjVTqF+Bl/heWJYDcRzYfQMxwqB4TnXRHZ6oYCMpG+pSpo8nqtyejVyGbQV\nooJaan5dPEcYVvj08wkpz24N43FtmWIhwuiwbQ0T4f6hJ+gnFLNc5LkIt9HwUKKme+6yCdd1YvSU\ntqHK6FDn5FE0jYSSP3OBI+Vd1hU89pN6oJamg5zLFnb1bDbAooBjkl+reBZNR6iaD6SoeyesyTdo\nEogOvEl22V2ZS08MQ5HPzjn2U0WRnHsryrB8jfPBEbRi9JH3NlTBDdlhRh4p08vU51Q53MbCZWTU\nz6FMHEvjyVYutwVtlVMpnLRxKZ1zHxRN6eKEcHFYxkaRioRSVejmnPvCkYm7VpnqhBboZtzNE3Ot\n6ZwTT0bEvxfdX+YV+4r01y/uceBuavx4vvDjcuHbp0dGBEbwNw+v+fF24eN643eXd6xx5q4ceXko\nHIpyHRc2PnGxC6VVFOdlPTHrFZONj5tR1JhJhcQsKQFS3XjRzjR1flxf8OZ4o5DF41g3Rhd+Wu/A\nBK2FYzEmMVSc1VPcT4BOeVzNjlg598a2JxQctBCkWsAB74VhefyfizDXYMPZLDkMYUqTgpSKSC7r\nhgdiAqEUCm2f+627ISQc1CVhOiL7KTrDKsuuatCdIxCRBaZanrjLfjqPkXE/ospuzkLc8ZYdt6DE\nbYfJuGCAtmevvvL06S55MhpIgxtCLMIkeVxft+xG7+82igbDQTxTiqNIfu+xQ8h3jOK4gfdUHSCF\nyLKPakrb1GPv5PM188i5aSF2pUCnVCE88JELUvWgSaob8uFYsP3M5SHc1c7YKmCMKFjNmbK6IbGr\nPzRfb/Z8thIwiWNqyLR34WP/zO5YCaRVYttHPvtMKB9qSTmbvXLF8eJ0scztE8XccIebGbelcziU\nPeoqF25PvXMzo4lxbBMv5yMG3MJQ61SEB73jZ7uyxMBGMGvjm1PjQ99YxoqocD9P/IeH1xxa5db7\nv3Yp+De5vsii+/w0lR3VLwJ9DKoWfvvqNds22IrxD08/8/31iR9vZ+aa3es3hxf8sD7x7nbmk3/g\nxbxS5MDLUjnVoMcTd9M7NoP7mBGBAxMH3WjTE++3mRLCfVXUnRODSQdVnRf1glP4rr/km8OVqoaF\ncmiD61p5t9yl9VQLhzKY3XCcm8+AM6nRSuAViq5sJlzHlMd3VeaI9OmTYxXrBbOKuuZ8rkD3jS3A\nTFBPu3ElhfqLO4wEtGgkbls0MhUiUnTvIanO8HzQ+G5nZUAV2QlWzti5LhaC7oFs6e5LFGQszxHk\nOd+MYcgxZWw+CixZPCDB21JyISoqLNfGbUxZNFugEVzXlrNxF8bIYzoHT8ctWWx9yzSRiATeODUN\nKAKMoIz8VmKHiFvkg0L2Qqo2cJQwZexzXiSTkmfdGNagGENK4h9JCPvqqUCRktjHUMd6pVZHFUoJ\nojlbr9goHHqnUihTJjUnUjLvRUhQakKFcKHfUo1SdO+ee7Ixhhmiyd0YHozRkFmoTZlK5akvmAS3\n0dFauWsHFtu49DUTiz14czzSNVi2wY9cKAr3bUbEuQ3n/bhxXyvT1JKP3Fe+v55p2nh7d0etytqN\nHy5n7lrj1eH4r1cE/g2vL7LoAhxa5c8fHkHgaV2YWuOrUiki/HH7wE/nC+d1pQT811dfExJsPPG7\npz/zoXduvnBfJu6qM3Pih+2RZVuZ9IZiWJx4MSmNzlwu/Gr+wMftgLXnbb9wVzcm6Xy/PKTTq1Ya\ng4eyMGlKAE660Kk88pKvDxeqOoZSSnBeK++XewTP0UTpqAbmwtUnIlIJMBXBqiDSsxBvEz5y230s\nqfVMw0MWEutp/Q1AJVi0J4owBCmFum/qOymqD0sXG5FSMtuXbD0yMIEQSiStwAUYKZsLS4kcAWZB\nFNl1wyCbIlOOMFxz1sktO+jwdFh5CfSU6QthOb6QCHzn03rNrjYkgyR7L7nlLzkKWXqlFgcTJDy7\nenVomarr5IgkZ9A5uhEPBhURy1FBDMQErYGV3fDgBXHL0UPJwmlRcz4tSWczh7l2zApFAy852kDg\nxOAiFVXbnW4gI5eZFKPVTO0doTueV5hqz846GoOa/ASCWnOpZbLPh+N5Lp+CCyvs9DPBNmO1wUUW\nTIWTFg4lO9DrWNkMVCuzKDc1vAibDQZO8+BxW2l3JSVnkvf43AcPUihNmWvj1CZQZdLCJBXTvgem\ndt6WXz5LF77gohsBr04zHolM/LRsvFvP/P7xI9sYfLSF3zy8zHwnN/7+6UfO4z09gmNR3p5eA0bR\nK1f7I04DMSadqMzcaeGTdVbJDrTgeFQe6mDWjXu98vXhwrfLA6+nc44fMF62laN0fn/9KgMiZ6Ni\nvCpXqhguyqyda8y4zHy9d8RrpJHgtlXO47hv1oWDDlAYYdxswiKXd4pi4Xg1uguj11ywRDqRwh0v\neRzHEj1YyK27lXSwZRxPjhdSeLQf3/fF2262+ix/CviXWS6yF95gKOBCMaE4DNJAMUZWbyULqvaC\nK3iN/CTq+O2ZNasEnqf7I58VBdqzUFJ8nx0GESlvE1dsy2WWAyaa9t2eigzz1MSGB14rA99HLb7X\nYNsttc7oLTtiU4psmbKrRngu9GRnNxQNaqQjrXuGUoY4GoPKhHtgTSnFiZJdeOqw90VYLUSBvqXl\nuhajI7l364JuOUaQAJk8Z8rDEUs34FDDasnwShfoCe4J75imrjt2lcYQYeudZWwpPQznUGdaE5YN\nPiwLCty1ibvWcik9NpCS9t9S+Xm78VSCJol9fGgzn/rKT5cLL09HHubGX90/JCMifvksXfiCi66F\nc2hpo+2l8PPlynUbvD2ceOobX7cTP29n/nz9wO+fPu46SOc3p7csvnGzle/W93x1fMKiclcLd60g\nYZzqBeVnrstLFCd9Z4U77XwYE0vM3JcL0+7keTMtNBk8lBsvpo3f22u+mc+fj5/3beFQOr+7vEEI\nXsxCxXhdrhR1tqjMsnGxmZCJV9NCVWfxwhaNtcO2npBwSlSKBNI6bsawlgshdgH+lCGBFnkM1/Ev\nJdXoWexckRSo5g9QgIkRttOsPOE62WslM1YAHIrvkd8ae6csnyHrVYQtyIJZ8s/qPpNFwGrsFKo8\n3ooUpHvyKeoun6iCdygju/TQHSCje/sdggynheNq6N5JumSXKiil9z1DbOCSEcduUD3SFszATbB9\neZbUmKCEfVZ0SDihdVdwBFPJ17pqAn6kZDeYA+lGk55JvG1XJJhQQ5lZcRSdA1sr7pILx4B58mQa\nr8IYiZykBodiXHvJGS+FETDXgUhBKYTtDN4IHGMbg3mecomnwiTK6p2lb2griAj3tfHUNz6tC4s3\nvBtvX75g9c5qHR8JYz+WmZUOoZz7xovpwN2xsZmxbIOf5UIR4esXD6kR72lOOtbK6+PpX+nd/297\nfbFF91Aq3z09QcDS82j25v7EqTa+f3zkx+XM99czj1vnr+4ecHWO9cB3y8/8tF146gutBCqV19Mr\nPm4Ll36j6plJb1xs5uVkzDKosvH19JEHVS7+KyBSNA+80gvf95cEB16UhSk6Crydzqg49yX/nd9t\nX/N2ekQQSnHmMrhp5Q+XNxDQJmdi8LpdCYLVJ46S7q8hE/eHjhB037h6w7omXMUTIylqOS8eilET\nLRmpqQ1LPbCKZNfru3yBYDz75I3PMjEhM7iQ+JylVnd2FhK5QMrmM7vgkYuwUf4FRSh9L9SSErSx\nd6MqQlkUIVMOvGbhwCXlZz0tuMKGT4KgDElod9kMKdmlFskCO0ryJNSNySwZwOQSzDRJb27KxGBy\nZ+CYtiyYO/cCgVoDH+B1tz1Loj+bph67aNCt0D0LMcCpDLoXnMGg5nhFnSkG3StNDIuC1sB3SqSM\nJKD16nhJi3KiLBzFMQ16KUjPnUFgVC1Q9wfEkmoDz1qPl7zfjEwYXsy4RM7876UyaXJwr2OwuVFL\nTTNFKxSBbQQjOpNWusPLNpM+ymw2zAYvfc5xQwlqzXTiFzUXiZtb6oS34Fd35V/lvf9vfX2xRVcE\nJs3NNFGxCGIM/ufjT1y2jffrmVdt5nWbmWrl95f3/Hi78tEuBJaJEAgv22Ab3+0yLnihhYiG+h2r\nB0Pgdbtyx8YHf+Dr6cIkg4Ms/MX0yCSDn/sLYvf7K8E39cwfltdYKP/p1Jm4McngYdpAhEPdcKl8\nWl/y9vAIO+e1FGexwp/Pr3lG/lWBh7riwMUnFKV4Op8O1aDlKODWWx4to6AjcdcDx57xgpFzPwmH\nkguksLSzPkO6IQulk42l7LImCRj7eCFk/6EzYSKjiZ6tqxrQnoHdCEP3hdVe12tPWE/2jJl8C3l0\nTlYagwAAIABJREFUpwt15EA4NIsJVpBqyA4Or4xEXUpjkAu7YhtNswopgobT60SQyoEZI8IxyTj2\nQXab+cYJZu94KCNqAnH2e6UITTaGV7R5zoFrWsKrGOY1Z6wlmRaBY15zaiKRWtyStLitF+ay4Xjq\nbWsQC9gGrcAojk6GW2EsleJGD2eecqA9ujAWpRpYC6Z8QVAH647tzIgOhOSS04bj1Vm9c+1GKx0P\n426qHKTw6IMfrldCg4c6M5fKlcGnsdLFOZXGLMLj6LxfL9TaeGiNhzpztc6PlxsPhwOv5pm/uHuR\nM+bnCIxf+PXFFl2L4H6eMqrancuHj7xfrvvxOPjLFy9ZovNhOfP/vPuRNTZWLryqr6lzbtb/vPxA\n2860ulLLxP0EjcZBghfzn7ncvqKHYF7o0ajROHtlRTm1lQPGxWb+en5k0o2jbPzldMajpBRol6Q2\ncb5pZ/7p9jUWwl8cPzHrxqkMJukJphZnYebcZ94crgBsUXERuhXeXe6x/fcUyTeRClerBNntFUCK\nYzVlX3SFnh0hlsd8AxBHLUE9bP+iq5QdQykRqUh49g/s97uSErPc+AdrJApRCPR51uuCK+lwi1xk\nqWWHGxaMAjEZIamDFVeKZ7HInDF/lhNQwyjdKdKz29Vk+Qo5Yy8Cs2TY4qqNzQWkUq1TJdMbJDRn\ntvsMXIFZVjQCQxlSGZ6Qm8ogXJjLyO9Zc84ekcm9Gs9AHUUqdEqiFgWOOghJJ9igYqq7HC+Rj0V3\n9yAZXe8izKWnthXYvNA3pWrOjXlGcXrao+vOUzAHU88HDjsVjmCI09SZvIKkSO7qxsGEKsLUCpfF\n+XTrXFvycE+HY6ZMmKXcTvJFrGiqY6JwP83cT42bG0+9o1IQDe4OLf+NzbhuG60U6heQBAxfcNGt\nqjz1hb6ubCOdN69OR35V73m6rXy43fjD9QPfXT7RarZOf3V8w3Wc+bmf+Wl9AjFezRsSb5klcB/c\n+IjIQgg8tJWmufC4KysPx2/5P8+/ZcSEeWFEFuifrfHkM79pj7kZ9gN/c/jIJMaxrLxtFxabOGhn\nRMa+HIrxVT3zT9c3OMqLeaOGcdRO0RVTZYRwiRO33rg/bBDJWehUzOC8HMBLaj5lXzapY6NkrI6T\n8BoDf9YGm+D2XCBjh6vnGz9sn5/avnHXz25fJt8LZzbMu4siIFKKFaR+FwSznc1LpJKBlKR5yxlt\niQKWPAjCsZIGizSONcSMYp6Gh5q2WqIkxEWCQ+TDalCwmBiePNxJO7hSZaWEslAxKQQFicEdPccA\nAA5D5bO5YNKVFqlv3axCcXxA005RKGps1tiK4qOQQxxniOJeWKPuy0FgHwvVXbRhYgzRjBIahVYH\nHUNaGlpymZmLOmrOy12BTRm7euP5nS4eRC9IQC+OiSHDsS2Zm4t3Zq8My9BRjSSiLd3ou6lHPChS\nuTvMXC4bqyz0LSV6Lw/3XMaabJNdwdH0mGkopD57KhNfHY/cxqD74LwX3V/f/7t64Rd91VLY3DGL\nz8uPQyn8cL3w/nbhx/M5fxBOL3g5HXg/rnxcP/JDf891DO5qxVV5VV/S7crZOzcPDpob51t/w2IV\ni85dfeSoC4924Jv2RJVgFuOhLMzz4A/rawYwamF45V6Mb/s9C4X/KB9ooSx+4G+On5jEqNK5qyvL\naJxqpuY2jEPrSAR/vH2VJoKSRelUNoiNTQqigyc7MaxyLIaX9NBfR8203CUZDWq6+/TTdMCe9ZWj\nhj0ZWMhEg+7I2CupfJ40pHPNSG2vgEk8m7tyr7Uvq32XzbI70aqB7L7i8fz5dpNFQ/CxT451pJKB\n/JrEgkK6H1Q6UbIShxQkghYbkf0dg7YDdSxJWjE4xmBDWWJK2hbClHGPVBmJwUXpUVN/68oxbhgl\n0x/M89coanCohkQg4vRdXWImNAYjKkVSI2Y1QQzugrgwi1NxqDB68hM0cj5eq1FrgKUdGpQoCQCy\nbWKY0m0fLanRLFh9YJ4deUqSBxGVYvqcJkQUYbNOicDc6BhzCItvLGPsyEq4nxqH2riOzre3RySE\nu9KoWukYN1sZ5py0oiKcx8J0S4nYi9ORuzazjMG7y5XTPPPmeMevTvcpDfwSuI58wUXX3LmbJlrZ\ndYVPwZ8eP/JhvfK4rLw8zKzFECb+18efeBwLj/aBud5xdyjct4kfl4+8X8+8OHxAfeahCRKVRqPp\nJ85+4mnk/y924LG/4OYzTYy5DI4KP40Dv5k/ohoc2XhdV/Az/3N9i7snPYvCg2x8t73gSuU303sq\nN7oX/nr+mGoEdVwKP9gdx7rtEJgMrjQ1frg90F3oUhGHg3ZqG3Qa69hFCV3TQab7Zn2kQiG28sx1\nJNiNDOaAomNXZ4njlH0RRp5t9xFd6N5t7c4slYzxich5bmQNpW6BSjx7KbLwVT6nH+hwjNglY9nC\nqZSUUvlIOLruhaxUwpSpdGRf+hGRhY8KBMXhpPnxHsqVKY/kkqjExuAghhNc7JjZZkWYZNuXiYnk\nNFF6b0ix5GPIkrFEImzRUE+pn7pzaCM1tO50WmqQiTRPxBEVZ0hNKFAEImm20HSCEwrbbjSexPBw\nvCgbivsOJiK77TjkvJ59LJExQaShwvdF4rB84LdBoaGuyQ22YGupRUaEeSp8WlZunqkRbsbsBxYz\nDlOltSB6sIUzMIYXDiLcz0fuppnFjdu6UXdjkorgljyKvo9tvpTriy26IoJ7cLWN7sZ1dA5T5bfH\n16xb57IN/tftJ/50fsfjWFjC+KvjS1w6N+/849PPdNt4fbyy9Dc0LUzAKk9cHd7OF+68cdSOolQJ\n/nL6mZ+uv+GT3TMRLH6k+5E1Doh17qrTKHyMI387vwOCg2y8qSs2Xfi75S3PioHhlZe68u32wDUm\n3kxnTmXDXXk7n1EhrcEceNwZs1UKYqm/JYJLP9Jd2awiBlXHjk5MB1F2t3uL8wxQ2HaZ0wB6ZBSN\naNrTBjl/3VVdz0uw8nxq3qHproHtn08DxGK32uaI2ITP8TgNcMul2ZAsn7k+l33eGmQYWG7zZbR0\nduEUMUpxaANbKtSKuFLpSe8ikEizictM7HPeg1wpwCqVxSsWqVGddOQizDJGaeHI8NgfboPJA9PA\nvSLF6J7zYwQOujKspmONtEdrRDr2NLW4CSEPwoLNpv2BYYlzrM7oOV5I1rLvi7mKr9nJm2cnnFMX\nx5ccDVFyjqpbznDdoEfSy1qteAcfTvhg7A+44QnZb5ownZt1hg/W7mirTCinqbH6SNB579QqfM2R\nxx3QvobRrKACc0k9MArHOvPN8Y6LbVz6xnS7cWj1i2DpwhdcdJsqmw2ua7Zq6xi8mA983BbOY+Of\nnx7p4jy0I78+vsA0OPcbf1x+x6exEBhagtfTAyU6H63z01iz/ujGD8tb1phpMhBJxcNUnJf1xj0b\nFWcW42+mT/zj+oonP3GUJzY/4D6x+Ixj3DWjUvloJ/7LMQMvZ1a+ni54wN8vbyGC7gVT5UVZ+HF7\nwcUnZk2rJxG8nG6EKJsXPo4ja8+FSZOdd1sdcyWiYIPPkUDojmIMSUYCe8x4yGfsIWu6uqplF+MS\nOVvYu6zyea3PXmzzKu4JtnGIKYttPDMbENSdEWSyzfPfqaT4H/YAxl0e4akJlpJzy1I7SMG9YKvm\nXJUt44MYSDXGVnEtuBTmsTKXFSkFN2WEM2xGpBPAqXSad1ZRNpkyvyycCcMCpponhU6hPyuzNWjD\nkh2M4k1ZyVsjAQfduO5yqi3KzoJQZtlLs5ArPxHMauposSy6VukUxtiBNTi2q03GYC/i+TnCLe+F\nZMtrksYQRJ5TmjLqKUaegorgREKC+oqHYTrjETy0A1EzbufdcttHPkGdGmHG9bO4L9jGoDNoQzlI\n5dXpyKlNbL3zfl24myZeThOvTyc8Aov4IjreL7boDk9zxN004RG0onz7+MSfbx95d10oFWbSefbP\nT594d73wQ39P1cZ9VV61I+exct4+cjd/l1rOKiSp4I4qxjDj0QunMhMof16/4uZHJumghsogYuJN\nvfGKhNrMavx2euJ/LG84u3CnH+g+IxQWn+k4U+lowKMd+dvDT9n1YdxPGxLB7+0rJGCj0MQ5lo0P\n64lbzDvDNzfid20wMEpTHrcDNnK7X1SwCEozbH9z0uXZCQuTgEUukTpZT32fRmieIrSTSeDstXf3\nEOwMbaoll7YTRN1dbYVkHCTCjLXu5goRptjnxZvhmuAG35mzIhUiM8+0ZOSMa9lFqLb7IgZVAmfg\n1nJxFUqrCxKDMg8wWFzwKCCNRudUbjjKRmPzwurpiisuTKUzhXMT5daTs9FdmUlNtErQJb3IjtLE\nEbG0MVO4MhHFQRPVud0qpzkLZanpJDQTvBcKmYbB5GwkkjN6UFwxcXpVvJMPxJH3Kp5vvu3yPpcc\n15BsCfMc14yAus9gK4LHhmuw2pZjH8nw02us3KLjSyprWlHGcB7u7kGEXjOmJ0ir9bE1RJW70riN\nwWpGrBsln+V0s+epD/L/f4v+Yq8vtug+Q288IhNQLTumbw53vG4nuhs/rI/806d3fH974uobX00H\n0EKTwne3R852pbUz2l8gobysE1vc+Hk78Lf3f+C2vAafwAsRlaN0Rjjv7J43euPmMz9sRwZKlYGL\nUzGWOPCX7YxFZ5bBQQe/nc/89+tbbuTS5OYzIsLNZzaU+3JDIrh646/nj7nwyXcvj31i9YYSCVeR\noKix9MLmjc1yZts0GatOwYviSzqt3DTZLTnGJFtTQXbPfqRpLdtRB+1JJwsBa3mcFcmOt+wSqVF3\nTm2OhvfwxsAsGC1ygCm7s02ymxtYxuVoDoHLHnHu1nPZB4xRKDVy+yZOqUGlEyGYKrFVFAN11J2q\ngzBY7AAjO/m5rTRJs4BFZQuhm6IqFHde6ZXNKxsTt1AscmxQwpmKUT0pdau3vG8I1dMuMBenUzKy\nHbBRko3sQdVgaJ4iPFLKWAmsWp5EtoZthbZ313OzXJqNivdnJkcmHY+uRIfPQ/ad+4vl4uz5NTTJ\naZFLJoSMkkV5eKQShEBbYNEJgSfvTChNK3OtbL7yod9AKrPCi9PM1rddJihpmUapWuh9cDpN3NXG\nm8OJc9849w27wcN8oP77eOGXfVVVhjsfLjdaVd5drjycDty8s3XjHz584nGsRMCvDw/UVoDBn25/\n4sf1wjlWBsZX9cRcZhjw/XJDFF63jR+Xr9nsyFEDKxsfxsx/OPzED3aHk0suCXhVFv5xe817v+Ov\n64VHv/LzaIzI46lhFDEuduI38yMjjKMY92Xjr6ZP/N3116wh/HqKxEYCm1eWaBT1PXqm8Hq6Elo4\n+MrFjmyWW/hn5GI2kskWcM/RQuxFLUkr+3xQ2Oe8OZ8LsX1msEsQUpWFTTmSEHZXmoN7HlnRlJpV\nYpcfgZU8XkbTz6oHrZFZXh5sxWlFQDMBlwio0Bn55zwIlXSzBYQmJjCPzSXZDiVB3WU2NAZmyjom\nIgpuzlx7yprCcYKrHfepSGHSwSwrqzYWn/BSUlqnRuC8KCsRyjUaV2mfxwMtjBqGiWIurNS81w5T\n2VjkgDts3vCS329FCS9IDiNAhY4SLplYjKO7/buPkg67nc3rxTB/1s8B+33hORzTdv00QtScqY9I\nrkOIPgvPMAX3KyMqCytbnUDgpcz5+oXxODaq6F5WHbRm6CXgJbjYwhyVa+k0Lbw+HZlro3uGVN61\nxt088zBlNPW+T/3FX19s0X2OBfn6/rT/Gn4+3/jj7ZF3tzPuwYjgv756y8+3Kx/7hT9cPtIFSmn8\nup6ICNwWrPyBm8eeEBBMcmAdM4sXPnSnFEXU+d3ylsVnTpIR2LcIXtbOvW4YCbYuBG/qlb9b3vCh\nn/iP05lHW3g3JsxBNHPSFONiL/jN/C61u2K8rAsV45+u37BFRvpMnkxdQ1h7Icty8mDnOhLTKI77\nxOgluzZLkpdo4sQKiu8R3lns9mOhAd7yjWIOMmDKDodRkb3Y4fnwyDTLjFVvlkyGIIE3z7VgUkml\nREaLESU7tXmH7TwX3KGB60hrspRM3BUj6m7EEM0l1ciOFzGmloJgt7J/jYqIZYLDvDGJZeSMzZik\nWuLQOh7ZHceATSc8UsM8FeMgK1upbGPChM86akkwJAKYCEN0f7c9EynApaD71zyAPiqnnR8pO5y9\nR1rwxJN/QU362VjTFRe7/c/FMC3EttvxJHZJSlqa8+SR44XncEq1jA4NB6nPfGRHI5d96y4vU8+U\n54x22lifZ+qWi89ftwcsklS3WCZlRARNGrMKteT+ZBmDPpxDq+DOYoPXcqCofmZcfwnXF1t0IY+t\nzx2vSGFx48U0cZKXrD5YGXx7/cQfzh/42BfmAqIzv55e87Ff+dQvnP3CqzLjYrwuEyOCD6vxH+//\niW07sTEnJ8AnltG4+MSP1pi1sx6df9ru2Sg8aKYCO52jVl6WFciGsojzq3bmf9ze8OgP/Gb+wKOf\n+DhmRiQkO1kCztVmfj1/yoVOFKZqrKPw3fUlA6VHYbcVUHEW1yxAsZsR1NEWoPuxdUvACrEv0djb\nqkFWSus5L0zCTZ5aXZAwJEZK2UpqwyRqfkMZnIbV3fe//9UiYGNkFDFGlJaa25DdiGEpL6vkHHvs\nxaJGqgQ0Hy5VhREF1SBKUDRF/cMKMWRPH85xQDt0JIQeE9ee34BoUN2Z7pNZfDZlWMMUxJ1jGXl0\nLso6lDVqjmQiMY4HW5BauEXFSPawiCHm+XpIw132ZGPAjbkkN8HJZaZLJGbTI79eZQeuC27pm25l\nMAa4KH2UnQ5PqjlitwI+y0g0TxVh+8cBj5bjGsmZbwAUWMWxAFHNB13JjLVg8LQX5bs2cajCdTif\nths9jNN8SKKdRwaDNuiSiRhEpZtznCce5pmX88yld85b52rG16fTLz4b7fn6YouuSkaDf/94phbh\n43XhvlXuakUCfv/+A4/ryod15agTr14cOdXCj/09357f88GurN75am5M8sDEzPvlArISrfPT9kCP\nxoMWrh58t7ziv734PU/XX3HTTK/tPrF65dEqj3bioax8ig986jOd4FVdQAbKRpUDr+tK8+xuKsY3\n7cz/e3vLJRpft0cmNZ7sQI/UjlbJ5dIWldeHG90KF6uYVHool+XEcOFmyV9VDSaMmzQcwS32YSyg\nTuzJuGypOBD33earWTUHO6zckg/w/7H3Zk2WJMeV5qdqi/tdInKpKhQBkN0civT//zktMxxhNwmA\nBCorKzMj4i7uZqY6D2qR4DxQZh66SQqT/oJC3YqIG4urq6me8505toDo1KIYjIhxUUArrzQ/nR+O\nBMwmxMCxzXZxeglDRqT7vMa6GCOF48804Ny4BiqxOJItEnkDsEsjilJNneoNUsJsRvh0IEVHWNKG\nAJe+8jI1zCKhPtEyrbUI955Dn+xKlQ1HaTh3XcDtFfkStmMRhmT2r4XUSWmw7ZlcovOW18LboI/E\nop2ksWiM54EEvhEwMTYBU2FsICn6a3Rqa+fs1ofgSoB57BVPFI5B1YGS46EpG2aFZoki8RScwhIQ\n5yYbDqy60Bx26+T5uYd0xDM65jKUWHb+st8558q9d2rKvF0PLKpcW2cvg0MpHEvmWOs3MVZ4vb7Z\novt6PawVd+PtceH5HpzPn24XtnvjyTf+25vvuLSde2/8z5fPfBkbuxjv6hEVDfus/Ykv+0c2QBjU\nVLm1t9xG5eadpI1Lgd9tP7D7wqMYd1c+9CN/s37gj/uRZ6l0grHaHJ5G5lN/zw/lxof+xKdeGQwe\n0wZ0knSUzLtyYfFAVGYdvCkX/v72nptnVu2s2bh1ZVg45Uqex3pLlNxQT9Nslmhd2EeJYtlj5owY\nmgTb07wD/c+ap7k8EzyI5SNyxVzAg6kY5okBnicW0jW0oCZID0nboGOSpv52wc0QCR7D0FngVWCX\nrx5/V8cWx+tER3qJ+bNZwMsVxi6oxqgk43DoARQ3GJKRIZG5NgICnlOnJ2X0wpAUTwGPr33QEYyO\nfgg5Vli7qMlINCAxBmETlqBZJAl/2hDok6v7GuvjZuH6Vf4sGRsxnkoSHfpA6BbdOSmWc8ukmfVR\nkH0EvjGBiAWPYihuGl05TpIA4bxWNSO4IQyZj4gxO+kS0XMyiMk4pDwYXekDSgpdO9rZ24iHZnc0\nCQ/5REcY3uZJClDnKIlDTpE43DvNjdveeFOXcEy6c3w1KP3r3vb/ptc3XXRfbY2GkLTx08sVN+Mh\nLegC51S52sbvXj7x8/0WNwPC3yx/QaPxue388faR4xJWzTcpIRxovfKw/IH72LH5I1apfNhWPvUl\nRpa5czHl7+7f0Um81zt34GbKr5dn/uf2NqJnLNM8Yxif+sLn8cgP+ZkP/cLHdsJcOOoORF7VRuZt\nvrJ45e6ZkkKN8NPLI/cxrahEVtokJFAEdpcYH2h0YcNj5GJN8aaT5DhhMkQH7DadYh6LnUglD62p\nzPlgFOmYmdsMdnztkjXFZk5EYtnjKTi0NTgWpgnVUFZYB0+KDMMrUykh0GZBjvYfzz4Lq0S22DRZ\nGD2WTT3FfBNBbJDzoCQYU59sHooNF8h5J7vRVXnZl0jE8CiMiJCm6fdCwbzGGAWh0CJmiMQmxMzZ\nBRnGqjsbC07m7iEdS3PZtlngF4MQ54ykjJbQESMWlYiVD2VDzP/ldS829CvzghHhk6Jj/u1ZfE8e\n6EgVY8zll6Mz267HEk8N0z1cbH4KkH5hMpDh1jcQoUql1ETrg+d+Zx/GoVZMxqTMZarGmCTleLDd\n9p2HZeHturKWwt6NL7cbSRO/Op/+1e77f+vrmy66NWU+vDxjDs/3DRV4fzyxlsIfvnzhw/3C759f\naOb86nimqLB75U+3Cx/bC0/txrkoKgd+yD/w8XalSaNxJacF1HhL4mrKp/s7/uv5d1z9zH0cGFYw\nL/zhfuajnWkj3ESf7QOf7484wnu9TDet8aty4W/vHZXKzReaZQTnl37gi628y1eyOp/bSiOT1Diy\nxY1J4pQbSZ2XVknZGMB9W+gjs7nSRkKmi6uTyCli2t2mQiHSJ5EWabOMcHNF/RFUlEG8liwMCCTD\ns+KS8EkjU4kjruXEKJOb68AwXMa0HwueZjfZY77sseDGRnTgOlGRM70y7K3DYZYSUcKZluJB4pqw\nPUAvmhS1huSYU997QgYMTRRxat3mnDexbYpLpEKoGZ4G1QZ7UXor3Lsy5kLsNT5cTOkZ+mwvw6Bl\nmApNSpw0Uqg6hoX7Lp5fE+7ewkFSNJCPWp3mkeKsu8cMXh0v4HvMXQMNCppHMOBtdrjTXKJihGsi\nEoxFYlwjFr+vJBtIpSeJ52r8lujWyC6MVOK9ljJPCGFBdoHNd0QLWTXGDBOu/OneeLNW9tGoWnhb\nF2rKPO2d4yIsSVlKYcmRoPGtXN900RVANcwCx1qQBs994x+ev/Cy7Xzer/z29MBmB9SFf7w983lv\nvPQbixZ+PGSqKCIXPtyeebYbXTz0uv0N5gc+9YanDefOYzvjJN7o4LlX/ri942+O/8jlWrjImTGU\nu1V+2U/8aZy4j8yadn4aC6OfAOF9ikKsarzXK+XeKRILtN1ugPKpH7iMhUNqLN65t0KbJodjbSFP\nkhSBhwy8xTy3uzCsMIYwWmyzJY3oRIdGsbVIydV5fB2EaWJMnW3ygROdq+s8otuIY7NaLInSjKgx\nMNVgAgiTsiVYC2aE9860/sOoYbwgmMPoq+EiFqHZJWaeBB8hS0Sk207Mid0ijUI8XGlF8L3Q5nya\nYiw2yTymtJFh13gIDNC6QzIalZsp456ic0w+Z+HGNqI47RQcIyukbqQUnIXuBPtXQEeoJgYZkyDC\nOaG4iImK0C2ihVqfNmkTtBjewno89qkoCbgxSoIWnbKbzrggD+vw9HK7GUPjtLJYw/Cgyc+FWkDb\nIWUjpQtYZniZ6R6CMrh/BedsKMKpnPARoJw+kXGbdZaSOeSKm9Jl0HCe7zd+XE/srbGWzLlWksaD\n8lu5vumiO9x4XFeSCN2Mv/3wM8/3HetG98FfHM6QlM/Pn/j90xcQ4bnv/NXxV4g6bRh/f/kZsrOz\nc9CFpMrqAnzgblfuYog5WQof7r/iZay0Mchp52O+8dDegCTeauNLX/i4P/Db9Wc+Xg7cWZEBl7Hy\nqR34qZ24WuaYdv6pPdBNg1madppFvHtOd+rW2bVw7Us4hVx5GZVbLyE8cI14FwNXoRbjOmKmGG1n\njFFIRHqsJXxmjSWZN+3MJdMB4hagnNfti8xi0EIbmjGoTh85jvrDv6IchAazG35FWWmOBZm4zEVY\nmVAdonODr5D0kDklKDO2RyP51oeAFNxGjBNWJyVndJAt7MEyOtQYPzBRmC6ZbQt/c9JO0Y7lhHvC\nLSMWtmFckTTIMri7hg2bFCGX6qQpgbKU6JJ43UkqU0ebE811wmc0dCGmjKl5NrVp802oxwlElHiP\nnsJ19qpMyDPFuMXPS4bF78+Dc+tjfC3EmpxKo0uOQMxZ2MNL0qgaECQQmseJKZcRkkF1+gDXhCZB\nVTFzttFp1jimiiGIZAqJkhSTwVoO7NbZeudcF94ej2QJhOeX244q/PCf44Vv4yop8fn5QrPBvXcG\nztvDyl++eeTD5cLH+4XfPX3mp9uFtSyULLyRyvP9zscer2uKLKtfl9/w0nZ6H/zUv/BIQtLOoxY2\nc/btkXX5iPudPcUNap75u5ff8mUsDHOW0vlTO6HSScBb2fjSF26j8pv1Mz+3E7sfkOFc+srn/cBH\nO3O7Z45lJ+9GM6VZoYqheaMmo6rx3MLxdGkLpJmB5oWtxSZ+9LiBouqliOlpgQfzTsxGZzQNHXwE\nPEeJzpcpffKo6jDi63oOXa7NDlVaxJ6LgJXoKt0UH/4V3ygiqIcUzIcgNCzgCXj4WGOWC5O8Hsdr\nEUcauM4lIAYl5ryiiWFhDkhi+HA0pegqTVBzvK+x9UuR4msiDI/fk2+vacVRAF1iJLC1gpVc63bX\nAAAgAElEQVTM1O2BONWNDUF6igQLglcs7nMsLoGVLBZ5bpMjjIQG91XnnHzCK8QjwaNpOPWm0kKF\nIOS1whRRA0GIkxHzd0MRgUUamwvZYjGqbnSJbh2c7J2u8RDAnWGJAjQbZDpde6g89BiByQncQ6a4\n+R0ZibIWlIrgdOu89A3VE9vYOWjlTVnJSbj0nR/WM4ZTc8g2vw0vWlzfdNFNEhlNTHNEkrhx/uny\nws+XC5/uV86lUtNbjjnztG98uLzwoV0ZZjwuB5I4NVc+Xzee7MaOUTXRxpnDKNzMab5z0RtlqSQd\nPLpzG5mn6yPfnz6Rh3LXQh+FRuW/v/wlX/ohkorrxj+1B26eyOKcZOd5VLorP6xf+OXlQGPh0itV\njJdWebaFey+subP5oI3EHnc/h7LTUbQmtmvEymwt0yDsp0PD7z9AeswUJccsFWTGqwvIIE3IF6+z\nV0tggyyOVI/iPUHnaQS71hfBZgSQz3QJ0RYJxB5zULEIZ5RZ7F2JI7BpFGyTwAtM15sboeUdxDIp\ng5Tp6bd4HQMhukZPguYZ++6hDghIi5A0oamHS88S9xYYSElGxSJMk1hm+S5QNEDr2uP9a+JOsIlx\nR5tg1fAhdA3WLkQQqPQwamiOUM8x/yajzZ9LSElx9H5lXIRDBphSPQvlgtiU6GWbYZtTc2wx/9U5\n+kGIYEtViu800oSsRzQQnjBC2ZJxTDLNlKFKkkQSp9PYR573TOQNQoIRy0+fZo1DKjxIJG43CfD6\ndTcOWrn3nWNdOC2RFvGK9fgWrm+66HaPzjarYu5ceuen5xee242XtvP+eKRo4mm783dfPuE2+Ol+\n4ft65LBUsiv/cPnCfdu42EZCwtrIgW5XXsbg7jsN45QSn2/vwQt7G6CDe1bO/YIm52yDp37g0+0N\n7w+fuVrmyVZaW7jmlT9e3/NkFTPhUHZ+amcOtqLiHKRzGRXqjfeHK8/PC47wvFVKhm0o20jsI5FT\nIADHiLhtUyeXjlkOoXybqoRBWHaBV8uYWix/YKBTqfAKKQ+iVQu4mORYZllHJluAuZwxnev2qKxk\niZga0zgSa4/5sjNiFDFZBHydzUaAZkqRs4zxNWLJVfH62nULIF/jyN1A65wbDoVm6OTZqsSIQdQx\nVdhrpOX2qG9eomi1kXESd4tuMGlIBPesMFLI5Ia+OlqmRtkQF/pXmQHgieEj5t5E1xvBborpBA1r\ndJvRwP6zjw3fSHS9U97F6wKSMK24CygkF5LanHWHEEx1hAnFIHuLblfCcZKyUTyQlyoDtCAuJI2O\n2W1n6MApoRAhRjxjODe7kvSAIGStZF1ZFDzDmhPdjH0MDqXy7rjSDbp1nm4bSYTT4/l/893+7+f6\npotuFglrohl77+y987hWfng48bxtfLnf+dP1mX+6PAV8JCf+jzdvubfBy7bxx8szjmHJ+KvDr6Jr\nNuen+yWaPu0ccqYY1JGAK7c22DVCEItkfn/9kT5CPpPrwNw5lCsqzonGp37kaTtxKlde9sRNVnqD\nQzrxh0vlMgvxWgef+5HUO68O+kFmkc6xdLZxxFEud2VIaFcHibYJmiUIYz2OuGmmQliXuRmHQvjz\n8VnPhryefINZK5PbQMI8BPpZgGXEEVj1a36a95gNa5ngc1ekzygYGZBmckXyr6mUsfTb8FeuwIiE\nX3QWRmK8gEmYJ2YHiArk6CCHGQkNQtoIO7SWaWmeyRfiaQqpgBpF1UXxVugGKoqYI3lMG3gCDalX\nGBxAW1DQhihOibHNlLhp6X/+ORKnB5pOOls8HGKOMpd/Y9KALLreGLfESEeI4s9EajJn4EG40ylx\nnHrp+fuMwdZ0oaWM9HiYwgjVXTF2h9ELKoUkO9Vjx0FSnESWGIlI1HO6byQpVKmAkoZjvvHL7iya\nSCq8yQfeHg6IC/fReShhpy5JySr/aQP+Vq6syrU1+hjBWrDBqRb20Xna7ny8XSmivF0PvKsHmjlf\n9hu/u3+k2aCkRJLMUhPe4Wm7custsIO+8CAPeDfuffBiO8dVEe2sEtpI205outFN2FXYu1K08LdP\nv2XvEa9Sa+MnzvwmNwrOgcaXtnLVxjE3Xlphk0rfQ7t53SubhcNpzZ1LC13p61EedRLByO07oMq2\nSQC33RFXuk3Nbo+tVdKICE8WCzZ9hXPjc6MFo4dsTLzHGibFTDEcvDLz06LA6TJbLclYF1QbySeg\nvATVDOH/Fekein6iMo2AvKAEPH3mqIVKLZZSr5QyJVxZIpA9hylBBVafTIgcwZZmyEiQlJQN6/Fw\niJzyWA7qzBdS91gPduLor/PrmDOGoSkSNHgdeUR9jxmr5FiC6XyN1wH31A5LnlZdAtKeYKLKsOFR\nIE34Sot/JcXLfDpO1UZyme68hPigpij2iQDUxJwGEkZJLR5qXtl0hS6xRJTOZjlgOySqQpHMpnd6\nTziFUjNFc8CM3Cgp0dTpOAdNHNIB3GneufYdG85DqezqvFn+jFb9dkruN150hztvp3oh8pmM3z89\n8dx2PtyuPJTKWhKPXvm7z1/YR+fj7coxVR4PlVOu/OPLM9ftzof9md2MtSQOudLGzr45dzq33tAM\nez+xcI4NsBtffOPdeQPtVJNwo10TKW8MC5hXG5maMv/np19z7xlxoSw7v/QT79MzlSA+Pe8LBWPJ\nO7e20j1z2RVR2Fs4m4zoGIdPEteQSM/VSGFwIhpGRtzg7hL2Uc/znp5JD9JIc7mDEcmy0sH3gKJL\nBDZGVwZgMX6YGlmXKHbaYnk4CN2qIHSPj5H5vjCLNN2pwVXAvIcRIi0owW/wuYV3dTS/ijAUt0E2\nfW3XcQ33U2ohJxADHx5YwcVwc3qPryQd4LXwGd7C9us2sYnCBLwD5qjG7LMT0jo0FB/B6NFJcvtz\nV/p1ZuD6tTj7q0ZOAPevDwzmvw77nYT4dtg/e5BG2m8mXjMNeRe9gQZ+8pWTLJIDa5kGQxXTDDNh\nuHhIB3NyRHXO72s8qNwwuSJeSSmHHryDF6VLw3yjjcFJThzTGmOgbFQp9B4jq2MuvDms9O687A1J\nG1mE81L/N9/t/36ub7roQsDMd6D1zqUP1lJ4PKy8WReurfG0NT5cnrn0nYzym8dHth7AkN99+czz\n3hgYvz2+Z7NGduXpfqNbYueOTAF4NcV7Z2vOLhGLQ0p8uTxSRGitR9aWdt7UHTCKKdue2PoBUixq\n7gJ7ywjGp+uPbL1E6OLSeOqVh7JRCBzirSWWEv79ZpnhSm8xK/Th2CR3kQlQjTjK+IpfFEuBmBwe\nndN0QIU+N2RLAqhH+kYRmazWWJIF+cspYvTZGcb8NTirqVh8cZkLOZfZTfd5wHdKUUYnIt1VMJ+8\nh7HEH69FMXAPS2rK0/X22i1KCs1sFQZOIhxnMgtaSnP0YYa0gMRk18BExjc30ZcpzAwTAPxqOWY4\nPU99sWrItHqYMkb2qYkVIJHGHNHY7MRfwRSvy8h/PioIIWyMp6doRAREynQHZrA2z/gxkx6+EE/0\nGA8JQk6v2MaKDWFJO906ksCkxmmEiJJXa9R0p/tK9xgFmA+qRubZ1+9DM0uJB/EwGGxkr6y54r5Q\nUUbvNHFuI+QZ3y0nHuuKA/e9cVoWzOJhl74Rju7r9U0X3aLKvXeetg0V53nbeH884jjX1vjlep+i\nfuU3pwcOubD3xt+9XLn2xu5GycLbspIsxgO/bDd8Znm9refoIEd8bslKl1sQ98XIvWI6sBbjhZDA\nKx+fH1EPofwQo6fGm+MzkpxqsO2V3eY9K86GMFpoT58upzBCYORi3DosOaRdas4+whrrQhwJSbBP\nJYH1KZGKrjHSD2K84F/zyFLMVUeQzZgWYFcwC7CB2GvET+ykLefQ72rgBFOy+FSSGD1GFK8x7qYg\nHlBEwaELqlFozRUzSGkJmdeIj7WZevkVWWgShaY7qShSFUhEGEMwa0XitaEWYwQ0lnUT4o1BE516\nW4nIIRNyDlfbsIjgEZlmjXmFdTncea/z5TGpava1ldX588h8zXiD+B4kxiQeXgbcHE2OjDn0HfEw\ncyy4ERY/t5ihr+A3Sh14U1QSQypiPWq4KiQhz+QKl1cTxz26c810eYzOW0OeF8+vEakScgxYkTut\nxahpLbCWBemZ6lFUt6m9Pjk86BqcXnX20Wk2eLMsbH3w/njgWGtohf8V7vd/L9c3XXSHO2vOnGrB\ngSzK75+fuWwbH66Xr46aH44n/uH5iU+3Gz9fLjjC98cT79cDH64vXLadT9udl32jFo0cqJEZfdCt\n8bSPmUeeeZseGC3+yC5bI6+Nm2xklOaD2hfMYkO/ieEqdFN+fnoMLaonhhhbT7w53hFxssO2LUju\nmIdEq08Ai3jmdg1DRJqfT4YhoXWPm5/B8BxuMQ8UoY34vG6GE4VYPDpsVSNr+PeTWQj6W4xnROIY\n/zrSDAtw5KElBc3OSIJbxj1UEJhDsYgbh1kUBd+dtNgExQSbt0jGxyCJ0kVJc2yxaojtpYGWKJo5\n6ZRZKdLDOSbDoQqSCmM42jTes8dCNKd4AA2B1IykKbr8zCS2yQQBKcmjoCaLMQgKpLDXMvdfqcXx\n+zWWJ2Ycc77uncBy5Tl2cNKIgi2MP8vczMNm3XyOIBKiQvIyRzFCks42tgknqpRi2KiIFdALKXVs\ndzwtQSQb0PvAVUlSME30oXQ7YlxZU2cIiCVEHiJ8OTmKBsingDbhIGcWq7TUAr0p8JhPnLWioqwl\ngFL30TA33qxHjrnQgevepj49keQ/xwvfxBXb8vjnbUSGUxJ4PB5Za2bvg2aD3z898en6AiS+OxzY\nzChJ+fDywud7sMW+qwcea6VI4t4arXX6nEPWDFkTfbTQhHoL7F8NotUxZdo+YsnhA13u3Dyss21A\ntRKyHzHu7pHs6plPzyfMFDVhJGMYrMse3dkID3+eInbH2D3mtViOWHAPB5pJzBcEpTAmlWvQbZoj\nTMMeamEMgGC9RocShTirISnmr+F+khkJ3imBr5pmg5jPulkkY7ggWRhzueMGWZyhg3SY81MLNoIq\nWDdMDJVMdoVuVOaRXWfRFmERQVwYA7Q7kp0kGqp+AXYnezjjZAa3qSu7GTKgasxgVR2bMq0u0fEG\n+J3oRJkkNNEIyTTI8+30BGOGaiaf8q88pwMCSTPDnCzzNvRwScYoIZMIuHdyQcUxMZKEtlo9fifN\nYpSQJFNSpliQ1WyAamM3o2gh2Yme5ollZJK+UBLcOzR9A0MouoWWWhZcFpLEIq7mjHuPxWNy1BOP\n6cTuSnOneqKos9ZCsoV3+UzVwm6NNgYpK79Z3/B2WadhUVjnjZeTxO/lG7q+6aKbVRnu/OH5iaKZ\nP11eeLseOJbMtSX+9pefwcNAcV4OvFsOKPD3nz/z5b5xGXEOPNc1RgZt8OV+YwygC2/WBVAex+C+\nN4zKp3ZHs9AxVhLNIbVQDCTgpgO/J5JEMVWPUYEtGy0OtPhIU5YDCNw9ICbDE+NyDO2qx1G9WSMX\nnbPUaRZgLp1sypY8OqrkAxuRJNFJc97nZCJ6PGmwY0OukGLsYDa1tjo/Z4whBCPLICXHPI7RNmKe\nqik64i5R1G28HsMjnTgWdIqNhORwjYHH0V+d4uv8mNjev26/kwcjOU87Me7kFMVJZxEVA+8h1kpZ\niWCeTI4WmzJCAZCzxuzTweb4QW3Gp+d4AOxTsmYCeZo4XCDtwihOAUiQp3ohq9DnQXo2wzN3Lf65\nAYuE/Cw5QX+bf6tuUNH4Hj0WpAmLxZhDFZl25pizIlCkhITNM7XAGJWszjag6JmU4FW6HP9TEIQi\nyppCOtcx2hxPnOuKkBieSBzJGjuLmpxVT7xNZ3YGO40HWVjSwpoWziXzeFhQUVrr3LuxpsQPjweW\nnOljTJXKt3HJ/4c+7j/0qMXd+f3TF9wjmud52/jjyzO31vn5doHhHEvlzbrwp8sLt9Z5vt652M5R\nFx4Pleu28bLdue3G83ZHVDkvlTGMvTcE4dPLRs6x2LA0GHMud9k7gvEydro6d+uIGGawIGxudJyr\nNIRYfuQhUx0w6LXjBK8VC92xd8FSn0dUpvQquli1Hk4vIaDfbtiIBAWbizQfKY73r8uu2HDF1/GC\njQDVZBkETrwzesT8+lR2uTk5RaS7Ttg4TSPFgY6lFHPvETNJH7F9t6BtT/h5+fPDIxPLNIkjq0h0\n2WKvLrKIWM8zw80nlCejkWgR38KUNkVgUX41QxABkKHOiGJr3mmq080W828dRES6eGAldkg5zAd7\njmUbZXCXCY2Z45tqSldn11lkZS4TJd53m/NjIVQYKxKnGaCIsHksKFO03VQz7sR7mMHBOFC1MLzj\n2bAJhj+K0sQncjIK96HCbULjjsvCfWykBLvFg+RYFra+s9YSsUTeqCVcDt+vj7g4t75zKBUx+P70\nEOxgVb6vJ25j45AK79czb5aF3xwf+dLuJOC7wyNvamGtCypwWoKlKwin/3jqhX/xGfJNd7pxCUWV\n8OxELPRSlO9ZubfBooWP1ysfrzdaHxyKImPhoS7s2+Byb+y7cy6ZRY9IUsSVp/sV63EMPqyx6GkG\n8krQGnDUzH00Dqlwb42zrvTRaSJcrTGy02xwFMVHpXri5h3VwcWE1DOKcADuYx51yxYKp7mYirFA\nFJLdExA63dwG7hpFUCKr25lJsySqNPrUzu4jR3y5O1UJA4OHGiAMrUEd0ymeTdkZPWZ5PiK9IYIR\nO0NyuN1IgUrUKJCeFScWbkqKbs0sIuPNyRyik+sgNkEuGouur/AbDZtsEkdTjnRjmJrXyFArmUn6\nigdFnWsw01isNnGKZ1SUboMlJVyj0PkcVFeHUmOMoClj0pAUxaxUKC50n/99miNrJqVLlJE8Ov9Y\nY7KgIcJwkJwpe2z8swjqzkELXQSzTk6ZGmmirKlMWdfgah0X4SAFSaEDX5eMj46whS5c4ZAOZIyN\nQbd4B2uKXLclFw5led1/IigPyyPnsnDdd5obSRKnkjmUwiqVt+UBFefS7qg679OJ3xwf50JS0aQ8\nsPJQKqdaOa0Lj8vK832jj0FNmUMt/2p3+7+H65suuiJCTcr/+PyZosqfLs8spfD9esDd+e+ffuZp\nv3PZGyrwVw+PHGrhj88v/HK9YN14aRvvlgOnUtm98/nlQpYSMPTpwHkgcetha31ud3C4y6CNgeCs\nlll0JvMmJdlOBmrPvNEDCeeWOi/tBiWQf+esSIvl2Da1mh0nt0pSWF3ZDSzdMQmpl4pQLUWf5IqZ\nYZpDbiWZ3AfJE0l3OjmWSyRqCivr4p0xRfzDYmQw3COl1xsmoS6I1y1muWWEZUwG3gvDolCKG0Pn\n2EMlXFkSBdZfjRVFYVSyxXsfIwq/EprYOhTNTt8tHGgKkpUiMscwkHOkC3tJLCnNJZjjbXxVEyyl\n0m3gEkf9SmKIsKyFbMaenIO9CtkiGw2FW9sD9zDZuJKVgnLyzAuNphFlnnOiuATYXCTUEUQMKQjH\nXNit03CSCZ4zay4kU4YP1pzZ3NibkGthNcg5I8m57h00c64ZG865LAw3tt6wNpCknNMRkcxtNJa0\n0sfGSaMlT1Z4zAt3Hfjo9GGUXDjmSvPgKjzkGnpnzfPznXi/nNj6htugLpmFE+d84HFZ+PH8gOF8\nut8Zw3m/HvjuFBSxMCEZp7V+xTp+a9c3PV4A+KeXJ563DXO4tp0/XV7o3fjlfuW+7RzLwvvjgS/b\nlVsf3O+dT7cbR8k8rgubDZ7uN2QIf/ryxFoLh7xAgtvWqJr4dLmy1EKi0x0ufQRtabuhPdETdDov\nrTN8YDhvUmUjpEE329lp3EdkXVWFLInPfeNmLfi0RM9ZPYre5k7XPQTq6hQHJAo2NLoEXQsSq3ZG\nU5I2WhcsBz83i4D3GFt4Y0gOLoMlNDRfKHtEnDPhMR6dpErYZEXj3wthMxYMSdGWOsboC8ln5KUw\nEyRAkuBtieSIyaowKZFtp84M9o1u2oQlh1yrAIUYWQyNRQ1SOOUYm4wRqMk1L1zHxqqZrmH6UJNI\n9TAjHzL3+xaaZYMqAXIZhMR2c2MM2FKMSwRjz4NNPDTFKRQNhYwkoWssN13ie10sXGtdPbhcAtrh\nkAs3N1ZJaIKtNw4oSGYIHD3+XgpKx7nvO8ccutlmRhZn85j31yTc9o3DsrJo4jYMT41bN9ZSKCke\nZA/nA5fbncYgJcha+e3hxBfbue6NNcX38OvTmTacmhLfH058vm+ccuZ8WHhXDny/nvjS7hxT4XE9\n8FgKx7Kw5MSb9YCqsvfOw1KppfxHD6L8z/HCv3S5O+e6fI1k//3zF0B4XOLI86YsbL3z8eXGvXXW\nVDjkxLFGZ3rfdtyFqsL3DydElQy87I3kFpvedWW3hpNp1qhAGsr79YE+BjfbkZH41alwvzcOWmh0\nhjqfbhuSYJGFH0vluTVsDF7kThanSGJVoU7zwqcxgpsArFqxLlTAvdE8XnOpHLzjltBktJ7QtNM9\nBdHLM0dtU73g7EPQlBk9U5JjZpTsDItxwRgSzisJJ5fqxhgFiBkxzOO9dtxzGAa6kMlojjfrOD4X\na/hK6oRiAIuRjSp1FkDmMnCoU1xJRag5wxgzp2yQVakqJDJVMjJGbPxT/Jxc4ZyOIM5ozillSg5e\nclkK+xgULQHKybHo0yw8t859anSXmhASOx2V+G9bCnQlAqrR4XaJGW1KsbBrHuoITYmbtUikHlCy\nkEtmHYMiStEorLUs5BQ65KqJz/uNjnHMC8c1cRCNKCMZ4Tq7N9aaOS+Vp5wR86ligJJWknXOuZBr\n4mnb2fdG0sxjXsk1cdl3bmZkV94uK0suqAvHtGLqNAatD96uK+/KSs356+4D4JQqp7zw7njk7WHl\nZdvZxuCQlHenY7j/vuHrmy+6ay78358+IgLPWxz9f/PwwJIT/9fPP/PLdmN04z46vz6/4ZQTz23j\nn768kFF+uW28P6w81oUjzufrLQqvdLQUcsos7lwNvA90FMgJZNDc2M2oKKe8kBI8HBbuY0RH4Z3f\nHs6YG1Uz995JufMyOjUlFgq/LoVrNzbZuI6dNQtbh7dFSCb0YnzegZQxMRapmCkZoXPBhrA7ZFtZ\nxMLqq85uikin+1xceeKge3ARknEfilAZJuQ5nUypM9wYHpQzJdiuRHMMlFjytEIAEy1GFKKgCbUD\nKmFsGOJoYaILCSlUCifZmMaF4wTzighjH6F1TR7FTjJLrkgPidsOLAoZhZynYiCoZQtO95htl5Km\nQS9RF2XVxEvfgxSGUVKOOXSPhVWtykXAvdOys6BUSww3dvWYpYqwklgcPCWGDdqcMZ9S6MSfmpHd\nWFXpvZNEySQeqvKQM9vUAR9KntllzqkcuNsNFWUbG8kL35WVp0kNIwlHL+SUuNiOifCwrpQ8WHyw\njzB7nOrCdW+8W1ZyqSE3I9Qeb+uJcyncfLBoormTNPNQYjzwl6c37Gb8st152na+O5748XzGzEkS\nZo9DKZyXSk1ppjh/29c3X3RFhGOpmBtWVpo7f7q98HzfuGx3Fim8O66c1oXr2Om98/Fy4ZAzb+rC\nm1Pl6b6xjcaHlytVlFozSz3w3DZygs8vG4sIS17RB7g1ow8DS/zFqdL3Rl4K99HYzWlt53EtJF85\nl8LnfefuAeY5S+F0XFglbuZt29nlxr07h1qpXsir0HrnzkbblTXFKOCcDthwmsLnbZBnCutZC8My\nxe50dnp3uiVKKqxElLj7YKQoJkYJAIwrS2psI9b7+ygIHrraFN2raI+gSg8+bSZIkY0AugSwPMFW\nkcxcMAEokqYUyqJADwggTYJFlTSm7M/itZTgXFbMnexCaztqGc3OkpU1LZFY0KOAbb3xUBKqmT5s\nkh0c6cKhSKTXiqESmK7umVOJ93yjI+Zs3ubIQxCNmXYwGITkjZwSiysDo0shi3NQpVgCjcUtnjhM\neE1CyVo45Eqay6iHmhlbY++DIolTVersFsUSp5xRAkB+KAWXM8ONe2+0bpwOa7gEPYJY++i0JGRJ\nvJPCj8cTP6Url9FRN8514bvDkee2cVwyx1xoLTjRqwrnsvC4LGwjct+WVPjLXHh7OLLmwnlZcHMu\n+4Y7PK4LJaV/i9v73+X1zRfd4cbbdUVF2Hrn59sLt9aoSXhx4WGpFM18vN745XbhIVUEONRCWjL3\nSyfPSO7v6hGbndatbVRNaCegOppJCTYz1kyk9h4yp1zZ1s7H28aaCkk6P7x7R9sHZHjeGjlDvxm/\nPjyQVFmS8uFypUsI8L/PJ1jgoVZufeOydT6N6LjXIhwkYSL03rhj2HBOJY7G56xszbkrvIxEGhlN\nwjEVbBRKunL3HhHbLbPkUEXg0M1pEsBtpyCuVAFLMVPsAtIi+gUp4a5jDwb3eNVWhPToFRuAR0eb\niWwvpmpCNF4rOcY3OuJ3N2yQciZ55pgjJr54MCC6dY6lcFhWeh/0HqrcZIpXWCQhOaMJTpZpu9HF\nWatyU6M3J/cUMi8SjylUJQH6sq8EtTdpwcW42eBFd7YJ1SleOXqh58gPK+qRb+eJ86Gy9Y5lJSdl\noOgIfayI8LjEyOt5u9EscS6VlJ3zsvC8tQCMZaF25W09smjj6juNCIw8pxLg/bZzWgq6Cy5O74ao\n8sNyYLMeKRMSv++RMqqxbDzkWHI16/iAU6mspbJo5lcPJx7Lws+XF/owzsvCX5xPHJaF276z907N\nmR8fHljKt6VM+P9zffNFd02Z3z19YZhx6zubdX5zesN5qfyP9Imfry+Uofxyv/HD8cSburKPcKm1\nu/FLe+FNrhzqEVng6X6jaGI3oWThVBeSZJ62mAkfyRzWSuuGqNAZiCpvDwsYPC4PFFE+p41f2p13\n5wNuzm/Oj1g3bgye9i3wgaPz18f3IZ/ywU+XG5YSpp3fHs84mUfNfGobzXae+qAsYbo4auFuho3O\nZoImOFCoS+GQYkTxQufFMglIWqmag33AlY1wk/lIVC24Z6wYbYzgAgjoyIiuISWTHk9k/54AACAA\nSURBVDKqlnBbAhwOETU+Z6CoUF95AuZ4GsGGyRFHmSYOccRImJTD3PEgCyLQmvGaaFFLJctClYK3\nTkFoDO5uHPLKUjK+9zAXNAtdboVEQlIikzhKALpjFACbREGTABmEaiPBDcPcGVlYvZCGkpOwi7FP\nO8SqicVCT9w15FqSM2eE1YWmwqIZzRnpnUNStuGsWnhIhV1elRMAgyVXSspc8qBj5Jx5dCGnzK2/\nYClTSuatwjEt3HOPuXFJnEX41Xriye582je+7BvNnf96fuDug92MRYXdPMZmdSEh/Hh64Doan28b\nRTLfnx94tx6+QmvOtYI7p7qwlvyf3e2/cH3zRTdpHBdFI6plSYXndufn64VP9wvVNJ7yj5n7GFy7\n8fP1hVUSj+vC46nw+Xql++DpdsOas0rhx/MDt9Eim2s03iwrp9w51YpIImXhw8sVVDlb4bAmrn0H\nEa4W2+XvsnDIiYygufCHly/QnMey8uvTmd7C6/5539jM0aIUd/7b4/eIC8/7xtN+J6mzmfJfHt/i\nPVEPytNtQ9l4MVhy2BzqkthssDXj7lBSRh2WVCkKNwvDxo2FBOSUXqmHIBs2CJsvB7KCTWF+m443\nmZjE17me24ziafF7eN1mu4CXMAvE3PXPGWNNX+lZ4bQ6kGjW0Tk71hQ25CoZx9iaITlIXEutlOEU\n1a/FU0ZoVhdJYec12JuTMuzmERzpGkB2hLXEsuje29zpOXcb1AzDBqopJGeTa6OqsfQbjmQN3e+I\nn22zwXDwmjh05bhk1pK59catDZaSWY8HTmVhaxvuxnBYUuZtWUI+KAtF4aXtHFLmVDP3sfKwLvHe\nhtFolGlAWGvik9/51G84wrtl5ZQLX8ZGV6ESg3RJwpGFYy3UrCxaeH9YyZti2TiVxHld+e3jI/c+\nuOwbrXce15VT/Q9ndPhfen3zRbeNwZtlZU2JNgaXtvPh5YWqmdsY/Ho981grH19u/HK/8jatuBnL\nUlmWzL43JIU29DGvpGPilAsDp0rhtt04puXrCOPSd1wMeuLHw5mBcS4Jc+G4VP50eeGxLiyiqAq/\n7FckRWTQ94cjt9Iim82gZ+cfbp851sSxZ/76zVsutzsAH+47qrH1f9QTvz6vFIef71deWsdkUCTz\n27WSveIYn1voLjfgkAKovRyU676zm4fmVAW1xDFHPuJusDsgC6JQPNgJTjjJ3JgIxTCFDA1q1YTL\nhsNsFv1XalpAZFJAeUbYbE0ESWGPlQR1KCUl3JzNApxzqgvdjOqJZpHCoTPBNhUhSSKXwb73yd0w\nUlRLQvSlKJ2kiVuLiKGUnTZiVqJVGd3xBJojD2648kZyfI1cuPoWHXGChcRqiifhoo0e0ckUMtVl\nugKVOhddOpxUhFMunGvBReg9cHLJnEUyb5aFLxY6b/XCQ104pfn9xkGAh0PlIInPfacLnHM8TN/W\nhVQzL2NQcuKlNx70wMNS6XssvQSoOVE081ASj+uRkpXPtxs/Xy88Hg78eHpknfP94U7NifPywMOy\n/EeXgf0vub75oltT4rJv/Nwb3YznfeO3DwHneLyu/OPLF26Xxh9fnnm3rLxfjrxZCv94e4aL8+l2\nY82J02lhPRUu2xYmBReOOfO+vqOksHwOdUrPmBuHnHk8Vm49IOclF6oof/3mPUMj5qQPQ7Lzy37n\nv7x5ZC2Fp3vjpd1pGoD03zw8srfGu4cj3WGl8vdfPvF4WJBR+avjA89jo6L87nol50yywV8sDxzz\ngT6cP96fpvDfeawH3hHsh+47zy26q47wUHIc/VN0/F0t9KATYVtzBFdu7uzdySKUmfggCQbBZxVX\n3GL5VDTwlUNePb9C0RSEs/i/SHLGEKqE1Tl16Gq0YUgWlloiEsiFnGYsuIywmuYF6yMobm1nI95v\nt+Ae/D/tvdmSZNl5pff9ezyDu4dHRGbWhAKKGLoJs1b3Bc10oRfQjR65H0C66Au1JLJBcEKhUENW\nZmQM7mfYoy62Z6KJFimym0gCVeczq4uqSCtL83Bfvs+/17/WqBSmWKaU0UCinc6rwGAUS6r4CsVU\n1pTbwki5dIxRcEYRL6vEiRZapEUwtdnagrTsha5qOmk2ukihGIvNgtFCpzXnALEKJrUT7t43q1Wp\nBVEFZzROmndajDBqxykE1gL7w8C+H1qRqTfczxOTUmhR3HjDbT9QauWuLJjVMGjNJ7sjX80nssCq\nWh7vtXGsFIw27DvLGhJONZvabTdw6DxOa/bOcfCeu3mm1sroHIO1mzPhn8j3XnS1Ur8trxYuYlP4\n5nTi9SV/wRrFx1dXlEs+wtOyIEnTWc1nVwfWWnHWEHMm14xk4Tjs2HmDaM2bZYbaTj6HfbOEeaep\nBbxrAuSUovMObxV308oqBWMUz7oDhzLiFFA1RgxZCrFkrm3H3nm+XU4sIRAuNeY/ubm+2MwsS0qQ\nDN8+nXjhe3rrYag8zjNihNfrUwv4SZUfXR2QLIQS+WY+t+QvKtdubC+QUkQVmJKiOVUVvdFtpCCJ\npbwNYBG8FoyodwWGgUwu7RJLibS9XYEg9VI7wyXpSr1rGS+1IrplKLhLlkHLS2ve2ZQq/aWJQqXK\nSqRSscZgRZNVC6m3WhNLYb30r4lyUDK6KELIFNXEO9JGGFaaGE5ra0KOF4dEMW2pYSVjL80ZicKZ\niFLCWjNOdFt80NI8uFqDbtm+TqnWNaZqS/iiXTiWXBmsYdSKzlkelsCS00X4Lb11RAU5FaZc6Kph\n1B56xTkGklRiLex1R46JUoVrb5hypFYIOWG0ZqimzclLJZDZe3txT0gLGtKCRXPdd7wY93x9emLK\nCWsNL3Yjn+yveAorc470xfBsHLjy3Sa2/0y+96KbS+HgHN0wtuCTCn91/4qdeN6sM8/7kRs7MMXA\nrx4e2RnLmjPeaJ4d9hhleD2fOC8ro/H86OqWXHNrStUtmq9XiloLx9HjrWcJgYew0jvHoBTPdgPz\nGuk7Q0yVq779mcE5vLPkWnj5eMbb1o7ww8MViYpVwhwCzxh5WWZ6DzvTVor/7uGhWdBq66r67Poa\nXSraWcK6ck6BU4h8vN/himeVwLwmIoWHEHBKkxE+Ga9YQmGVyN26tC8mKsduQEoFpTinlZWW9mVE\nYU1bb81kQr0EblfBqraWqy7OhlQrtbToQkOrTkcXAq3gscSCqc2RoC856EUB1CZuqpBLbfnDzSBB\nKu1kGaRiVlhUolJQWrBGo6RdEFmtWHOi6tYfN9rW1eWyJqT2umkN1EwpgjMOSiJcSjeraZVEmcpA\nq7DvL5nIq2qXo4NorOjWtybNuYIIvbGMyrHEuZVF6FbpY51F1eYe8GjEcFll1tis0cbgjOU+zZxJ\nWA07DDvlOZM41ZXOGvaq58aP1HjmKay8Tgs5FX58dctcI2/WucWElsLBGKxzZGM4Dj0xtRHb3TJz\n3XWMzuK0be0nxrCXFlN67LrNd/vfyfdedJUoQs68WWbWnPn6/MRH/RU3Q8/V0PPF4xvu88zXp0es\nthy944PdwMOycAqBEiOj9ZiucBwGrnxPoPDqfCKkzGg1N1dXhJRaQlVOaKXolKVXmtH7lhMrrSV1\ncI6j6Yi5I+TacgMKHDtHFcFpgzOKtWRePpzoOos3ih/3HUuOOHGcw8pHuwPfTCeO1uKqwajC3z2e\niDEylcIHu4HrVBjFMtXMEuBUVhKFj8cdVhtCLjyGlaIL05oZjCFVzYddx1Qqc1w45diWQWphcF3L\nNlOVNScimVpbOLwRRZLc/LFtBe3SIgEg7UQspT1x5Iqu7Ua8UKhGWACj23y1y5qi8iVWsV46zNpq\nrBMIOWOkMOuKs4olF/bOUwqkUFBEFlUQ05ptlVWEWtACa02AUKVtEwYRnLR5fyptNKAN5CJtPba2\nBgWpckmFb2MUWyuhtAs/o6Rl/Frd/g4pManEYCxGGXpteUozOTZHy2AtvfOUtBBUZcmJJQSe9SOd\nMzxFhTeKmBMirQnEVBi1xRvDfV64jxO5FkbnGa3l9TrzJjdnzWA9vThmlSi6PY1opfmgH3kMK6cY\n0UpQWvPD4y0KeDW3tpSdtRy7/nuZmfAvxfdedN/eZC8pNYO7tOy98xIu2biVJJXb3Y7eGY7GE2PF\nKUMOhZu+49mwo1CanWadqQVu+4EYC9djjzEGYyN3pxPWenZWcxxH5hTbBZFSDGgktX15qzVWO+K6\nUGqmd47b/ci0RkKK7UKowu3QIwjONGtOXCpv5jNeG4bO8OP+hhAKSsPTsnA79pzDyu3QXwKzK58/\nvEFQBCn8YHdkKYmdNjylglKRN/OEsYrn3YjVmikH5ppJNRFUxV1aDK615lQzkcSc2zzVKI1ViipC\nNS0EO1EhtaxYtGqP9qq0R/vS/LVatZhJbQBaeHfJLQfOiiLqls6VVEvuyhRs1RRpGYnFNP+yojVo\nWBRLTJfS3YJSmiqV7nKaD7FtgNVa0Fq/+yIotCSxJWdqrlRXMMURibgKMUeybo/qlTab7sTg0ZR6\nGamo3OLhC3RaM5cMSmOVsErFCXglTAhKNNa02qdY2shixDAoQ7CWtSQkw1XXsTeeVzm1k3UtaKPa\ne04qEgOJwlwS13Zg8B1ryZQqBFrwkjKVo+1xynBwjjfrxG9Oj+y857PjdUvRKy3EfzSWF8OO5+PY\nEvm20+3/EN970U214I3lR1ddS+kX4RdvXtFVy6vlzF57Phh3OKP4zeMT364Ttmquuo6rQ0fvWnBH\nLYo4BaporgbH6D1zTpxjxOaMiHAzjiDSfI+6ie15XXFaM46WXW95XNZLCHVkMIaQobctW9YbwxJb\nE+++67ndjTzNc5uNpsroDZq3p5BmiZrjwhIjzmgOfc+69m2uWIX7sHA77ggx8kKNpJpYquWLpyeU\nwBITPzwemJc2g34TAw7F47RgrOZGD4iDqQRmSZeUqkuugIDXiolEkfZYbkRhaWKalSC6kHO7QCtR\nYS+XTlkqRhViLahqLlm47VSrdL3k4FZAXcouhfDuBNriGhMFWzRZRYDmOFBcejCbwK3pElKuW4Rk\nEUWvK2spSBKCKq3MUgTnFSkJmILLwvq2Hx7BiCGQ0LWNVJYqaC34i9/XiLDUxLQErLM4q/DescwL\nsSaWrLDGsneWoiprSThtmOcFrRXOGkyKWG1BKilmgsl4ZXBUDp0nlMTL5dzGHcbwQb/ni+WRhQRp\nASM8Mx3n2hpSlFbEFLl1PU4bdtaz79p8dvSOj3Z7Xk1nSil4Yzj47nufmfAvxfdedBVCKYXXcWVN\niW/nM0fnubEjz8eBU4gYozivEa8MhjaD3dsekcyr04yTtil1tRtR0qxfNVdG5YgxYYxicBarDdMa\nmdYACrxpzcNSKxqDtdCnRKnQGY81iiUkptTSykotjNaQq+BN+/fOWJ7WgDPC6EbMAK+fzljnmeaV\n66HjNLdH1lAqXWf49hQQ2iXNbdfxtCYKiTUbYly56QZKKTzvDHMK4BNfzmcMbaPqo/HAOSUGb3i1\nzlgN01pwzuJN23xapa0h59yaiZ1qweXOQCitDyFS0aqFsrfdkIwxqYXKJ4VCU1NGq0pRBWOEREIV\n35rIgWor1HypXXpb0dNOzUFyqw5XFVcU0YCJiigJye1CzhoB1frXpBSWRFtlNqBLs3VpFKG0RY2U\naTGVSuGVZq0JIwZXDblGAuB0wSh9CVMHtGo+cG1ay0V+e9moMFXTK8NUE0tO6KIZRGGVxlpLIjOn\njFOao3WcUiDmSFWOqAqjsuhLGaZV7TJwLZkghZ13LLF5mJecWKTVt7/oej7YHfjN+ZH7tLIXuBoG\nfnq84RQjS4rMMXHlO14Mu23J4V+Y773ovjXlv1lmOm1a9JzzDM4Rcub1PPFyDVx3Pf/2+IxKOyEI\nhTW2SyqlFFd9T2c1ISZe3890vl163I4jVaCzrXnVO8W8FHpj8d5Sa2FaI3NcURk665pPVQSlFL23\npNraDgbrUFoxrZElRKxW7VbaNEHrjGaOic571pjYjR4nioPveVpmeiPMMfJ86JiWNj8+pRWthftT\nwhnFTjmeXw18O59oseOGECvXxlFFcWVgru2E+PU6tVvvqrjp+9Yx54THtCDSog874yi0RoxAJutM\nja1lQl/mv72BJAGpFyFGU1TFSqKYilUXm1URKJpaM6IrtRasabNglR2lAAjVVpTklmzWSn2IqTkk\nVhJaNNUUbFWsulX0FAqF1kunjaEiXCrkiLVQYqIauSSNg6USVaGkSiiBVJuLwqtm7eJS156kkFOL\nQ9xby2OK5JyIsc27ndVorUkx0neWNRdCbvGeg9HEUvFa8RAX7mPbztu7nr3xhFp5ipH5nKFUXuyv\nOOXA62ViInBKgWvfoy8LI6NvX7xTSTyuCwfbY41gjWFnHN4YvDbM0XDb93TGbqfb3wPfe9Gtlw/R\nT483pFI5es+X5xNvlolc4Fk3YKrmduzx2nKKgTfLhFWGG99zfX1gCgFF827Wqhg7gzOG3hu0EqYl\n8vA00TmDMZrdIGixON3yuboMMSf8JZUsxkRMuW1aaWGwvl2aqLY51ztDKQVn2v+v84anKbCmgqrC\nrnNAK38U0y6pRFrp495ZtPR0NjKFRO89tQSe7XvWJeE6w11YQBselwmrFSOWW9/zzXJCjGIOid5q\nUtUY3QJ1znUlEznFgrpcHO2cJ9RM7zSPOWKlsOaCde0CSqnQ3AiSkHRp2a0KdGKwgqgARbHWln1Q\nasFryDrjdSbVSi76Ul75VojbOCCrgq6WeHlNiqloSe+KIaF1haksLDliRFFMK1mcSXTiyCVTq1BL\nxDtD0e3OLqfWFZdiourSam2IaK1IuVBra2GAAsVc+tYy51pa4aXSaNVKJamVNbXMYitC1ZqpNBvb\nvCb21qJFETNc+RaLmVMm1Fab1BtL3xke1pnX64QywtF1eDRPVKYc8NrSO8NH45E34czDurLUhNPw\n8+sP8Nry5fmRU4zsreVHx2u8/t5Lw++N7/0rKyIY1SxOndF0xnAOEYxwcJ7RemLO3E0TTyGiEG6G\ngVIqvXesl7nttK7snOOm6/HOE2JGFMRUMEoh1mCsxjkLtTLNgRAS1mm0VljrW0mhojXW0nyqRim0\nURAVIWbIrZyy85YUC84KFcXQWdY543uNCBhteJym1lsGXA8dd+cJEQuKdustEVWEg+9QUnmQwDmt\n9EajqYjrCbkgVni1nDFOc79M7JxnWROHsePLcMYZRQmX1daaWgOxViRZkJJ5ShVT2+l1MJosEava\nAoNXlZgqyrTKXafbeq1SiZTUZdFEUVVi7Es75VZhLYp6qYR3KlO0wptMLJCSQUSIa/vSQgq6KooB\nFQ25tAUNpSpIRlmhXpYx1pJQRVhya3lWBkzRLJKxRRNKasVOktoihgiuqhbkXgtBmh830zbJOq+p\nRQgpgUkkaTnCTjXvcK2C1c1zPMU2Ox+tx2ZDqZFzDK3yxll22rLoyl1ccTExx8CLcd9CcEIADWtu\nT0lH33EtHrRisI5TWPn14z2ds3xyOHCwnlgzIWcOvuODYcdHuz29tpsz4ffM9150AV70I1+en1rq\nk8DH+ytKqc0SJILVinPULdrOGqzSLCnxtK5UClYMx75vd9hGOMX2+BZj4WboOHRdC2S5CGYsBWtV\nS7zS0gz0RZjWGUFjrOCtJdWENe1XJNZSast11ZcZoZTKNK9tVkjB+2Zt6qzBaGHXdYSUGFxb3DiO\nisdlQWqr+blyPa/nM1YZkgLjNSor9saSlWHn4KvTCXQrEQwxsXc9a44or3iZJ6yGU164dh1TjgzW\n8rqc6KwilorCYEtE29RCtu2CREWoFcmaYgveVqgr3hQiGS+wpPYFpHXCq0yqgtKtrXiJFlUrVSKd\nL4hK7UScdcvlVW1MoDrBmlYRE7KBqkmt4RJ9iVIsBkyyhFJQSKuKry2DwfDbxQ4SVMmtMl5VnPLN\nCVGFmBNJQNVWIGlM+zLPUphTxijBOY0zjikltCqsKZFLwWiFRXNSmc5aQspQM6b3mKywCJ0znNfC\nY4qUDDvjOHYdQRKPcWFdM6NxXBvPXVwJkplLItTKh92Id21MsfOWUlulz58cby+Jeq3h49P9FZ3Z\nEsHeB5vo0oLMP91fNVEUuVxOrTyuK1qEQuWj/YE5xXdvzAKoGBltu/1VIkwhcDe3IPRBW+Ilyi9T\nOIfIaVmxKK6HgbHvWg5sbpctuZbWaiu6eSRV2w6a53bpZo2ms5ZYMlabS1pimzvWWvC6NVnEVFjX\nTFaXRlmtWFNqiWfGIVRiLHSu3XgfVMc8J6QIpmpu+o6vTqc20iBx1TseQ+TGDzzJQq8KX84RLTAq\nR5JEzoZFItVk7vOKRhNrYKieRSUGazjLzGAyIdMsU0RsF5Ca6F0gJsUahYxBTKI3BSUBp1rKl1MQ\nctvi61xg0IlYoUjb6FsvjcSaiHegTaBiCUlTqkZUxlDIvlX4pFpJyQCqRRwWhUhplfeuItG1L4bS\n7tmqaltyrchSUUqmFJCayJe6Hrl8nHSVdxe0xRSkajLtycVqQaRlYqhaSSW2nF7jMAjnnKiVlu8r\nwpXr8WJ4XRY6b4mxhep4bxiLY730rU91pasWa4TBeG7HkTKdeDNPDNVxM3T85HjLU1gJOTOnyOg8\nP7s+0G9i+17ZRPeCVRqrfntLu/ceb1ppYxNDoS4wp3hJIVQ8343EXN75ZHUWaqoch2bbqhTOc+D1\nNKOkMnhHyS0zQWfhcV1ZQsApy/XQMzrfRFjRGhVSxejmi3x74RdjYgoL+pLD6rQh1xZSkkpm6C3T\nGrDS6mFyKaxpJSUh1NBGGKaVS/ba4QqYOpNSvXhOFVdjx7S2yztVFaMzfHOe8M6w1syt63mMK8/9\nwMt4YuccL/MT9hJCI/WSLNYFbElMuYl61gUjFkymN4UkCweTmUtrInY24dSKqMreraxJOK+ejMZI\nYTAJbRasVAKCrcKaLbEovFvpdCIX1ZooqExzm5lqaV86xgVKcsSskKJaoA6JbNvTTC6FXAw1aSqJ\nmhRKVyqtH80UIUnbUFM061lSpflypT15lNrsZBbBWk13uVSrtdnPQop44+isuThbLEUqIWSy8XT6\n0nWmWpTlKbVGkEFprn3PvVqYYuR+XVgFnncDRcFcEtoJa25r06bA7lJrblUrutw5z8F5vl0mRus4\n+m473f4rsInuP4L7HavMsesYcnuTWt2KDl9PU8s3oDU3GK0us9kW0l0v9dZ755qft8LjujKFgFWW\nnR9Y4spaMzkWnkIg5sRgHFddjzG2mfdVu6TRps3b6iWhS2khh8h5avNLYxTeWCq1lRWmQu8Kc4hY\nZZqgU5nm5t18ihm0Yo6Bzlp6U+nsjlflTEyZaiq1KvrOUgCdKlFVlBW+XB/pnGGuiWdq5LFEnmnL\nt+XEaBVTndqFoQKkoiUz9DMxVeaskGpIKmGrRkxitAEI7ExiLYpcOzoXsKrl8452JWXF49oRa8um\nHU1Em/ZEslZDroo1t942ZyNeZwoQi6FQCYnLiTi1sHQXydmQcmtfyAWUSZQMuteU0pwWVCFKhaQQ\nXZvjoLaktlIrqYBTFSmVzGXx4m0ym27hOqoWRDSUwppym/FbS82Fpbb5clEwqHYRW0MAIzhrWdbA\nmtsTjZT2JFRC5qwCCNx0PR+NB+7CxMPaVn+dsfz8+hmdMfzm9MjjujAYx8+ubjn47j19ijZ+l010\n/xmICN6Y//o/8GwcW2gITaRDzryeJlJpCVzHzjf/5WWTp1zal50x9Na2D6P2PKwzqghOO3qj2ygj\nR6YMUwjknNldrGzeKmJKaK3JubScWklNhEW1DbRp4TwHvNNoaT7dlhmsSDpRsiHEiDUao5qlaAor\nUoSnNSBGWFJi7DxVwZgtXz89Yo3nZZywRuNKEyoj7QRYTeTbtGC1QXRml3aca+TgHBNPXPlAlhUt\nPV63S0mjM0f3xJKFU3DU6uh1QlfY2ZmdDhQp9KYwZ80Sm/NiT6FKZbCBVIQ3046CUDX0JuBMpLSg\nRlLR5NTsZtYmnCotASw7UEKMAqUJseiKmEzNhkyF1FZ3qRWlW19bqa0JQitNKgny24YLg0izh+Uq\n5FworgLNq2ykdeOJUpdTcWVJiZ11HJTCK0MpmSSVx3WmINx0A9oozimRVWZOid5Yrn1HuoyarFes\nIXE3T2it+HC3Y3RtfovA3nV8MBY+HPbsnN+cCf/KbK/+/yDqd4TYG8OL3Y5cCiJyiY4MPKzrJbOr\ncjsOPK0rWl96rpRcbro97jJOEA1vlpaN2+nmODingHPm0t8WEKlcdR2dt1inyamiVOtHM7oFsihp\nlejGKO7PMxXahpVxoGAwhpZ9UMkpc1ojRrX6713fMYeWVPVmmRElnPLCh+OetSQOruPzpzs6a3lZ\nzjgjrLQWWymC8is6RU4lopVFsUDeEcgcHCBnbv0ZJLHkkc6ktnmnEtfdmSkZ7peOLBavZpxkDnZm\nUIEiYHVmyo4leXoXcLUQRehMa6+4m8ZLw7Ew6ERnzyRjWMWQsiKRqFlhdUabTC2FVBxFWjtvraBV\nplLBFXQ25NK63WqCWMplbg6xNq+vYMi1Ahll2u8jp7bIEWttiyBaY53mPAWOrmVvUFsojlJCERhc\nT1wWzim20HLfPNdzbaE659IuWT8cd1QLv8mPKF1Za+ajfsePr275enriKQR2LvHJ7shxO93+QbCJ\n7u8Bo9TfM5WPrjWhlst82KiWPnaOoQkswu04MseIUm/HB+3Dt/e+iSSaKUTupjM5Q2+aPespRowy\nzDnxNC+oAsdhYBw6Sqmk3Gaz0xraKbdkjHVQC14ZvjnNFApGFJ23KNvaDZZSMGHlaV4JtTUJdNaS\nS2WKkVgKr9MJZTSBhU/7G84lcOMHvpjv2DmYOdNpR5W2misCoz+hUiFmQDkMZ9bcoQWO3YKWhVs3\nAZUpW7xJCAmvItduYomab5YdqWqO3UonCe0ynkRRbdZ6yh3naHEmMRIJorAXjjDoPwAAFCVJREFU\ne9g5OFIVMuBUZnAzMTtCVYSkoGQkV5RptfFVMtQWllNKoSSF1gWUQmzFFNVq2kurQSo1U7TGFWki\nW6E3uiWt6UqJmVwUKVd661rVErUFlpdCqpmD7bjqHae8tpLLS1j4YHuybtVFV77jfp35Zj7TF83H\nuz0f7w6cLu+pWAoH3/HZYeDou80G9gfEJrrvid9dpTz2fctUoIm0XD50c4xAfRefl2p9J8RKXerV\nxx6o2GKaEC8TOVe81aQE5xQRBU8xcF5WdFUcfKsWKhVybsHk05xQqlJzszQBmKL45nQi1kKumZtx\nYK6Zg3PchwUXNX+1vG7uCSWM1pKo5JpIKfJYZ4wohMjR3LDmyt46HtMb9iZizCOFgVJbuhVSuXVP\nnLJwjp5aez5wJ9ZsURQ+GSY0K0e7kkRzjjucSngSTgJXfiFE4evlilgNOxfwrHgbwMnlsV/zkDtC\nVBhdcNI6zapolC7ECKWoFqAuGdcHSrLMVVFzy1eWKpf4yHKJqmyXorlCLQUlLe0C3VaQEXnXwJtV\nRlSLeUyqrR93RjOFQgS6IhhpX87kRCiJKUcO1rNzjrlGYo6ghFLgxnr2zjHlgDMWpSpeG571Izfd\nwG9Oj1Qqz/sdV85vATV/YGyi+6/I35sPA8+GgXCZD9vLyfjVdGZJ7RG/Mwaj5N3SBJomKgWuhh6l\nhJQzpzXweopAm+WeYyRe7FDnGHicF4xSrXSw7wgpND9qSq3rimbaH3zXatRz5TfnEzFHomQ+Hg9E\nMp/Za75dT4Ss+Ovla4zVpFjY+Q5jYotXTAunslKweB2p9QZVNc+6lcITz+yJzpyJ64GFjBaDAM/d\nIxbPQ+gI9cBz9zWlCEYlPhvP6JoZbWDF8ioNTSwv/+z8Sqnw9XRgKQ5rCx2BwcJeR6qqpGJ4yD0p\nK5CCM+mybmxAKuFSM9Q25wpGF1QxLJfE+1oKtQiQUaJbtY+GnIWqW+CQsaoFy+SCrgIGrDEM2rQG\nYAVOWi5lRej7jrzOrR7osrLcXxYp+t5w3fewwimtpLUyWs9Pb2/JpfAmzJxTxGrNz2+fc9uP7/Ot\nvPHPYBPdPyB+96JOifB8GEktVODdRd2r6Uwqbd300HUsb1eEpS1aiAjaakbbHBPOWp7CwilEqJXe\nGeaUqFooUphr4WlaMEoYvOezoWsnbhHezDOhVNb0djOqAyqxZr6enzjHQFArPxxekIl8trvh5fKE\nkYH7+jnOWdbYzPtGHEYySa2sJRGLox8SIR0oteejfgF54AP7QK8XpnRDQeikhcvf2jO6Jl6FPXOx\n/Fv3DaUoLIXPdndoCkYyQRxvgkeAUa9N0GxCE/l2vmJKDkThSexsplctYjMmw1MWStJUKloKIpf8\nBqmspWX9tiJTsKbiqpBKJWWB0kYF6NpWd0tplUQ0D7XWgmTVYh8vVsPON8tfjZmYWz/baA175zhN\nic46qrQnk0K7gHW2rfW+vTO46sdW2+48R98x2q0Y8g+ZTXT/wNGXeuu3eGP4YLcnlYzQRHoKoV10\nXTJgj33HdAmiFml+URUV1ktzTCD0IXKKgRAzVSp9Z1hixveOlDMlwf08A5Vj3/FsN3AOEa2Fvz3d\nM6+Rx7DSG8uua8E7ISleTSeWnOm7MzfygliFDz28jqfWKqx/RZKRJWg6cUi1DDoTypklC6EM7Lpv\nmHNHwPOj/gEtiQ/dI14HzsmxFo29ZNge9RltMy/XPU/F8yfjHblkHIUf7u7aQgTCIh2noMkoOolk\n0WhdMZKYg2UuhoxpvWc20mlHEiFnzTlmapbLebS2C7LcYiRTFiS3tQiRjHWmNSFnLiEPBW1ca4yw\nlse84rVDa0VOBW00oyhW1ZYpYil4A6EWno0dz3c9sVS+OZ+ZYqAq+NnNM267nt+cHnhYVoxSfLw7\n8KwftlHCHwGb6P4R8rsXdYNz2Es7q5LWLquXhaewAApq5flu5BQC7mIXqg5Uihx6T+csApzWlSUm\nntaFcjkRx5K53Y0sKbHkzN0ycdCe/d7yyfHAFCNKVf5y+pqcE/dpZuc8RztgteMUE/dhJVTw+kQp\nH2Jw/LDPPKUFpxTX9gtSueZN2uFFU3D0qnKlzkzZstSOn/ffMhfHUhw/618jqvDCPuF1YsmGqTj8\nZSxy0AuOyDfrnrs8cNvNmJLoRPikv0dLJVTDU21z3gRYCkoiVVmszqwThKooRWGoaNvWh1fXRDbW\nVpYJpY0XFOiiW0debX5tUa3BWNXm254vYelD19FJa6vwql3EaW04Dh1384J3bZ5/XpvrxRnDh/sd\nV71nTZmcM94YbvuRvfccfU9nto/yHwvbb+o7gtUay28v6666js6Yd44Jq1rWwzkEhFZ//uHhwFMI\n70YaQ63EuvJsN9Jb04LO54mYC6+niULLYADhOHTMsfmHXy0zt3qP+MInuyNLqnT6yOfr31IQ7uOJ\nvekZzIBmR1rh21CoJPb6gTk9J3DFj/qVc0pYUXxgv2Itz/g67PAi5GrxKG70wkPqmbLmp90dSzHM\nteOnwyuUFA5mwenEy3XHOTVXRK1CpxIdgW/DnqfsMUYwNaNU5kW3gMrE7HgolZw1SRQCGDJZKURV\nclaUCKpWVAVrK90lAUwKbdab2p99G7o+WtesfFXonCbEArWw67t22q1Cyone2XeXp7f9yM57Tilg\nTfMLP+87fnK84SEEXk4nTmHl2PXb6faPkE10v8P87kXddde/K2A0l7FFeOeYaMb/D3c7phjfOStK\n7XhYFj49HhmsIZXKt9OZmguvpjMVGL1lMCO73rKmxDfnJ14vKzv1IbpbedE9Y02ane455b9AxPMY\nM1d2h9M7RI4QO75ajhj1mmfmgad05FSe8ZN+YirtC+UD95op3/LXyy1OCrk6DIprtXKfes7V8LF/\nZCmWOXd8NrbxglYZbQqPq+ec/GVJBQwVZxfuV8cUHUUMplaciigvZAGS5jFbStGt0QMQKa3hgoJU\nixSaLxmwTuOsRqEZsEgGXdto5zgMvD7PjNYhVVCiOPaWmNvs/dB1PIWVx7CQKHw07nix37NcGi5i\nbYsX/+b6lmfDuNnA/kjZRPd7xH+zUQc8H4bWilvbaVmAUNrmE7XijObjw4FweaSFljlwN0/87PpF\na6TIia+nJ0yFX5/fILq16952N3QOYkl88fTA19OB3v6Ug5v5U+MIZUeSipH/G6Ouuc+egxyBPaKO\n5LrnV+GGQd3zoV2Z856HcsOf9vdM2aLRfGDvWIvwi+U5Tiqp+OZjVTNv8sCb4vm0f2AtirV0fNQ/\nYiWRRJPEELNwTi2BLEsLqulMIJeuLVAUg6otSN24RKyCLpaQoRZNjlATiK302rMQm6hmRc60te3U\nts0smufjvgW/U1EIB+NZpDB4z6HvmVPEG423rWnkxdAqnr56empJcV3Hsetbjf3GHyWb6H7PkUuq\n2n/N82Ek5Ezlt/kT355P707EozOM7opSf/tzLYq7+cz/dP0JO98aKV7OT1jgL6ev8EaTxfOB/znO\ntljDX53f8IvTj3jhKr2Z+fGw41RveWLloP4cL4a7dCTrGwo7kuzI5cgXyTOq13xkA3PZcZ9v+Hl/\nz7lYFIqP3R2q7Pkv84cYgbWs5Ko46Jm7suMpe0bbyi3PpePGB4xEYnWcsZSiWJIiF0Mpuq0664Ch\nI6dKQaErKFWxXSVmhUFjq6YGoWZa+4XNfLAbeZwyg9Jopcgp4zrLqNsTx80wcoqROQbu54nrfuDT\nw4GqhPu5OU68MXx6dcXzYdxGCd8BNtHd+G/QStH/zqPr83HHmhLQ/MK5Vr45nVhqodYWuHLT9e9C\n4WuteGX5dnriPxw/Y+88T+nMq+WEE+Ev58/plaFWx2D+F6yNHKrmb56+5n9/+nf8aafQJvBpd+As\nH3NXFm7VX9NLx8s4cNDPKbwmsiOWa74Mnp1+zcduYikdb8otP+0fmIrFAh/6b7kPPb+cXiC1rTnX\nCju9kIthKYYqBlMUqVg6A0ov1GJ5ShWpljUbJAsKTRYDOrETz1wKUg3UirFt0cQmxU4ZPJ4bX6hC\n2/SLiY+GPUZ0K9e85GHsnOfYtQr7zrVEMKkwWMuh69g5twnud4RNdDf+SRilMO63/k9Nm/8uKSG0\nxuJUCt+cT6TSanM+6g+88DuUbm6LUm/5G3nJq/WJf3f1M46+5z7c8xhmnMAvz5+z0z0rmVn+V6zN\nPHPC6fFr/uPTv+fPBgUm84nveawf8brM3Khf0UvPl3HPVfmQzB0re1K95ss4MKhXPLevCMXxmG/5\ntJuZkqGKZeceSVnz+fkGqqWokQp4M5FLZS4geDotlKIpolAu40U4xYquHpLBVoUi04tmqpln/a5d\nuJWWUkYVBq/onaPTlutxRFfh9TS3dWBj+PR45Nh1fHs+87QGtNLcDAPPx3EbJXzH2ER3478bq/Xf\nW2/WSvHhbs+cWoLYYNvF2rfTRFZt9faz3TN+MNxidbvIy/UD/urhN3y7PPDz/U+58SN38Q1TChhV\n+JvT5+z0QAS+4n/D2Mq1gcfTl/zH03/gf+4rooVPveW+/oCHMnGl/g5Pzxfrkb3+iH8znFnrFbEe\nuYs9Xn3LtZ0I1TLlIy86WHNiqgXUhKXyKmtK7hG5ggreTq1+PmaM8uydbbPc2jbKnLKYYjjgGPEY\n0+rn953lzXnlhbvi+bhjiZk1RgbbgsWfjyNrail0nbVcDwNe63fuk+10+91D6iVq8B/gH/3hxsY/\nhTUl5hgRJeysY46R1/OE0bpV1ighl+a2UCKkkvgvT3/Hq/kOJZXbbs/r8IpcK0bBXz39CimKlYnn\n3S3eGFIu/Hr+Bs3En/X/iSnDL5dbTvVH/Nnur+nlb/n18oL/4+E5e1342f7EWgbOccd9cqj6EmM8\n51yAPWtyxCysYeREYmeEee1ISTGoq2b1Ugu7uuMhBJ65kQFPipXBWVTR7DvPB8OOnOGDw8itH/nq\n8ZGboccozYe7HR9dHbhfVx7nmRe7HdYYng3Ddrr94+cf/AVuorvxr8IcI3OKbZ55EeK7ZcaqtuSh\nRQg50xnbGjNK5C8ef8k30x1GVW77A6/Da0pVeKX5y6fPoVSCrNy6I6PdEXLi8+kbNIGf+v9ErYm/\nnJ+x8hk/Hh6p9VfcrUf+/HHHqCvPek+sFqk9S1KcyyssI/ep4GVAV8+aoSTHOWV2xjDUgSXCrR4Y\ntOOBmU+GI/McuOlHfnZ8zpRXcqnsnaezhh/srphSorOWDw87Ys5oEa76vsViboL7XWAT3Y0/fKYQ\nmFNCi7DznnMIvFkWvFakUi+n4N8KcSiB/+fhL3g131NJXPdH3sQ7VLH0tuMXj19QciaQObo9N/6K\nJUV+NX+DULnSf4Ej8atpD3zMs95xKl+xJs/nD4bBCEd/JBTopYOqeBnu2ame+zXTiecD1wRUZ4Gg\nGJzlJ9e3pAy9NfzgsOerxxPXvqfXlqux40fX18RSeHk6cRx6BmM59lv84neMTXQ3/jg5rStTShjV\nquLPIXC/LK2doRQQIeVIbz1KhLWs/Of7P+fNciLUyG135CGe0Fh2rucX91+SSiLWzMGOfDjccoqR\nz88vUaJZy6/pdOHrydDJDbf9kbv1DTkr7qe2tvu8u2Yumb32HLTni/MTB+uoQTMow79//jFTXFlL\nYe8dB9fz0W6kiJBT4dm+Wb+0CNeXPr2N7xyb6G58N6i18hQCUwwYUVx1Ldznbp7pL1a2VAu5ZHpt\nMVqz5sD/ef9LnsKZc145uh1PacGK4aa74hf3X7GkRK7QW8+f7G54DJEvpjdo0bxe7/BGcZrBq55P\n+hvuwpmQMjWAV5bP9tfEWNBa8cOra+5OZ3rnuHYeby2f3d6gRHh1PrN3jsG7FlC/Ce53lU10N767\n1Fp5DCvnsKJQXPc9S0rczTODtZRaWFKiUOiMxSnDmgP/+eFvOa0rT3Fm70amnNAoPh6O/PLhFee4\nAgqnDD+5umVOid+cHrFK8XKecEqw2WJE+On+GTFmzjly1XUcnOfT4zUa4XFZOfYd1hqMUtwO/Ta3\n/e6zie7G94taKw/LwikEtBKu+541Ze6Xmd600sY5RXLNOG3ojSPkxP91/wVTXHmzruysp1QuVrcb\nfv34wF2YcGiMVvzk6hm1wDfzGaeEEiog/PBwRS6VH95c0VvH/TTTOcvgLOOlFXrjO88muhsbb4X4\nKQQEuOl7Yi3cTTOjtVTgFAKVjJIW2B5L5hcPL5nCyuPaEtksmjUXfnK8ZgmJl6czXmu8svzweMXO\nO96cZ6zR7LzDaM31sOUlfM/YRHdj4/+LWiv3y8zjGgC47nsU8GqaGC5Ja08hUGumIFx3PTFnfn26\n5xwjJVWMKK5dxxwTHx8PXHXtws+q5tV1Rm+C+/1jE92NjX+Mt58DEXknxA/rCsCV9zhteHk+M1qL\niPC0rigBhbB3ngI8zgtrStwMA1qr7XT7/WYT3Y2Nfy7l8tl4K5xv5umdEA+mBdE8TDP+sq67xoRS\nwqHzGL2dbr/nbKK7sfEvQS6FCu/qkk5r4BwCVDBabUsOG2/ZRHdj4/fFWyHWIpsVbOMtm+hubGxs\nvEf+QdHdnoM2NjY23iOb6G5sbGy8RzbR3djY2HiPbKK7sbGx8R7ZRHdjY2PjPbKJ7sbGxsZ7ZBPd\njY2NjffIJrobGxsb75FNdDc2NjbeI5vobmxsbLxHNtHd2NjYeI9soruxsbHxHtlEd2NjY+M9sonu\nxsbGxntkE92NjY2N98gmuhsbGxvvkU10NzY2Nt4jm+hubGxsvEc20d3Y2Nh4j2yiu7GxsfEe2UR3\nY2Nj4z2yie7GxsbGe8T8//z8H6wR3tjY2Nj457OddDc2NjbeI5vobmxsbLxHNtHd2NjYeI9sorux\nsbHxHtlEd2NjY+M9sonuxsbGxnvk/wXIU1EV667/1QAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Random 3D kernel - HWDIO layout\n",
|
|
"kernel = np.array([\n",
|
|
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n",
|
|
" [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n",
|
|
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n",
|
|
" dtype=jnp.float32)[:, :, :, np.newaxis, np.newaxis]\n",
|
|
"\n",
|
|
"# 3D data - NHWDC layout\n",
|
|
"data = np.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)\n",
|
|
"x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]\n",
|
|
"data += (np.sin(2*x*jnp.pi)*np.cos(2*y*jnp.pi)*np.cos(2*z*jnp.pi))[None,:,:,:,None]\n",
|
|
"\n",
|
|
"print(\"in shapes:\", data.shape, kernel.shape)\n",
|
|
"dn = lax.conv_dimension_numbers(data.shape, kernel.shape,\n",
|
|
" ('NHWDC', 'HWDIO', 'NHWDC'))\n",
|
|
"print(dn)\n",
|
|
"\n",
|
|
"out = lax.conv_general_dilated(data, # lhs = image tensor\n",
|
|
" kernel, # rhs = conv kernel tensor\n",
|
|
" (1,1,1), # window strides\n",
|
|
" 'SAME', # padding mode\n",
|
|
" (1,1,1), # lhs/image dilation\n",
|
|
" (1,1,1), # rhs/kernel dilation\n",
|
|
" dn) # dimension_numbers\n",
|
|
"print(\"out shape: \", out.shape)\n",
|
|
"\n",
|
|
"# Make some simple 3d density plots:\n",
|
|
"from mpl_toolkits.mplot3d import Axes3D\n",
|
|
"def make_alpha(cmap):\n",
|
|
" my_cmap = cmap(jnp.arange(cmap.N))\n",
|
|
" my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3\n",
|
|
" return mpl.colors.ListedColormap(my_cmap)\n",
|
|
"my_cmap = make_alpha(plt.cm.viridis)\n",
|
|
"fig = plt.figure()\n",
|
|
"ax = fig.gca(projection='3d')\n",
|
|
"ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)\n",
|
|
"ax.axis('off')\n",
|
|
"ax.set_title('input')\n",
|
|
"fig = plt.figure()\n",
|
|
"ax = fig.gca(projection='3d')\n",
|
|
"ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)\n",
|
|
"ax.axis('off')\n",
|
|
"ax.set_title('3D conv output');"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "DKTMw6tRZyK2"
|
|
},
|
|
"source": [
|
|
"## 🔪 NaNs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "ncS0NI4jZrwy"
|
|
},
|
|
"source": [
|
|
"### Debugging NaNs\n",
|
|
"\n",
|
|
"If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:\n",
|
|
"\n",
|
|
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
|
|
"\n",
|
|
"* adding `from jax.config import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
|
|
"\n",
|
|
"* adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
|
|
"\n",
|
|
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
|
|
"\n",
|
|
"There could be tricky situations that arise, like nans that only occur under a `@jit` but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute.\n",
|
|
"\n",
|
|
"If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line `env JAX_DEBUG_NANS=True ipython`, then ran this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"```\n",
|
|
"In [1]: import jax.numpy as jnp\n",
|
|
"\n",
|
|
"In [2]: jnp.divide(0., 0.)\n",
|
|
"---------------------------------------------------------------------------\n",
|
|
"FloatingPointError Traceback (most recent call last)\n",
|
|
"<ipython-input-2-f2e2c413b437> in <module>()\n",
|
|
"----> 1 jnp.divide(0., 0.)\n",
|
|
"\n",
|
|
".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n",
|
|
" 343 return floor_divide(x1, x2)\n",
|
|
" 344 else:\n",
|
|
"--> 345 return true_divide(x1, x2)\n",
|
|
" 346\n",
|
|
" 347\n",
|
|
"\n",
|
|
".../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)\n",
|
|
" 332 x1, x2 = _promote_shapes(x1, x2)\n",
|
|
" 333 return lax.div(lax.convert_element_type(x1, result_dtype),\n",
|
|
"--> 334 lax.convert_element_type(x2, result_dtype))\n",
|
|
" 335\n",
|
|
" 336\n",
|
|
"\n",
|
|
".../jax/jax/lax.pyc in div(x, y)\n",
|
|
" 244 def div(x, y):\n",
|
|
" 245 r\"\"\"Elementwise division: :math:`x \\over y`.\"\"\"\n",
|
|
"--> 246 return div_p.bind(x, y)\n",
|
|
" 247\n",
|
|
" 248 def rem(x, y):\n",
|
|
"\n",
|
|
"... stack trace ...\n",
|
|
"\n",
|
|
".../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)\n",
|
|
" 103 py_val = device_buffer.to_py()\n",
|
|
" 104 if np.any(np.isnan(py_val)):\n",
|
|
"--> 105 raise FloatingPointError(\"invalid value\")\n",
|
|
" 106 else:\n",
|
|
" 107 return DeviceArray(device_buffer, *result_shape)\n",
|
|
"\n",
|
|
"FloatingPointError: invalid value\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The nan generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jit`, as the example below shows."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"```\n",
|
|
"In [4]: from jax import jit\n",
|
|
"\n",
|
|
"In [5]: @jit\n",
|
|
" ...: def f(x, y):\n",
|
|
" ...: a = x * y\n",
|
|
" ...: b = (x + y) / (x - y)\n",
|
|
" ...: c = a + 2\n",
|
|
" ...: return a + b * c\n",
|
|
" ...:\n",
|
|
"\n",
|
|
"In [6]: x = jnp.array([2., 0.])\n",
|
|
"\n",
|
|
"In [7]: y = jnp.array([3., 0.])\n",
|
|
"\n",
|
|
"In [8]: f(x, y)\n",
|
|
"Invalid value encountered in the output of a jit function. Calling the de-optimized version.\n",
|
|
"---------------------------------------------------------------------------\n",
|
|
"FloatingPointError Traceback (most recent call last)\n",
|
|
"<ipython-input-8-811b7ddb3300> in <module>()\n",
|
|
"----> 1 f(x, y)\n",
|
|
"\n",
|
|
" ... stack trace ...\n",
|
|
"\n",
|
|
"<ipython-input-5-619b39acbaac> in f(x, y)\n",
|
|
" 2 def f(x, y):\n",
|
|
" 3 a = x * y\n",
|
|
"----> 4 b = (x + y) / (x - y)\n",
|
|
" 5 c = a + 2\n",
|
|
" 6 return a + b * c\n",
|
|
"\n",
|
|
".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n",
|
|
" 343 return floor_divide(x1, x2)\n",
|
|
" 344 else:\n",
|
|
"--> 345 return true_divide(x1, x2)\n",
|
|
" 346\n",
|
|
" 347\n",
|
|
"\n",
|
|
".../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)\n",
|
|
" 332 x1, x2 = _promote_shapes(x1, x2)\n",
|
|
" 333 return lax.div(lax.convert_element_type(x1, result_dtype),\n",
|
|
"--> 334 lax.convert_element_type(x2, result_dtype))\n",
|
|
" 335\n",
|
|
" 336\n",
|
|
"\n",
|
|
".../jax/jax/lax.pyc in div(x, y)\n",
|
|
" 244 def div(x, y):\n",
|
|
" 245 r\"\"\"Elementwise division: :math:`x \\over y`.\"\"\"\n",
|
|
"--> 246 return div_p.bind(x, y)\n",
|
|
" 247\n",
|
|
" 248 def rem(x, y):\n",
|
|
"\n",
|
|
" ... stack trace ...\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"When this code sees a nan in the output of an `@jit` function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with `%debug` to inspect all the values to figure out the error.\n",
|
|
"\n",
|
|
"⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "YTktlwTTMgFl"
|
|
},
|
|
"source": [
|
|
"## Double (64bit) precision\n",
|
|
"\n",
|
|
"At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to `double`. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 164,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "CNNGtzM3NDkO",
|
|
"outputId": "d1384021-d9bf-450f-a9ae-82024fa5fc1a"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"dtype('float32')"
|
|
]
|
|
},
|
|
"execution_count": 164,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
|
|
"x.dtype"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "VcvqzobxNPbd"
|
|
},
|
|
"source": [
|
|
"To use double-precision numbers, you need to set the `jax_enable_x64` configuration variable __at startup__. \n",
|
|
"\n",
|
|
"There are a few ways to do this:\n",
|
|
"\n",
|
|
"1. You can enable 64bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n",
|
|
"\n",
|
|
"2. You can manually set the `jax_enable_x64` configuration flag at startup:\n",
|
|
"\n",
|
|
"```\n",
|
|
"# again, this only works on startup!\n",
|
|
"from jax.config import config\n",
|
|
"config.update(\"jax_enable_x64\", True)\n",
|
|
"```\n",
|
|
"\n",
|
|
"3. You can parse command-line flags with `absl.app.run(main)`\n",
|
|
"\n",
|
|
"```\n",
|
|
"from jax.config import config\n",
|
|
"config.config_with_absl()\n",
|
|
"```\n",
|
|
"\n",
|
|
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
|
|
"\n",
|
|
"```\n",
|
|
"from jax.config import config\n",
|
|
"if __name__ == '__main__':\n",
|
|
" # calls config.config_with_absl() *and* runs absl parsing\n",
|
|
" config.parse_flags_with_absl()\n",
|
|
"```\n",
|
|
"\n",
|
|
"Note that #2-#4 work for _any_ of JAX's configuration options.\n",
|
|
"\n",
|
|
"We can then confirm that `x64` mode is enabled:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 165,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
},
|
|
"colab_type": "code",
|
|
"id": "HqGbBa9Rr-2g",
|
|
"outputId": "cd241d63-3d00-4fd7-f9c0-afc6af01ecf4"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"dtype('float32')"
|
|
]
|
|
},
|
|
"execution_count": 165,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import jax.numpy as jnp\n",
|
|
"from jax import random\n",
|
|
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
|
|
"x.dtype # --> dtype('float64')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "6Cks2_gKsXaW"
|
|
},
|
|
"source": [
|
|
"### Caveats\n",
|
|
"⚠️ XLA doesn't support 64-bit convolutions on all backends!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "WAHjmL0E2XwO"
|
|
},
|
|
"source": [
|
|
"## Fin.\n",
|
|
"\n",
|
|
"If something's not covered here that has caused you weeping and gnashing of teeth, please let us know and we'll extend these introductory _advisos_!"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"name": "Common Gotchas in JAX",
|
|
"provenance": [],
|
|
"toc_visible": true
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|