rocm_jax/docs/notebooks/XLA_in_Python.ipynb
George Necula eae59d0b2c
Moved all notebooks to docs/notebooks. (#1493)
* Moved all notebooks to docs/notebooks.

Now all notebooks are in the same place, thus all are subject
to auto-doc generation at readthedocs.io and to automated testing
with travis.

Some notebooks are too slow, exclude them at docs/conf.py:exclude_patterns.

Cleanup a bit the section headings in notebooks so that they show
up well in readtehdocs.io.

* Increase the cell timeout for executing notebooks
* Exclude also the neural network notebook from auto-generation (timing out)
* Disable the score_matching notebook from auto-doc (travis does not have sklearn)
2019-10-17 08:58:25 +02:00

854 lines
47 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "XLA in Python.ipnb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "sAgUgR5Mzzz2"
},
"source": [
"# XLA in Python\n",
"\n",
"<img style=\"height:100px;\" src=\"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/compiler/xla/g3doc/images/xlalogo.png\"> <img style=\"height:100px;\" src=\"https://upload.wikimedia.org/wikipedia/commons/c/c3/Python-logo-notext.svg\">\n",
"\n",
"_Anselm Levskaya_ \n",
"\n",
"XLA is the compiler that JAX uses, and the compiler that TF uses for TPUs and will soon use for all devices, so it's worth some study. However, it's not exactly easy to play with XLA computations directly using the raw C++ interface. JAX exposes the underlying XLA computation builder API through a python wrapper, and makes interacting with the XLA compute model accessible for messing around and prototyping.\n",
"\n",
"XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.). \n",
"\n",
"As end users we interact with the computational primitives offered to us by the HLO spec."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "EZK5RseuvZkr"
},
"source": [
"## References \n",
"\n",
"__xla__: the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n",
"\n",
"https://www.tensorflow.org/xla/operation_semantics\n",
"\n",
"more details on ops in the source code.\n",
"\n",
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h\n",
"\n",
"__python xla client__: this is the XLA python client for JAX, and what we're using here.\n",
"\n",
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py\n",
"\n",
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client_test.py\n",
"\n",
"__jax__: you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n",
"\n",
"https://github.com/google/jax/blob/master/jax/lax.py\n",
"\n",
"https://github.com/google/jax/blob/master/jax/lib/xla_bridge.py\n",
"\n",
"https://github.com/google/jax/blob/master/jax/interpreters/xla.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "3XR2NGmrzBGe"
},
"source": [
"## Colab Setup and Imports"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "Ogo2SBd3u18P",
"colab": {}
},
"source": [
"# We import as onp to emphasize that we're using vanilla numpy, not jax numpy.\n",
"import numpy as onp\n",
"\n",
"# We only need to import JAX's xla_client, not all of JAX.\n",
"from jaxlib import xla_client\n",
"\n",
"# Plotting\n",
"import matplotlib as mpl\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib import gridspec\n",
"from matplotlib import rcParams\n",
"rcParams['image.interpolation'] = 'nearest'\n",
"rcParams['image.cmap'] = 'viridis'\n",
"rcParams['axes.grid'] = False"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0cf7swaobc5l"
},
"source": [
"## Convenience Functions"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "5I50k0rhbg6W",
"colab": {}
},
"source": [
"# Here we borrow convenience functions from JAX to convert numpy shape/dtypes\n",
"# to XLA appropriate shape/dtypes\n",
"def canonicalize_dtype(dtype):\n",
" \"\"\"We restrict ourselves to 32bit types for this demo.\"\"\"\n",
" _dtype_to_32bit_dtype = {\n",
" str(onp.dtype('int64')): onp.dtype('int32'),\n",
" str(onp.dtype('uint64')): onp.dtype('uint32'),\n",
" str(onp.dtype('float64')): onp.dtype('float32'),\n",
" str(onp.dtype('complex128')): onp.dtype('complex64'),\n",
" }\n",
" dtype = onp.dtype(dtype)\n",
" return _dtype_to_32bit_dtype.get(str(dtype), dtype)\n",
"\n",
"def shape_of(value):\n",
" \"\"\"Given a Python or XLA value, return its canonicalized XLA Shape.\"\"\"\n",
" if hasattr(value, 'shape') and hasattr(value, 'dtype'):\n",
" return xla_client.Shape.array_shape(canonicalize_dtype(value.dtype), \n",
" value.shape)\n",
" elif onp.isscalar(value):\n",
" return shape_of(onp.asarray(value))\n",
" elif isinstance(value, (tuple, list)):\n",
" return xla_client.Shape.tuple_shape(tuple(shape_of(elt) for elt in value))\n",
" else:\n",
" raise TypeError('Unexpected type: {}'.format(type(value)))\n",
"\n",
"def to_xla_type(dtype):\n",
" \"Convert to integert xla type, for use with ConvertElementType, etc.\"\n",
" if isinstance(dtype, str):\n",
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype]\n",
" elif isinstance(dtype, type):\n",
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[onp.dtype(dtype).name]\n",
" elif isinstance(dtype, onp.dtype):\n",
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype.name]\n",
" else:\n",
" raise TypeError('Unexpected type: {}'.format(type(dtype)))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "odmjXyhMuNJ5"
},
"source": [
"## Simple Computations"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "UYUtxVzMYIiv",
"outputId": "bd8aa18e-26d9-4df4-ebc3-20026119de17",
"colab": {
"height": 33
}
},
"source": [
"# make a computation builder\n",
"c = xla_client.ComputationBuilder(\"simple_scalar\")\n",
"\n",
"# define a parameter shape and parameter\n",
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), ())\n",
"x = c.ParameterWithShape(param_shape)\n",
"\n",
"# define computation graph\n",
"y = c.Sin(x)\n",
"\n",
"# build computation graph\n",
"# Keep in mind that incorrectly constructed graphs can cause \n",
"# your notebook kernel to crash!\n",
"computation = c.Build()\n",
"\n",
"# compile graph based on shape\n",
"compiled_computation = computation.Compile([param_shape,])\n",
"\n",
"# define a host variable with above parameter shape\n",
"host_input = onp.array(3.0, dtype=onp.float32)\n",
"\n",
"# place host variable on device and execute\n",
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
"device_out = compiled_computation.Execute([device_input ,])\n",
"\n",
"# retrive the result\n",
"device_out.to_py()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array(0.14112, dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "rIA-IVMVvQs2",
"outputId": "ce88ec6f-d2ea-4ec2-80b4-ddd1afd36957",
"colab": {
"height": 33
}
},
"source": [
"# same as above with vector type:\n",
"\n",
"c = xla_client.ComputationBuilder(\"simple_vector\")\n",
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), (3,))\n",
"x = c.ParameterWithShape(param_shape)\n",
"\n",
"# can also use this function to define a shape from an example:\n",
"#x = c.ParameterFromNumpy(onp.array([0.0, 0.0, 0.0], dtype=onp.float32))\n",
"\n",
"# which is the same as using our convenience function above:\n",
"#x = c.ParameterWithShape(shape_of(onp.array([0.0, 0.0, 0.0], \n",
"# dtype=onp.float32)))\n",
"\n",
"# chain steps by reference:\n",
"y = c.Sin(x)\n",
"z = c.Abs(y)\n",
"computation = c.Build()\n",
"compiled_computation = computation.Compile([param_shape,])\n",
"\n",
"host_input = onp.array([3.0, 4.0, 5.0], dtype=onp.float32)\n",
"\n",
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
"device_out = compiled_computation.Execute([device_input ,])\n",
"\n",
"# retrive the result\n",
"device_out.to_py()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0.14112 , 0.7568025, 0.9589243], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "F8kWlLaVuQ1b"
},
"source": [
"## Simple While Loop"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "MDQP1qW515Ao",
"outputId": "4da894b5-2b0e-455e-a720-3bdadc57d164",
"colab": {
"height": 33
}
},
"source": [
"# trivial while loop, decrement until 0\n",
"# x = 5\n",
"# while x > 0:\n",
"# x = x - 1\n",
"#\n",
"in_shape = shape_of(5)\n",
"\n",
"# body computation:\n",
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
"x = bcb.ParameterWithShape(in_shape)\n",
"const1 = bcb.Constant(onp.int32(1))\n",
"y = bcb.Sub(x, const1)\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation:\n",
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
"x = tcb.ParameterWithShape(in_shape)\n",
"const0 = tcb.Constant(onp.int32(0))\n",
"y = tcb.Gt(x, const0)\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
"x = wcb.ParameterWithShape(in_shape)\n",
"wcb.While(test_computation, body_computation, x)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"host_input = onp.array(5, dtype=onp.int32)\n",
"\n",
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
"device_out = compiled_computation.Execute([device_input ,])\n",
"\n",
"# retrive the result\n",
"device_out.to_py()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array(0, dtype=int32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7UOnXlY8slI6"
},
"source": [
"## While loops w. tuples - Newton's Method for sqrt"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "HEWz-vzd6QPR",
"outputId": "6ef10855-232d-4701-a442-0e2667b2fd97",
"colab": {
"height": 33
}
},
"source": [
"Xsqr = 2\n",
"guess = 1.0\n",
"converged_delta = 0.001\n",
"maxit = 1000\n",
"\n",
"in_shape = shape_of((1.0, 1.0, 1))\n",
"\n",
"# body computation:\n",
"# x_{i+1} = x_{i} - (x_i**2 - y) / (2 * x_i)\n",
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
"intuple = bcb.ParameterWithShape(in_shape)\n",
"y = bcb.GetTupleElement(intuple, 0)\n",
"x = bcb.GetTupleElement(intuple, 1)\n",
"guard_cntr = bcb.GetTupleElement(intuple, 2)\n",
"new_x = bcb.Sub(x, bcb.Div(bcb.Sub(bcb.Mul(x, x), y), bcb.Add(x, x)))\n",
"result = bcb.Tuple(y, new_x, bcb.Sub(guard_cntr, bcb.Constant(onp.int32(1))))\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation -- convergence and max iteration test\n",
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
"intuple = tcb.ParameterWithShape(in_shape)\n",
"y = tcb.GetTupleElement(intuple, 0)\n",
"x = tcb.GetTupleElement(intuple, 1)\n",
"guard_cntr = tcb.GetTupleElement(intuple, 2)\n",
"criterion = tcb.Abs(tcb.Sub(tcb.Mul(x, x), y))\n",
"# stop at convergence criteria or too many iterations\n",
"test = tcb.And(tcb.Gt(criterion, tcb.Constant(onp.float32(converged_delta))), \n",
" tcb.Gt(guard_cntr, tcb.Constant(onp.int32(0))))\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
"intuple = wcb.ParameterWithShape(in_shape)\n",
"wcb.While(test_computation, body_computation, intuple)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"y = onp.array(Xsqr, dtype=onp.float32)\n",
"x = onp.array(guess, dtype=onp.float32)\n",
"maxit = onp.array(maxit, dtype=onp.int32)\n",
"\n",
"device_input = xla_client.LocalBuffer.from_pyval((y, x, maxit))\n",
"device_out = compiled_computation.Execute([device_input ,])\n",
"\n",
"host_out = device_out.to_py()\n",
"print(\"square root of {y} is {x}\".format(y=y, x=host_out[1]))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"square root of 2.0 is 1.4142156839370728\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "yETVIzTInFYr"
},
"source": [
"## Calculate Symm Eigenvalues"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "AiyR1e2NubKa"
},
"source": [
"Let's exploit the XLA QR implementation to solve some eigenvalues for symmetric matrices. \n",
"\n",
"This is the naive QR algorithm, without acceleration for closely-spaced eigenvalue convergence, nor any permutation to sort eigenvalues by magnitude."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "wjxDPbqCcuXT",
"outputId": "9683e40b-3c5f-4f3e-c971-0613b182c68c",
"colab": {
"height": 455
}
},
"source": [
"Niter = 200\n",
"matrix_shape = (10, 10)\n",
"in_shape = shape_of(\n",
" (onp.zeros(matrix_shape, dtype=onp.float32), 1)\n",
")\n",
"# NB: in_shape is the same as the manually constructed:\n",
"# xla_client.Shape.tuple_shape(\n",
"# (xla_client.Shape.array_shape(onp.dtype(onp.float32), matrix_shape), \n",
"# xla_client.Shape.array_shape(onp.dtype(onp.int32), ()))\n",
"# )\n",
"\n",
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
"intuple = bcb.ParameterWithShape(in_shape)\n",
"x = bcb.GetTupleElement(intuple, 0)\n",
"cntr = bcb.GetTupleElement(intuple, 1)\n",
"QR = bcb.QR(x)\n",
"Q = bcb.GetTupleElement(QR, 0)\n",
"R = bcb.GetTupleElement(QR, 1)\n",
"RQ = bcb.Dot(R, Q)\n",
"bcb.Tuple(RQ, bcb.Sub(cntr, bcb.Constant(onp.int32(1))))\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation -- just a for loop condition\n",
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
"intuple = tcb.ParameterWithShape(in_shape)\n",
"cntr = tcb.GetTupleElement(intuple, 1)\n",
"test = tcb.Gt(cntr, tcb.Constant(onp.int32(0)))\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
"intuple = wcb.ParameterWithShape(in_shape)\n",
"wcb.While(test_computation, body_computation, intuple)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"X = onp.random.random(matrix_shape).astype(onp.float32)\n",
"X = (X + X.T) / 2.0\n",
"it = onp.array(Niter, dtype=onp.int32)\n",
"\n",
"device_in = xla_client.LocalBuffer.from_pyval((X, it))\n",
"device_out = compiled_computation.Execute([device_in,])\n",
"\n",
"host_out = device_out.to_py()\n",
"eigh_vals = host_out[0].diagonal()\n",
"\n",
"plt.title('D')\n",
"plt.imshow(host_out[0])\n",
"print('sorted eigenvalues')\n",
"print(onp.sort(eigh_vals))\n",
"print('sorted eigenvalues from numpy')\n",
"print(onp.sort(onp.linalg.eigh(X)[0]))\n",
"print('sorted error') \n",
"print(onp.sort(eigh_vals) - onp.sort(onp.linalg.eigh(X)[0]))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"sorted eigenvalues\n",
"[-1.190547 -0.91282177 -0.32339668 -0.14050038 -0.09441247 0.08265306\n",
" 0.49015656 0.731502 1.0677357 5.3513203 ]\n",
"sorted eigenvalues from numpy\n",
"[-1.1905469 -0.9128221 -0.32339665 -0.14050038 -0.09441243 0.08265309\n",
" 0.49015662 0.7315014 1.0677353 5.351319 ]\n",
"sorted error\n",
"[-1.1920929e-07 3.5762787e-07 -2.9802322e-08 0.0000000e+00\n",
" -3.7252903e-08 -2.9802322e-08 -5.9604645e-08 5.9604645e-07\n",
" 3.5762787e-07 1.4305115e-06]\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACtdJREFUeJzt3V+snwV9x/H3h9MWaGWg6JbYNrZLjEvD/mDODMrmBbhEJ5ObZcMEM71plkxFY+JwN97uwhi9cC4N6A1ELioXxBBxiXpBliClJcO2akjtoIChcwFJIzut/e7inJmuW8956nkennO+vF9Jk55ffzx8Aufd53ee/s7TVBWSerpi7gGSpmPgUmMGLjVm4FJjBi41ZuBSYwYuNWbgr1NJTib5ZZJXkryU5F+T/G0SPyca8X/m69tfVNU1wNuAfwT+Hrh33kkak4GLqnq5qh4C/hr4myQ3zL1J4zBw/VpV/QA4Bfzp3Fs0DgPXxZ4H3jT3CI3DwHWxncB/zj1C4zBw/VqSP2Y58Efn3qJxGLhI8ltJbgMeAO6rqqfm3qRxxO8Hf31KchL4HeAccB44BtwH/HNV/WrGaRqRgUuN+RJdaszApcYMXGrMwKXGtkxx0De/aaH27N46+nF/cuL60Y8pbUavvvoSS2fPZK3nTRL4nt1b+cEju0c/7p/91UdHP6a0GT1+5J8GPc+X6FJjBi41ZuBSYwYuNWbgUmMGLjU2KPAk70/y4yRPJ7l76lGSxrFm4EkWgK8AHwD2AR9Osm/qYZLWb8gZ/F3A01V1oqqWWL4pwO3TzpI0hiGB7wSeveDjUyuP/S9J9ic5lOTQ6Z97vwBpIxjtIltVHaiqxapafMv1C2MdVtI6DAn8OeDCN5bvWnlM0gY3JPDHgbcn2ZtkG3AH8NC0sySNYc3vJquqc0k+DjwCLABfq6qjky+TtG6Dvl20qh4GHp54i6SR+U42qTEDlxozcKkxA5caM3CpsUluuviTE9dPcoPEKx59cvRjApz/kz+a5LjS3DyDS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNTXJX1alMdffTn//+1aMf8/qnfjn6MaXL5RlcaszApcYMXGrMwKXGDFxqzMClxgxcamzNwJPsTvK9JMeSHE1y12sxTNL6DXmjyzngM1V1OMk1wBNJ/qWqjk28TdI6rXkGr6oXqurwys9fAY4DO6ceJmn9Luutqkn2ADcCj/0/v7Yf2A9w5ZXXjjBN0noNvsiW5A3AN4FPVdUvLv71qjpQVYtVtbht644xN0r6DQ0KPMlWluO+v6oenHaSpLEMuYoe4F7geFV9cfpJksYy5Ax+M/AR4JYkT678+POJd0kawZoX2arqUSCvwRZJI/OdbFJjBi41ZuBSYwYuNbapbro4lSlukLj1306MfkyAs3/wu5McVz15BpcaM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGvOuqhOZ6u6nr+y5apLjXnPy1UmOq3l5BpcaM3CpMQOXGjNwqTEDlxozcKkxA5caGxx4koUkR5J8a8pBksZzOWfwu4DjUw2RNL5BgSfZBXwQuGfaOZLGNPQM/iXgs8D5Sz0hyf4kh5IcWjp7ZpRxktZnzcCT3Aa8WFVPrPa8qjpQVYtVtbht647RBkr6zQ05g98MfCjJSeAB4JYk9026StIo1gy8qj5XVbuqag9wB/Ddqrpz8mWS1s0/B5cau6zvB6+q7wPfn2SJpNF5BpcaM3CpMQOXGjNwqTEDlxrzrqqbzFR3P331t6+c5LhXvfhfkxxXw3gGlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5ca866qAqa7++mWl6a5C+y5666a5LjdeAaXGjNwqTEDlxozcKkxA5caM3CpMQOXGhsUeJLrkhxM8qMkx5O8e+phktZv6Btdvgx8u6r+Msk2YPuEmySNZM3Ak1wLvBf4KEBVLQFL086SNIYhL9H3AqeBryc5kuSeJDsuflKS/UkOJTm0dPbM6EMlXb4hgW8B3gl8tapuBM4Ad1/8pKo6UFWLVbW4bev/6V/SDIYEfgo4VVWPrXx8kOXgJW1wawZeVT8Dnk3yjpWHbgWOTbpK0iiGXkX/BHD/yhX0E8DHppskaSyDAq+qJ4HFibdIGpnvZJMaM3CpMQOXGjNwqTEDlxrzrqqa1FR3P71i6VejH/P8toXRjzk3z+BSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNeZNF7UpTXGDxCvOnR/9mADnt8x3HvUMLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjU2KPAkn05yNMkPk3wjyTR/o5ykUa0ZeJKdwCeBxaq6AVgA7ph6mKT1G/oSfQtwdZItwHbg+ekmSRrLmoFX1XPAF4BngBeAl6vqOxc/L8n+JIeSHFo6e2b8pZIu25CX6G8Ebgf2Am8FdiS58+LnVdWBqlqsqsVtW3eMv1TSZRvyEv19wE+r6nRVnQUeBN4z7SxJYxgS+DPATUm2JwlwK3B82lmSxjDka/DHgIPAYeCplX/mwMS7JI1g0PeDV9Xngc9PvEXSyHwnm9SYgUuNGbjUmIFLjRm41Jh3VZVWTHX305riuMmgp3kGlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caS1WNf9DkNPDvA576ZuA/Rh8wnc20dzNthc21dyNsfVtVvWWtJ00S+FBJDlXV4mwDLtNm2ruZtsLm2ruZtvoSXWrMwKXG5g78wMz//su1mfZupq2wufZumq2zfg0uaVpzn8ElTcjApcZmCzzJ+5P8OMnTSe6ea8dakuxO8r0kx5IcTXLX3JuGSLKQ5EiSb829ZTVJrktyMMmPkhxP8u65N60myadXPg9+mOQbSa6ae9NqZgk8yQLwFeADwD7gw0n2zbFlgHPAZ6pqH3AT8HcbeOuF7gKOzz1igC8D366q3wP+kA28OclO4JPAYlXdACwAd8y7anVzncHfBTxdVSeqagl4ALh9pi2rqqoXqurwys9fYfkTcOe8q1aXZBfwQeCeubesJsm1wHuBewGqaqmqXpp31Zq2AFcn2QJsB56fec+q5gp8J/DsBR+fYoNHA5BkD3Aj8Ni8S9b0JeCzwPm5h6xhL3Aa+PrKlxP3JNkx96hLqarngC8AzwAvAC9X1XfmXbU6L7INlOQNwDeBT1XVL+becylJbgNerKon5t4ywBbgncBXq+pG4Aywka/HvJHlV5p7gbcCO5LcOe+q1c0V+HPA7gs+3rXy2IaUZCvLcd9fVQ/OvWcNNwMfSnKS5S99bkly37yTLukUcKqq/ucV0UGWg9+o3gf8tKpOV9VZ4EHgPTNvWtVcgT8OvD3J3iTbWL5Q8dBMW1aVJCx/jXi8qr449561VNXnqmpXVe1h+b/rd6tqQ55lqupnwLNJ3rHy0K3AsRknreUZ4KYk21c+L25lA18UhOWXSK+5qjqX5OPAIyxfifxaVR2dY8sANwMfAZ5K8uTKY/9QVQ/PuKmTTwD3r/xGfwL42Mx7LqmqHktyEDjM8p+uHGGDv23Vt6pKjXmRTWrMwKXGDFxqzMClxgxcaszApcYMXGrsvwFxd1oMhwT6kgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "FpggTihknAOw"
},
"source": [
"## Calculate Full Symm Eigensystem"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Qos4ankYuj1T"
},
"source": [
"We can also calculate the eigenbasis by accumulating the Qs."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "Kp3A-aAiZk0g",
"outputId": "ebdc1ecc-c9e1-4e95-b989-9645f8648ee0",
"colab": {
"height": 1000
}
},
"source": [
"Niter = 100\n",
"matrix_shape = (10, 10)\n",
"in_shape = shape_of(\n",
" (onp.zeros(matrix_shape, dtype=onp.float32), \n",
" onp.eye(matrix_shape[0]),\n",
" 1)\n",
")\n",
"\n",
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
"intuple = bcb.ParameterWithShape(in_shape)\n",
"X = bcb.GetTupleElement(intuple, 0)\n",
"O = bcb.GetTupleElement(intuple, 1)\n",
"cntr = bcb.GetTupleElement(intuple, 2)\n",
"QR = bcb.QR(X)\n",
"Q = bcb.GetTupleElement(QR, 0)\n",
"R = bcb.GetTupleElement(QR, 1)\n",
"RQ = bcb.Dot(R, Q)\n",
"Onew = bcb.Dot(O, Q)\n",
"bcb.Tuple(RQ, Onew, bcb.Sub(cntr, bcb.Constant(onp.int32(1))))\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation -- just a for loop condition\n",
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
"intuple = tcb.ParameterWithShape(in_shape)\n",
"cntr = tcb.GetTupleElement(intuple, 2)\n",
"test = tcb.Gt(cntr, tcb.Constant(onp.int32(0)))\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
"intuple = wcb.ParameterWithShape(in_shape)\n",
"wcb.While(test_computation, body_computation, intuple)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"X = onp.random.random(matrix_shape).astype(onp.float32)\n",
"X = (X + X.T) / 2.0\n",
"Omat = onp.eye(matrix_shape[0], dtype=onp.float32)\n",
"it = onp.array(Niter, dtype=onp.int32)\n",
"\n",
"device_in = xla_client.LocalBuffer.from_pyval((X, Omat, it))\n",
"device_out = compiled_computation.Execute([device_in,])\n",
"\n",
"host_out = device_out.to_py()\n",
"eigh_vals = host_out[0].diagonal()\n",
"eigh_mat = host_out[1]\n",
"\n",
"plt.title('D')\n",
"plt.imshow(host_out[0])\n",
"plt.figure()\n",
"plt.title('U')\n",
"plt.imshow(eigh_mat)\n",
"plt.figure()\n",
"plt.title('U^T A U')\n",
"plt.imshow(onp.dot(onp.dot(eigh_mat.T, X), eigh_mat))\n",
"print('sorted eigenvalues')\n",
"print(onp.sort(eigh_vals))\n",
"print('sorted eigenvalues from numpy')\n",
"print(onp.sort(onp.linalg.eigh(X)[0]))\n",
"print('sorted error') \n",
"print(onp.sort(eigh_vals) - onp.sort(onp.linalg.eigh(X)[0]))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"sorted eigenvalues\n",
"[-0.94551486 -0.63820213 -0.57944936 -0.28589356 -0.05510262 0.16862962\n",
" 0.4192178 0.4671099 0.88734317 4.990509 ]\n",
"sorted eigenvalues from numpy\n",
"[-0.9455159 -0.63820285 -0.5794492 -0.28589386 -0.05510259 0.16862962\n",
" 0.41921794 0.46710995 0.88734376 4.9905105 ]\n",
"sorted error\n",
"[ 1.0132790e-06 7.1525574e-07 -1.7881393e-07 2.9802322e-07\n",
" -2.9802322e-08 0.0000000e+00 -1.4901161e-07 -5.9604645e-08\n",
" -5.9604645e-07 -1.4305115e-06]\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACtZJREFUeJzt3V+snwV9x/H3h3NKoMWhIHGxbWyXGBfGsmCOBiXzArzAyeRmyTCBbN6QJVPRmDjcjbe7MEYvjFsDegORi8oFMURcol6YZZVDIYG2apraQfkzuxmQVUhb+93FOTOsW8956nkennO+vF8JSc+vPx4+gfPm+f1+5/d7mqpCUk+XzD1A0nQMXGrMwKXGDFxqzMClxgxcaszApcYM/E0qyfEkryZ5JclLSf4lyd8k8XuiEf9jvrn9eVW9BXgX8A/A3wH3zTtJYzJwUVUvV9XDwF8Cf5Xkurk3aRwGrt+qqh8DJ4A/nXuLxmHgOt/zwFVzj9A4DFzn2wn8cu4RGoeB67eSvI+VwH809xaNw8BFkt9LcivwIHB/VT019yaNI34e/M0pyXHgHcBZ4BxwGLgf+Meq+s2M0zQiA5ca8yG61JiBS40ZuNSYgUuNLU5x0LdftVB7dm8b/bg/O3r16MeUtqLXTr/E6TOnst79Jgl8z+5t/PjR3aMf95bb7hz9mNJW9K9P/9Og+/kQXWrMwKXGDFxqzMClxgxcaszApcYGBZ7kliQ/TXI0yT1Tj5I0jnUDT7IAfA34CHAt8PEk1049TNLGDTmDvx84WlXHquo0KxcFuG3aWZLGMCTwncCzr/v6xOpt/0uSu5IsJ1k++Z9eL0DaDEZ7ka2q9lXVUlUtXXP1wliHlbQBQwJ/Dnj9G8t3rd4maZMbEvhjwLuT7E1yKXA78PC0sySNYd1Pk1XV2SSfBB4FFoBvVNWhyZdJ2rBBHxetqkeARybeImlkvpNNaszApcYMXGrMwKXGDFxqbJKLLv7s6NWTXCCxHpvmz8TL+/54kuNKc/MMLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41NslVVacy1dVPf/lHV4x+zKsO/dfox5QulmdwqTEDlxozcKkxA5caM3CpMQOXGjNwqbF1A0+yO8kPkhxOcijJ3W/EMEkbN+SNLmeBz1XVwSRvAR5P8s9VdXjibZI2aN0zeFW9UFUHV3/9CnAE2Dn1MEkbd1HPwZPsAa4HDvw/v3dXkuUky2fOnhpnnaQNGRx4kiuAbwOfqapfnf/7VbWvqpaqamnb4o4xN0r6HQ0KPMk2VuJ+oKoemnaSpLEMeRU9wH3Akar68vSTJI1lyBn8RuBO4KYkT67+9WcT75I0gnV/TFZVPwLyBmyRNDLfySY1ZuBSYwYuNWbgUmNb6qKLU5niAomXHHtu9GMCnPsD3yWs4TyDS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNeVXViUx19dNLjr84yXHP7fn9SY6reXkGlxozcKkxA5caM3CpMQOXGjNwqTEDlxobHHiShSRPJPnOlIMkjedizuB3A0emGiJpfIMCT7IL+Chw77RzJI1p6Bn8K8DngXMXukOSu5IsJ1k+c/bUKOMkbcy6gSe5FfhFVT2+1v2qal9VLVXV0rbFHaMNlPS7G3IGvxH4WJLjwIPATUnun3SVpFGsG3hVfaGqdlXVHuB24PtVdcfkyyRtmD8Hlxq7qM+DV9UPgR9OskTS6DyDS40ZuNSYgUuNGbjUmIFLjXlV1S1mqqufvvqOyyc57uX//uokx9UwnsGlxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApca8qqqA6a5++to1l01y3MtOvjbJcbvxDC41ZuBSYwYuNWbgUmMGLjVm4FJjBi41NijwJG9Nsj/JT5IcSfKBqYdJ2rihb3T5KvDdqvqLJJcC2yfcJGkk6wae5ErgQ8BfA1TVaeD0tLMkjWHIQ/S9wEngm0meSHJvkh3n3ynJXUmWkyyfOXtq9KGSLt6QwBeB9wJfr6rrgVPAPeffqar2VdVSVS1tW/w//UuawZDATwAnqurA6tf7WQle0ia3buBV9SLwbJL3rN50M3B40lWSRjH0VfRPAQ+svoJ+DPjEdJMkjWVQ4FX1JLA08RZJI/OdbFJjBi41ZuBSYwYuNWbgUmNeVVWTmurqp5e8/OvRj3nuyn6fofIMLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjXnRRW9IUF0j8zfZpclj49dlJjjuEZ3CpMQOXGjNwqTEDlxozcKkxA5caM3CpsUGBJ/lskkNJnk7yrSSXTT1M0satG3iSncCngaWqug5YAG6fepikjRv6EH0RuDzJIrAdeH66SZLGsm7gVfUc8CXgGeAF4OWq+t7590tyV5LlJMtnzp4af6mkizbkIfrbgNuAvcA7gR1J7jj/flW1r6qWqmpp2+KO8ZdKumhDHqJ/GPh5VZ2sqjPAQ8AHp50laQxDAn8GuCHJ9iQBbgaOTDtL0hiGPAc/AOwHDgJPrf49+ybeJWkEgz4AW1VfBL448RZJI/OdbFJjBi41ZuBSYwYuNWbgUmNeVVVaNdnVT8+dm+CgNehensGlxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcZSNezqjBd10OQk8G8D7vp24D9GHzCdrbR3K22FrbV3M2x9V1Vds96dJgl8qCTLVbU024CLtJX2bqWtsLX2bqWtPkSXGjNwqbG5A9838z//Ym2lvVtpK2ytvVtm66zPwSVNa+4zuKQJGbjU2GyBJ7klyU+THE1yz1w71pNkd5IfJDmc5FCSu+feNESShSRPJPnO3FvWkuStSfYn+UmSI0k+MPemtST57Or3wdNJvpXksrk3rWWWwJMsAF8DPgJcC3w8ybVzbBngLPC5qroWuAH420289fXuBo7MPWKArwLfrao/BP6ETbw5yU7g08BSVV0HLAC3z7tqbXOdwd8PHK2qY1V1GngQuG2mLWuqqheq6uDqr19h5Rtw57yr1pZkF/BR4N65t6wlyZXAh4D7AKrqdFW9NO+qdS0ClydZBLYDz8+8Z01zBb4TePZ1X59gk0cDkGQPcD1wYN4l6/oK8Hlgij95fkx7gZPAN1efTtybZMfcoy6kqp4DvgQ8A7wAvFxV35t31dp8kW2gJFcA3wY+U1W/mnvPhSS5FfhFVT0+95YBFoH3Al+vquuBU8Bmfj3mbaw80twLvBPYkeSOeVetba7AnwN2v+7rXau3bUpJtrES9wNV9dDce9ZxI/CxJMdZeepzU5L75510QSeAE1X1P4+I9rMS/Gb1YeDnVXWyqs4ADwEfnHnTmuYK/DHg3Un2JrmUlRcqHp5py5qShJXniEeq6stz71lPVX2hqnZV1R5W/r1+v6o25Vmmql4Enk3yntWbbgYOzzhpPc8ANyTZvvp9cTOb+EVBWHmI9IarqrNJPgk8ysorkd+oqkNzbBngRuBO4KkkT67e9vdV9ciMmzr5FPDA6v/ojwGfmHnPBVXVgST7gYOs/HTlCTb521Z9q6rUmC+ySY0ZuNSYgUuNGbjUmIFLjRm41JiBS439NyrdYAajSKUYAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADO9JREFUeJzt3W1snfV5x/HfL46fEjqHsm4dcUbCoLCoFQK5ERSp6yDS+rSirawLU2jLm2jqgLRiYxStKqu0NxtCIK1DuNBKE6h0DVnF2ohSrXRt1DatCag0SWnDQ43TIMKYAwkhsZNrL+xJWbb43Mb/P7d95fuRkOKTOxeXbH9zn3N8zh1HhADktKjtBQDUQ+BAYgQOJEbgQGIEDiRG4EBiBA4kRuCnKNth+5wTbrvF9r1t7YTyCBxIjMCBxAgcSIzAgcQI/NR1VFL3Cbd1S5poYRdUQuCnrlFJK0+4bZWkX77xq6AWAj91fUXS39getL3I9lpJfyhpU8t7oSDzfvBTk+1+SZ+T9CeSTpf0lKRbIuLBVhdDUQQOJMZddCAxAgcSI3AgMQIHEltcY2j3QH/0vXWg+NxFoy4+U5J81tHiMw/v7y0+U5KO9VZ6UrTOp1b9vUeqzD2j+2DxmQeO1vmajR9YWnzm5Esv6eiBgx2/alUC73vrgC76p/XF5/Zf31N8piT13TVefObPt5xbfKYkHVxV6YVmPceqjL3g7LEqc6/+rR8Un/n9V87pfNDr8K9b1xSfufcfbm90HHfRgcQIHEiMwIHECBxIjMCBxAgcSKxR4Lbfa/tJ27tt31R7KQBldAzcdpekz0t6n6TVkq6yvbr2YgDmrskZfI2k3RHxdEQckXS/pCvqrgWghCaBL5f03HEfj03f9r/Y3mB7xPbIxPirpfYDMAfFnmSLiOGIGIqIoe5lS0qNBTAHTQLfI2nFcR8PTt8GYJ5rEviPJZ1re5XtHknrJHHdLmAB6PhusoiYtH2tpG9K6pL0xYjYUX0zAHPW6O2iEbFF0pbKuwAojFeyAYkROJAYgQOJETiQGIEDiVW56GKXj+lNPYeLz/3Tr20tPlOS/vaBjxSf+ZF13ys+U5K2/9eKzge9Dv5onb/rd1x7dpW5t7ztjOIzI+pcWnbxq+XnuuE1MjmDA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJVbmq6uSxRfrPQ+X/jfDPfbX81U8lacUjR4rP/N73Lyk+U5L+Y3i4ytx33lnnc9v33TpXKn31QG/xmWffFcVnStJTVzW8BOosRFez4ziDA4kROJAYgQOJETiQGIEDiRE4kBiBA4l1DNz2CtuP2N5pe4ftjW/EYgDmrskLXSYl3RAR222/SdKjtr8VETsr7wZgjjqewSNib0Rsn/71K5J2SVpeezEAczerx+C2V0q6UNK2/+f3NtgesT0ysf9Qme0AzEnjwG2fJukBSZ+MiJdP/P2IGI6IoYgY6h7oL7kjgNepUeC2uzUV930RsbnuSgBKafIsuiXdI2lXRNxWfyUApTQ5g18q6WpJl9l+fPq/91feC0ABHX9MFhFbJdV5Uy+AqnglG5AYgQOJETiQGIEDiVW56KIldS8qf6G5vn11nut7/hOvFZ85+OEdxWdK0u/e9Ykqc8/6ux9VmfviNWdUmduzta/4zLG1db6/Bh+eLD7zxZebXSCSMziQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kFidq6o61NtV/kqSh36z2ZUkZ2vwrvL/3PEv7ri4+ExJ6jpc53PQ9+91rn468bU6Vyp9884jxWfefMO/FJ8pSff84zuLz+w6cLjRcZzBgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQaB267y/Zjtr9ecyEA5czmDL5R0q5aiwAor1HgtgclfUDS3XXXAVBS0zP47ZJulHTsZAfY3mB7xPbIxP5DRZYDMDcdA7f9QUkvRMSjMx0XEcMRMRQRQ90D5V/bDWD2mpzBL5X0IdvPSrpf0mW27626FYAiOgYeEZ+OiMGIWClpnaRvR8T66psBmDN+Dg4kNqv3g0fEdyR9p8omAIrjDA4kRuBAYgQOJEbgQGIEDiRW5aqqXT6mgd7yL1dd854673VZ+8c7i8+87QtXFp8pST5aZawGl4xXmXvkoX1V5o5e8ZbiM2/8xp8VnylJ53xlrPjM+PNm52bO4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYlWuqlrLD59dWWXuky/9RvGZB95xuPhMSRr4UW+VuQ/vPr/K3L9/8IEqcz/zhY8Wn3n+raPFZ0rS0TPPKD90rFm6nMGBxAgcSIzAgcQIHEiMwIHECBxIjMCBxBoFbnuZ7U22f2Z7l+1Lai8GYO6avtDlDkkPRcSVtnskLam4E4BCOgZue0DSuyV9XJIi4oikI3XXAlBCk7voqyTtk/Ql24/Zvtv20hMPsr3B9ojtkcPjrxVfFMDsNQl8saSLJN0ZERdKOijpphMPiojhiBiKiKHeZX2F1wTwejQJfEzSWERsm/54k6aCBzDPdQw8Ip6X9Jzt86ZvulzSzqpbASii6bPo10m6b/oZ9KclXVNvJQClNAo8Ih6XNFR5FwCF8Uo2IDECBxIjcCAxAgcSI3AgsSpXVV3kUF/XRPG55/3VC8VnStLPN55VfOb7f3978ZmS9I3XLqgyd9nW//Pq4yL+evfVVeae889PFZ85uv53is+UpDVX/qT4zJ6PNbtqL2dwIDECBxIjcCAxAgcSI3AgMQIHEiNwIDECBxIjcCAxAgcSI3AgMQIHEiNwIDECBxIjcCAxAgcSI3AgMQIHEiNwIDECBxKrctHFY7FIr072FJ/7zB1vLj5Tkt72l3uKz3ziu3Uujnj+6MtV5j77R3U+t+94zy+qzP1J97nFZ06cXv5CoZL02/0vFZ/Zs2iy0XGcwYHECBxIjMCBxAgcSIzAgcQIHEiMwIHEGgVu+1O2d9j+qe0v2+6rvRiAuesYuO3lkq6XNBQRb5fUJWld7cUAzF3Tu+iLJfXbXixpiaRf1VsJQCkdA4+IPZJulTQqaa+k/RHx8InH2d5ge8T2yOHxQ+U3BTBrTe6iny7pCkmrJJ0paant9SceFxHDETEUEUO9y/rLbwpg1prcRV8r6ZmI2BcRE5I2S3pX3bUAlNAk8FFJF9teYtuSLpe0q+5aAEpo8hh8m6RNkrZLemL6zwxX3gtAAY3eDx4Rn5X02cq7ACiMV7IBiRE4kBiBA4kROJAYgQOJVbmq6sSxRdp78NeKzx3YfFrxmZI0vqb83KMff7H4TEnq/0xvlblynbGHPlbna3bdv20pPvOrN/9B8ZmS9OCjv1d85vgL2xsdxxkcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEjMEVF+qL1P0i8bHPrrkupcfrSOhbTvQtpVWlj7zoddz4qIt3Q6qErgTdkeiYih1haYpYW070LaVVpY+y6kXbmLDiRG4EBibQc+3PL/f7YW0r4LaVdpYe27YHZt9TE4gLraPoMDqIjAgcRaC9z2e20/aXu37Zva2qMT2ytsP2J7p+0dtje2vVMTtrtsP2b7623vMhPby2xvsv0z27tsX9L2TjOx/anp74Of2v6y7b62d5pJK4Hb7pL0eUnvk7Ra0lW2V7exSwOTkm6IiNWSLpb0F/N41+NtlLSr7SUauEPSQxFxvqQLNI93tr1c0vWShiLi7ZK6JK1rd6uZtXUGXyNpd0Q8HRFHJN0v6YqWdplRROyNiO3Tv35FU9+Ay9vdama2ByV9QNLdbe8yE9sDkt4t6R5JiogjETHe7lYdLZbUb3uxpCWSftXyPjNqK/Dlkp477uMxzfNoJMn2SkkXStrW7iYd3S7pRknH2l6kg1WS9kn60vTDibttL217qZOJiD2SbpU0KmmvpP0R8XC7W82MJ9kasn2apAckfTIiXm57n5Ox/UFJL0TEo23v0sBiSRdJujMiLpR0UNJ8fj7mdE3d01wl6UxJS22vb3ermbUV+B5JK477eHD6tnnJdrem4r4vIja3vU8Hl0r6kO1nNfXQ5zLb97a70kmNSRqLiP+5R7RJU8HPV2slPRMR+yJiQtJmSe9qeacZtRX4jyWda3uV7R5NPVHxYEu7zMi2NfUYcVdE3Nb2Pp1ExKcjYjAiVmrq8/rtiJiXZ5mIeF7Sc7bPm77pckk7W1ypk1FJF9teMv19cbnm8ZOC0tRdpDdcREzavlbSNzX1TOQXI2JHG7s0cKmkqyU9Yfvx6dtujogtLe6UyXWS7pv+i/5pSde0vM9JRcQ225skbdfUT1ce0zx/2SovVQUS40k2IDECBxIjcCAxAgcSI3AgMQIHEiNwILH/Bobry1k8oM1RAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAC71JREFUeJzt3X+sX3V9x/Hni7YEWhZUYFtsm7VujoW5bJgrQUn4A1iCw0my+QcECPMfsmwqOhODyxKW/bkYo384Z4O6PyDiUogjjAgm6h9mGaP8SKCtkqYwKD+0bhOlim3pe3/cr0nX2Xu/5Z7Tc+87z0dyk/s99/R835D77Dnf7z3301QVkno6Y+oBJI3HwKXGDFxqzMClxgxcaszApcYMvJkk5ye5PcllU8+i6Rn4GpCkkvzWCdv+NsmdJ2zbBPwr8IfA/UkuPu5rNyR5dfbxsyTHjnv86hLPnST7k+wZak6dPgbeRJINwD3AHuBy4M+B+5L8JkBV3VVV51TVOcB7gRd/8Xi27WQuB34VeFuSd437X6GhrZ96AK1ckgD/BDwD/EUt3p741SQ/ZzHyK6rq+2/w8DcD/wKcPfv8kQFG1mli4A3Mgr7hl2z/GvC1N3rcJBuBDwDXsRj4F5L8VVUdfqPH1OnlJbqW8ifAz4GHWHxtvwG4ZtKJdEoMfG14ncW4jrcBODLy894M/HNVHa2q11h8jX/zEvtPNadOwkv0teE5YBuw97ht24Gnx3rCJFuAK4BLkvzpbPNG4Kwk51fVD1fDnFqaZ/C14avA3yTZkuSMJFcBfwzsHPE5b2IxzAuBP5h9/DZwALh+Fc2pJRj42vB3wL8B3wH+B/h74IaqemrE57wZ+Ieqevn4D+AfOfll+hRzaglxwQepL8/gUmMGLjVm4FJjBi41NsrPwc9/y7ratvXE+x1W7ul95w1+TGkteu3wjzh85FCW22+UwLdt3cB/PLh18ONefe1Ngx9TWov+/akvzLWfl+hSYwYuNWbgUmMGLjVm4FJjBi41NlfgSa5O8r0k+5LcNvZQkoaxbOBJ1gGfY3ElzouA65NcNPZgklZunjP4JcC+qto/W2zvbuDacceSNIR5At8MPH/c4wOzbf9HkluS7Eqy6+B/vT7UfJJWYLA32apqR1UtVNXCBeetG+qwklZgnsBfAI6/sXzLbJukVW6ewB8B3p5ke5IzWVwE/75xx5I0hGV/m6yqjib5EPAgsA74UlXtHn0ySSs216+LVtUDwAMjzyJpYN7JJjVm4FJjBi41ZuBSYwYuNTbKootP7ztvlAUS65EnBz8mQN71e6McV5qaZ3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqbFRVlUdy1irn/73754z+DHfsvvVwY8pnSrP4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjywaeZGuSbyXZk2R3kltPx2CSVm6eG12OAh+vqseS/ArwaJJvVNWekWeTtELLnsGr6qWqemz2+U+AvcDmsQeTtHKn9Bo8yTbgYuDhX/K1W5LsSrLryNFDw0wnaUXmDjzJOcA9wEer6scnfr2qdlTVQlUtbFi/acgZJb1BcwWeZAOLcd9VVfeOO5KkoczzLnqALwJ7q+rT448kaSjznMEvA24CrkjyxOzjj0aeS9IAlv0xWVV9B8hpmEXSwLyTTWrMwKXGDFxqzMClxtbUootjGWOBxDP2vzD4MQGOvc27hDU/z+BSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmOuqjqSsVY/PePZl0c57rFtvz7KcTUtz+BSYwYuNWbgUmMGLjVm4FJjBi41ZuBSY3MHnmRdkseT3D/mQJKGcypn8FuBvWMNIml4cwWeZAtwDXDHuONIGtK8Z/DPAJ8Ajp1shyS3JNmVZNeRo4cGGU7SyiwbeJL3AT+oqkeX2q+qdlTVQlUtbFi/abABJb1x85zBLwPen+RZ4G7giiR3jjqVpEEsG3hVfbKqtlTVNuA64JtVdePok0laMX8OLjV2Sr8PXlXfBr49yiSSBucZXGrMwKXGDFxqzMClxgxcasxVVdeYsVY//dmvnT3Kcc/+/s9GOa7m4xlcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMVVUFjLf66WsXnDXKcc86+Noox+3GM7jUmIFLjRm41JiBS40ZuNSYgUuNGbjU2FyBJ3lTkp1Jvptkb5J3jz2YpJWb90aXzwJfr6oPJDkT2DjiTJIGsmzgSc4FLgf+DKCqDgOHxx1L0hDmuUTfDhwEvpzk8SR3JNl04k5JbkmyK8muI0cPDT6opFM3T+DrgXcCn6+qi4FDwG0n7lRVO6pqoaoWNqz/f/1LmsA8gR8ADlTVw7PHO1kMXtIqt2zgVfUy8HySC2ebrgT2jDqVpEHM+y76h4G7Zu+g7wc+ON5IkoYyV+BV9QSwMPIskgbmnWxSYwYuNWbgUmMGLjVm4FJjrqqqUY21+ukZr/x08GMeO7ff71B5BpcaM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMRdd1Jo0xgKJr28cJ4d1Pz06ynHn4RlcaszApcYMXGrMwKXGDFxqzMClxgxcamyuwJN8LMnuJE8l+UqSs8YeTNLKLRt4ks3AR4CFqnoHsA64buzBJK3cvJfo64Gzk6wHNgIvjjeSpKEsG3hVvQB8CngOeAl4paoeOnG/JLck2ZVk15Gjh4afVNIpm+cS/c3AtcB24K3ApiQ3nrhfVe2oqoWqWtiwftPwk0o6ZfNcol8FPFNVB6vqCHAv8J5xx5I0hHkCfw64NMnGJAGuBPaOO5akIczzGvxhYCfwGPDk7M/sGHkuSQOY6xdgq+p24PaRZ5E0MO9kkxozcKkxA5caM3CpMQOXGnNVVWlmtNVPjx0b4aA1116ewaXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxlI13+qMp3TQ5CDwn3Psej7ww8EHGM9amnctzQpra97VMOtvVNUFy+00SuDzSrKrqhYmG+AUraV519KssLbmXUuzeokuNWbgUmNTB75j4uc/VWtp3rU0K6ytedfMrJO+Bpc0rqnP4JJGZOBSY5MFnuTqJN9Lsi/JbVPNsZwkW5N8K8meJLuT3Dr1TPNIsi7J40nun3qWpSR5U5KdSb6bZG+Sd08901KSfGz2ffBUkq8kOWvqmZYySeBJ1gGfA94LXARcn+SiKWaZw1Hg41V1EXAp8JereNbj3QrsnXqIOXwW+HpV/Q7w+6zimZNsBj4CLFTVO4B1wHXTTrW0qc7glwD7qmp/VR0G7gaunWiWJVXVS1X12Ozzn7D4Dbh52qmWlmQLcA1wx9SzLCXJucDlwBcBqupwVf1o2qmWtR44O8l6YCPw4sTzLGmqwDcDzx/3+ACrPBqAJNuAi4GHp51kWZ8BPgGM8S/PD2k7cBD48uzlxB1JNk091MlU1QvAp4DngJeAV6rqoWmnWppvss0pyTnAPcBHq+rHU89zMkneB/ygqh6depY5rAfeCXy+qi4GDgGr+f2YN7N4pbkdeCuwKcmN0061tKkCfwHYetzjLbNtq1KSDSzGfVdV3Tv1PMu4DHh/kmdZfOlzRZI7px3ppA4AB6rqF1dEO1kMfrW6Cnimqg5W1RHgXuA9E8+0pKkCfwR4e5LtSc5k8Y2K+yaaZUlJwuJrxL1V9emp51lOVX2yqrZU1TYW/79+s6pW5Vmmql4Gnk9y4WzTlcCeCUdaznPApUk2zr4vrmQVvykIi5dIp11VHU3yIeBBFt+J/FJV7Z5iljlcBtwEPJnkidm2v66qByacqZMPA3fN/qLfD3xw4nlOqqoeTrITeIzFn648ziq/bdVbVaXGfJNNaszApcYMXGrMwKXGDFxqzMClxgxcaux/AWPSvgMr5OSGAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Ee3LMzOvlCuK"
},
"source": [
"## Convolutions\n",
"\n",
"I keep hearing from the AGI folks that we can use convolutions to build artificial life. Let's try it out."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "J8QkirDalBse",
"outputId": "73c53980-8dbd-497b-fe56-7e606a29c19f",
"colab": {
"height": 132
}
},
"source": [
"Niter=13\n",
"matrix_shape = (1,1, 20, 20)\n",
"in_shape = shape_of(\n",
" (onp.zeros(matrix_shape, dtype=onp.int32), 1)\n",
")\n",
"\n",
"# Body computation -- Conway Update\n",
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
"intuple = bcb.ParameterWithShape(in_shape)\n",
"x = bcb.GetTupleElement(intuple, 0)\n",
"cntr = bcb.GetTupleElement(intuple, 1)\n",
"# convs require floating-point type\n",
"xf = bcb.ConvertElementType(x, to_xla_type('float32'))\n",
"stamp = bcb.Constant(onp.ones((1,1,3,3), dtype=onp.float32))\n",
"convd = bcb.Conv(xf, stamp, onp.array([1, 1]), xla_client.PaddingType.SAME)\n",
"# logic ops require integer types\n",
"convd = bcb.ConvertElementType(convd, to_xla_type('int32'))\n",
"bool_x = bcb.Eq(x, bcb.ConstantS32Scalar(1))\n",
"# core update rule\n",
"res = bcb.Or(\n",
" # birth rule\n",
" bcb.And(bcb.Not(bool_x), bcb.Eq(convd, bcb.ConstantS32Scalar(3))),\n",
" # survival rule\n",
" bcb.And(bool_x, bcb.Or(\n",
" # these are +1 the normal numbers since conv-sum counts self\n",
" bcb.Eq(convd, bcb.ConstantS32Scalar(4)),\n",
" bcb.Eq(convd, bcb.ConstantS32Scalar(3)))\n",
" )\n",
")\n",
"# Convert output back to int type for type constancy\n",
"int_res = bcb.ConvertElementType(res, to_xla_type('int32'))\n",
"bcb.Tuple(int_res, bcb.Sub(cntr, bcb.ConstantS32Scalar(1)))\n",
"body_computation = bcb.Build()\n",
"\n",
"# Test computation -- just a for loop condition\n",
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
"intuple = tcb.ParameterWithShape(in_shape)\n",
"cntr = tcb.GetTupleElement(intuple, 1)\n",
"test = tcb.Gt(cntr, tcb.ConstantS32Scalar(0))\n",
"test_computation = tcb.Build()\n",
"\n",
"# While computation:\n",
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
"intuple = wcb.ParameterWithShape(in_shape)\n",
"wcb.While(test_computation, body_computation, intuple)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"compiled_computation = while_computation.Compile([in_shape,])\n",
"\n",
"# Set up initial state\n",
"X = onp.zeros(matrix_shape, dtype=onp.int32)\n",
"X[0,0, 5:8, 5:8] = onp.array([[0,1,0],[0,0,1],[1,1,1]])\n",
"\n",
"# Evolve\n",
"movie = onp.zeros((Niter,)+matrix_shape[-2:], dtype=onp.int32)\n",
"for it in range(Niter):\n",
" itr = onp.array(it, dtype=onp.int32)\n",
" device_in = xla_client.LocalBuffer.from_pyval((X, itr))\n",
" device_out = compiled_computation.Execute([device_in,])\n",
" movie[it] = device_out.to_py()[0][0,0]\n",
"\n",
"# Plot\n",
"fig = plt.figure(figsize=(15,2))\n",
"gs = gridspec.GridSpec(1,Niter)\n",
"for i in range(Niter):\n",
" ax1 = plt.subplot(gs[:, i])\n",
" ax1.axis('off')\n",
" ax1.imshow(movie[i])\n",
"plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, hspace=0.0, wspace=0.05)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABFoAAABwCAYAAAAuRhTQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABBBJREFUeJzt3MFt20AQQFHKUBWpwk0EqSBVpoIgTbiKlBH6FMDRxZL4lxI3753sk4jB8vKxnNO6rgsAAAAA2708+gEAAAAAZiG0AAAAAESEFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIic9/yxry/f1z1/bza//vw4/f3bLLf5OMtlMc+tnM2OWXa85y1ns+NstpzNjll2vOctZ7Njlp3L9/wjN1oAAAAAIkILAAAAQERoAQAAAIgILQAAAAARoQUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAiQgsAAABARGgBAAAAiAgtAAAAABGhBQAAACAitAAAAABEhBYAAACAiNACAAAAEBFaAAAAACJCCwAAAEBEaAEAAACInB/9AKWfv9/++f/bl9cHPQkAAADwP3KjBQAAACAitAAAAABEhBYAAACAyFQ7Wi53stjZcr3PZmWWLfO8n9mNY7YAALCdGy0AAAAAEaEFAAAAICK0AAAAAEQOvaPlcp8A9/tsJwu3setinFvPqtlfz26mbey62o9ZbmN+45gtAMviRgsAAABARmgBAAAAiAgtAAAAAJFD72jx3et+zPo2dt6MYydLx7lsee879lyMZdfVOHYz3c+eq32Z5/3MbpyZZutGCwAAAEBEaAEAAACICC0AAAAAkUPvaGGcI38P94zMs2OWHbPcl3lfz76bsexk6TibHe99a6ZdF8/GnqtxZtrN5EYLAAAAQERoAQAAAIgILQAAAAARO1oAYIAjfUf87MyyZZ4ds9yPWd/Gzptx7GTpzHwu3WgBAAAAiAgtAAAAABGhBQAAACBiRwsAAPBU7LlomWfHLDszz9KNFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIgILQAAAAARoQUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAiQgsAAABARGgBAAAAiAgtAAAAABGhBQAAACAitAAAAABEhBYAAACAiNACAAAAEBFaAAAAACJCCwAAAEBEaAEAAACICC0AAAAAEaEFAAAAICK0AAAAAESEFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIgILQAAAAARoQUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAip3VdH/0MAAAAAFNwowUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAiQgsAAABARGgBAAAAiAgtAAAAABGhBQAAACAitAAAAABEhBYAAACAiNACAAAAEBFaAAAAACJCCwAAAEBEaAEAAACICC0AAAAAEaEFAAAAICK0AAAAAESEFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIi8A8vjuqwsx0TPAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 1080x144 with 13 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9-0PJlqv237S"
},
"source": [
"## Fin \n",
"\n",
"There's much more to XLA, but this hopefully highlights how easy it is to play with via the python client!"
]
}
]
}