mirror of
https://github.com/ROCm/jax.git
synced 2025-04-27 11:06:07 +00:00
836 lines
45 KiB
Plaintext
836 lines
45 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "sAgUgR5Mzzz2"
|
|
},
|
|
"source": [
|
|
"# XLA in Python\n",
|
|
"\n",
|
|
"[](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/XLA_in_Python.ipynb)\n",
|
|
"\n",
|
|
"<img style=\"height:100px;\" src=\"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/compiler/xla/g3doc/images/xlalogo.png\"> <img style=\"height:100px;\" src=\"https://upload.wikimedia.org/wikipedia/commons/c/c3/Python-logo-notext.svg\">\n",
|
|
"\n",
|
|
"_Anselm Levskaya_, _Qiao Zhang_\n",
|
|
"\n",
|
|
"XLA is the compiler that JAX uses, and the compiler that TF uses for TPUs and will soon use for all devices, so it's worth some study. However, it's not exactly easy to play with XLA computations directly using the raw C++ interface. JAX exposes the underlying XLA computation builder API through a python wrapper, and makes interacting with the XLA compute model accessible for messing around and prototyping.\n",
|
|
"\n",
|
|
"XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.). \n",
|
|
"\n",
|
|
"As end users we interact with the computational primitives offered to us by the HLO spec.\n",
|
|
"\n",
|
|
"**Caution: This is a pedagogical notebook covering some low level XLA details, the APIs herein are neither public nor stable!**"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "EZK5RseuvZkr"
|
|
},
|
|
"source": [
|
|
"## References \n",
|
|
"\n",
|
|
"__xla__: the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n",
|
|
"\n",
|
|
"https://www.tensorflow.org/xla/operation_semantics\n",
|
|
"\n",
|
|
"more details on ops in the source code.\n",
|
|
"\n",
|
|
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h\n",
|
|
"\n",
|
|
"__python xla client__: this is the XLA python client for JAX, and what we're using here.\n",
|
|
"\n",
|
|
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py\n",
|
|
"\n",
|
|
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client_test.py\n",
|
|
"\n",
|
|
"__jax__: you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n",
|
|
"\n",
|
|
"https://github.com/google/jax/blob/master/jax/lax.py\n",
|
|
"\n",
|
|
"https://github.com/google/jax/blob/master/jax/lib/xla_bridge.py\n",
|
|
"\n",
|
|
"https://github.com/google/jax/blob/master/jax/interpreters/xla.py"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "3XR2NGmrzBGe"
|
|
},
|
|
"source": [
|
|
"## Colab Setup and Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"id": "Ogo2SBd3u18P"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"# We only need to import JAX's xla_client, not all of JAX.\n",
|
|
"from jax.lib import xla_client as xc\n",
|
|
"xops = xc.ops\n",
|
|
"\n",
|
|
"# Plotting\n",
|
|
"import matplotlib as mpl\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"from matplotlib import gridspec\n",
|
|
"from matplotlib import rcParams\n",
|
|
"rcParams['image.interpolation'] = 'nearest'\n",
|
|
"rcParams['image.cmap'] = 'viridis'\n",
|
|
"rcParams['axes.grid'] = False"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "odmjXyhMuNJ5"
|
|
},
|
|
"source": [
|
|
"## Simple Computations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"id": "UYUtxVzMYIiv",
|
|
"outputId": "5c603ab4-0295-472c-b462-9928b2a9520d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array(0.14112, dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# make a computation builder\n",
|
|
"c = xc.XlaBuilder(\"simple_scalar\")\n",
|
|
"\n",
|
|
"# define a parameter shape and parameter\n",
|
|
"param_shape = xc.Shape.array_shape(np.dtype(np.float32), ())\n",
|
|
"x = xops.Parameter(c, 0, param_shape)\n",
|
|
"\n",
|
|
"# define computation graph\n",
|
|
"y = xops.Sin(x)\n",
|
|
"\n",
|
|
"# build computation graph\n",
|
|
"# Keep in mind that incorrectly constructed graphs can cause \n",
|
|
"# your notebook kernel to crash!\n",
|
|
"computation = c.Build()\n",
|
|
"\n",
|
|
"# get a cpu backend\n",
|
|
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
|
|
"\n",
|
|
"# compile graph based on shape\n",
|
|
"compiled_computation = cpu_backend.compile(computation)\n",
|
|
"\n",
|
|
"# define a host variable with above parameter shape\n",
|
|
"host_input = np.array(3.0, dtype=np.float32)\n",
|
|
"\n",
|
|
"# place host variable on device and execute\n",
|
|
"device_input = cpu_backend.buffer_from_pyval(host_input)\n",
|
|
"device_out = compiled_computation.execute([device_input ,])\n",
|
|
"\n",
|
|
"# retrive the result\n",
|
|
"device_out[0].to_py()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"id": "rIA-IVMVvQs2",
|
|
"outputId": "a4d8ef32-43f3-4a48-f732-e85e158b602e"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([0.14112 , 0.7568025, 0.9589243], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# same as above with vector type:\n",
|
|
"\n",
|
|
"c = xc.XlaBuilder(\"simple_vector\")\n",
|
|
"param_shape = xc.Shape.array_shape(np.dtype(np.float32), (3,))\n",
|
|
"x = xops.Parameter(c, 0, param_shape)\n",
|
|
"\n",
|
|
"# chain steps by reference:\n",
|
|
"y = xops.Sin(x)\n",
|
|
"z = xops.Abs(y)\n",
|
|
"computation = c.Build()\n",
|
|
"\n",
|
|
"# get a cpu backend\n",
|
|
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
|
|
"\n",
|
|
"# compile graph based on shape\n",
|
|
"compiled_computation = cpu_backend.compile(computation)\n",
|
|
"\n",
|
|
"host_input = np.array([3.0, 4.0, 5.0], dtype=np.float32)\n",
|
|
"\n",
|
|
"device_input = cpu_backend.buffer_from_pyval(host_input)\n",
|
|
"device_out = compiled_computation.execute([device_input ,])\n",
|
|
"\n",
|
|
"# retrive the result\n",
|
|
"device_out[0].to_py()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "F8kWlLaVuQ1b"
|
|
},
|
|
"source": [
|
|
"## Simple While Loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"id": "MDQP1qW515Ao",
|
|
"outputId": "53245817-b5fb-4285-ee62-7eb33a822be4"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array(0, dtype=int32)"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# trivial while loop, decrement until 0\n",
|
|
"# x = 5\n",
|
|
"# while x > 0:\n",
|
|
"# x = x - 1\n",
|
|
"#\n",
|
|
"in_shape = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
|
|
"\n",
|
|
"# body computation:\n",
|
|
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
|
|
"x = xops.Parameter(bcb, 0, in_shape)\n",
|
|
"const1 = xops.Constant(bcb, np.int32(1))\n",
|
|
"y = xops.Sub(x, const1)\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# test computation:\n",
|
|
"tcb = xc.XlaBuilder(\"testcomp\")\n",
|
|
"x = xops.Parameter(tcb, 0, in_shape)\n",
|
|
"const0 = xops.Constant(tcb, np.int32(0))\n",
|
|
"y = xops.Gt(x, const0)\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# while computation:\n",
|
|
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
|
|
"x = xops.Parameter(wcb, 0, in_shape)\n",
|
|
"xops.While(test_computation, body_computation, x)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"# get a cpu backend\n",
|
|
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
|
|
"\n",
|
|
"# compile graph based on shape\n",
|
|
"compiled_computation = cpu_backend.compile(while_computation)\n",
|
|
"\n",
|
|
"host_input = np.array(5, dtype=np.int32)\n",
|
|
"\n",
|
|
"device_input = cpu_backend.buffer_from_pyval(host_input)\n",
|
|
"device_out = compiled_computation.execute([device_input ,])\n",
|
|
"\n",
|
|
"# retrive the result\n",
|
|
"device_out[0].to_py()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "7UOnXlY8slI6"
|
|
},
|
|
"source": [
|
|
"## While loops w/ Tuples - Newton's Method for sqrt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"id": "HEWz-vzd6QPR",
|
|
"outputId": "ad4c4247-8e81-4739-866f-2950fec5e759"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"square root of 2.0 is 1.4142156839370728\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"Xsqr = 2\n",
|
|
"guess = 1.0\n",
|
|
"converged_delta = 0.001\n",
|
|
"maxit = 1000\n",
|
|
"\n",
|
|
"in_shape_0 = xc.Shape.array_shape(np.dtype(np.float32), ())\n",
|
|
"in_shape_1 = xc.Shape.array_shape(np.dtype(np.float32), ())\n",
|
|
"in_shape_2 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
|
|
"in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1, in_shape_2])\n",
|
|
"\n",
|
|
"# body computation:\n",
|
|
"# x_{i+1} = x_i - (x_i**2 - y) / (2 * x_i)\n",
|
|
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
|
|
"intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
|
|
"y = xops.GetTupleElement(intuple, 0)\n",
|
|
"x = xops.GetTupleElement(intuple, 1)\n",
|
|
"guard_cntr = xops.GetTupleElement(intuple, 2)\n",
|
|
"new_x = xops.Sub(x, xops.Div(xops.Sub(xops.Mul(x, x), y), xops.Add(x, x)))\n",
|
|
"result = xops.Tuple(bcb, [y, new_x, xops.Sub(guard_cntr, xops.Constant(bcb, np.int32(1)))])\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# test computation -- convergence and max iteration test\n",
|
|
"tcb = xc.XlaBuilder(\"testcomp\")\n",
|
|
"intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
|
|
"y = xops.GetTupleElement(intuple, 0)\n",
|
|
"x = xops.GetTupleElement(intuple, 1)\n",
|
|
"guard_cntr = xops.GetTupleElement(intuple, 2)\n",
|
|
"criterion = xops.Abs(xops.Sub(xops.Mul(x, x), y))\n",
|
|
"# stop at convergence criteria or too many iterations\n",
|
|
"test = xops.And(xops.Gt(criterion, xops.Constant(tcb, np.float32(converged_delta))), \n",
|
|
" xops.Gt(guard_cntr, xops.Constant(tcb, np.int32(0))))\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# while computation:\n",
|
|
"# since jax does not allow users to create a tuple input directly, we need to\n",
|
|
"# take multiple parameters and make a intermediate tuple before feeding it as\n",
|
|
"# an initial carry to while loop\n",
|
|
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
|
|
"y = xops.Parameter(wcb, 0, in_shape_0)\n",
|
|
"x = xops.Parameter(wcb, 1, in_shape_1)\n",
|
|
"guard_cntr = xops.Parameter(wcb, 2, in_shape_2)\n",
|
|
"tuple_init_carry = xops.Tuple(wcb, [y, x, guard_cntr])\n",
|
|
"xops.While(test_computation, body_computation, tuple_init_carry)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
|
|
"\n",
|
|
"# compile graph based on shape\n",
|
|
"compiled_computation = cpu_backend.compile(while_computation)\n",
|
|
"\n",
|
|
"y = np.array(Xsqr, dtype=np.float32)\n",
|
|
"x = np.array(guess, dtype=np.float32)\n",
|
|
"maxit = np.array(maxit, dtype=np.int32)\n",
|
|
"\n",
|
|
"device_input_y = cpu_backend.buffer_from_pyval(y)\n",
|
|
"device_input_x = cpu_backend.buffer_from_pyval(x)\n",
|
|
"device_input_maxit = cpu_backend.buffer_from_pyval(maxit)\n",
|
|
"device_out = compiled_computation.execute([device_input_y, device_input_x, device_input_maxit])\n",
|
|
"\n",
|
|
"# retrive the result\n",
|
|
"print(\"square root of {y} is {x}\".format(y=y, x=device_out[1].to_py()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "yETVIzTInFYr"
|
|
},
|
|
"source": [
|
|
"## Calculate Symm Eigenvalues"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "AiyR1e2NubKa"
|
|
},
|
|
"source": [
|
|
"Let's exploit the XLA QR implementation to solve some eigenvalues for symmetric matrices. \n",
|
|
"\n",
|
|
"This is the naive QR algorithm, without acceleration for closely-spaced eigenvalue convergence, nor any permutation to sort eigenvalues by magnitude."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"id": "wjxDPbqCcuXT",
|
|
"outputId": "2380db52-799d-494e-ded2-856e91f01b0f"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"sorted eigenvalues\n",
|
|
"[-1.1406534 -0.5946617 -0.29557052 -0.09876542 0.07503236 0.19509281\n",
|
|
" 0.47496718 0.858686 1.09709 5.281351 ]\n",
|
|
"sorted eigenvalues from numpy\n",
|
|
"[-1.140657 -0.5946614 -0.29557055 -0.09876533 0.07503222 0.19509293\n",
|
|
" 0.47496703 0.85868585 1.0970895 5.2813535 ]\n",
|
|
"sorted error\n",
|
|
"[ 3.5762787e-06 -2.9802322e-07 2.9802322e-08 -8.9406967e-08\n",
|
|
" 1.4156103e-07 -1.1920929e-07 1.4901161e-07 1.1920929e-07\n",
|
|
" 4.7683716e-07 -2.3841858e-06]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAK1klEQVR4nO3dUayeBX3H8e9v5xxmWx0osCW2ne0FcSEkDnNm0GZegFl0MrmZGSaY6U2zZCoaE4e78XYXxuiFc2lAbyByUbkghohL1Au3BDwUFqRVRyqDAoa6BSQdrKfw38U5M11Le56+fR+ec/58PwlJz3teHn6B8+V537fv+zRVhaQ+fmfqAZLmy6ilZoxaasaopWaMWmrGqKVmjFpqxqjfoJI8keSlJC8meT7Jvyb5myT+TGxx/gd8Y/uLqnoL8A7gH4C/A+6YdpIullGLqnqhqu4F/gr46yTXTL1JszNq/VZVPQgcA/506i2anVHrTM8Ab5t6hGZn1DrTTuC/ph6h2Rm1fivJn7AW9Y+n3qLZGbVI8ntJbgTuBu6sqken3qTZxc9TvzEleQL4A+AU8CpwGLgT+KeqemXCabpIRi0148NvqRmjlpoxaqkZo5aaWRzjoFe8baH27F6a+3F/cfTyuR9T2opefvl5Tq6eyGt9b5So9+xe4sH7d8/9uH/20U/M/ZjSVvTgI/94zu/58FtqxqilZoxaasaopWaMWmrGqKVmBkWd5INJfp7k8SS3jT1K0uw2jDrJAvB14EPA1cDHklw99jBJsxlypn4P8HhVHa2qk6x9kP6mcWdJmtWQqHcCT5329bH12/6fJPuTrCRZOf6ffsZemsrcXiirqgNVtVxVy1devjCvw0q6QEOifho4/Y3cu9Zvk7QJDYn6J8BVSfYmuQS4Gbh33FmSZrXhp7Sq6lSSTwH3AwvAN6vqsdGXSZrJoI9eVtV9wH0jb5E0B76jTGrGqKVmjFpqxqilZoxaamaUCw/+4ujlo1wkMP/yyNyPCVD7/niU40pT8EwtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTUzytVExzLWVT+Pv2v73I955b/999yPKQ3hmVpqxqilZoxaasaopWaMWmrGqKVmjFpqZsOok+xO8sMkh5M8luTW12OYpNkMefPJKeDzVXUoyVuAh5L8c1UdHnmbpBlseKauqmer6tD6r18EjgA7xx4maTYX9Jw6yR7gWuCB1/je/iQrSVZWV0/MZ52kCzY46iRvBr4DfLaqfnPm96vqQFUtV9Xy0tKOeW6UdAEGRZ1kibWg76qqe8adJOliDHn1O8AdwJGq+sr4kyRdjCFn6n3Ax4Hrkzyy/tefj7xL0ow2/C2tqvoxkNdhi6Q58B1lUjNGLTVj1FIzRi01s6UuPDiWMS4S+PxV2+Z+TIDL/v2lUY6rPjxTS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNeDXRkYx11c+lJ389ynFX//CKUY6r159naqkZo5aaMWqpGaOWmjFqqRmjlpoxaqmZwVEnWUjycJLvjjlI0sW5kDP1rcCRsYZImo9BUSfZBXwYuH3cOZIu1tAz9VeBLwCvnusOSfYnWUmysrp6Yi7jJF24DaNOciPwXFU9dL77VdWBqlququWlpR1zGyjpwgw5U+8DPpLkCeBu4Pokd466StLMNoy6qr5YVbuqag9wM/CDqrpl9GWSZuLvU0vNXNDnqavqR8CPRlkiaS48U0vNGLXUjFFLzRi11IxRS814NdEtZqyrfr70+787ynG3Pfc/oxxX5+aZWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxquJChjvqp8LL7w8ynFfufRNoxy3A8/UUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjODok5yWZKDSX6W5EiS9449TNJshr755GvA96rqL5NcAmwfcZOki7Bh1EkuBd4PfAKgqk4CJ8edJWlWQx5+7wWOA99K8nCS25PsOPNOSfYnWUmysrp6Yu5DJQ0zJOpF4N3AN6rqWuAEcNuZd6qqA1W1XFXLS0tnNS/pdTIk6mPAsap6YP3rg6xFLmkT2jDqqvoV8FSSd67fdANweNRVkmY29NXvTwN3rb/yfRT45HiTJF2MQVFX1SPA8shbJM2B7yiTmjFqqRmjlpoxaqkZo5aa8WqiGtVYV/08tW1h7sdcfOmVuR9zCp6ppWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGCw9qSxrjIoFZHefCg7U0/4skno9naqkZo5aaMWqpGaOWmjFqqRmjlpoxaqmZQVEn+VySx5L8NMm3k4zzp55JumgbRp1kJ/AZYLmqrgEWgJvHHiZpNkMffi8C25IsAtuBZ8abJOlibBh1VT0NfBl4EngWeKGqvn/m/ZLsT7KSZGV19cT8l0oaZMjD77cCNwF7gbcDO5Lccub9qupAVS1X1fLS0o75L5U0yJCH3x8AfllVx6tqFbgHeN+4syTNakjUTwLXJdmeJMANwJFxZ0ma1ZDn1A8AB4FDwKPrf8+BkXdJmtGgz1NX1ZeAL428RdIc+I4yqRmjlpoxaqkZo5aaMWqpGa8mKq0b66qftZhRjnsunqmlZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWZSVfM/aHIc+I8Bd70C+PXcB4xnK+3dSltha+3dDFvfUVVXvtY3Rol6qCQrVbU82YALtJX2bqWtsLX2bvatPvyWmjFqqZmpo95qf3j9Vtq7lbbC1tq7qbdO+pxa0vxNfaaWNGdGLTUzWdRJPpjk50keT3LbVDs2kmR3kh8mOZzksSS3Tr1piCQLSR5O8t2pt5xPksuSHEzysyRHkrx36k3nk+Rz6z8HP03y7SRvmnrTmSaJOskC8HXgQ8DVwMeSXD3FlgFOAZ+vqquB64C/3cRbT3crcGTqEQN8DfheVf0R8C428eYkO4HPAMtVdQ2wANw87aqzTXWmfg/weFUdraqTwN3ATRNtOa+qeraqDq3/+kXWfuh2Trvq/JLsAj4M3D71lvNJcinwfuAOgKo6WVXPT7tqQ4vAtiSLwHbgmYn3nGWqqHcCT5329TE2eSgASfYA1wIPTLtkQ18FvgC8OvWQDewFjgPfWn+qcHuSHVOPOpeqehr4MvAk8CzwQlV9f9pVZ/OFsoGSvBn4DvDZqvrN1HvOJcmNwHNV9dDUWwZYBN4NfKOqrgVOAJv59ZW3svaIci/wdmBHklumXXW2qaJ+Gth92te71m/blJIssRb0XVV1z9R7NrAP+EiSJ1h7WnN9kjunnXROx4BjVfV/j3wOshb5ZvUB4JdVdbyqVoF7gPdNvOksU0X9E+CqJHuTXMLaiw33TrTlvJKEted8R6rqK1Pv2UhVfbGqdlXVHtb+vf6gqjbd2QSgqn4FPJXknes33QAcnnDSRp4Erkuyff3n4gY24Qt7i1P8Q6vqVJJPAfez9griN6vqsSm2DLAP+DjwaJJH1m/7+6q6b8JNnXwauGv9f+5HgU9OvOecquqBJAeBQ6z9rsjDbMK3jPo2UakZXyiTmjFqqRmjlpoxaqkZo5aaMWqpGaOWmvlfrtFe31cYfuIAAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light",
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"Niter = 200\n",
|
|
"matrix_shape = (10, 10)\n",
|
|
"\n",
|
|
"in_shape_0 = xc.Shape.array_shape(np.dtype(np.float32), matrix_shape)\n",
|
|
"in_shape_1 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
|
|
"in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1])\n",
|
|
"\n",
|
|
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
|
|
"\n",
|
|
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
|
|
"intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
|
|
"x = xops.GetTupleElement(intuple, 0)\n",
|
|
"cntr = xops.GetTupleElement(intuple, 1)\n",
|
|
"Q, R = xops.QR(x, True)\n",
|
|
"RQ = xops.Dot(R, Q)\n",
|
|
"xops.Tuple(bcb, [RQ, xops.Sub(cntr, xops.Constant(bcb, np.int32(1)))])\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# test computation -- just a for loop condition\n",
|
|
"tcb = xc.XlaBuilder(\"testcomp\")\n",
|
|
"intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
|
|
"cntr = xops.GetTupleElement(intuple, 1)\n",
|
|
"test = xops.Gt(cntr, xops.Constant(tcb, np.int32(0)))\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# while computation:\n",
|
|
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
|
|
"x = xops.Parameter(wcb, 0, in_shape_0)\n",
|
|
"cntr = xops.Parameter(wcb, 1, in_shape_1)\n",
|
|
"tuple_init_carry = xops.Tuple(wcb, [x, cntr])\n",
|
|
"xops.While(test_computation, body_computation, tuple_init_carry)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
|
|
"\n",
|
|
"# compile graph based on shape\n",
|
|
"compiled_computation = cpu_backend.compile(while_computation)\n",
|
|
"\n",
|
|
"X = np.random.random(matrix_shape).astype(np.float32)\n",
|
|
"X = (X + X.T) / 2.0\n",
|
|
"it = np.array(Niter, dtype=np.int32)\n",
|
|
"\n",
|
|
"device_input_x = cpu_backend.buffer_from_pyval(X)\n",
|
|
"device_input_it = cpu_backend.buffer_from_pyval(it)\n",
|
|
"device_out = compiled_computation.execute([device_input_x, device_input_it])\n",
|
|
"\n",
|
|
"host_out = device_out[0].to_py()\n",
|
|
"eigh_vals = host_out.diagonal()\n",
|
|
"\n",
|
|
"plt.title('D')\n",
|
|
"plt.imshow(host_out)\n",
|
|
"print('sorted eigenvalues')\n",
|
|
"print(np.sort(eigh_vals))\n",
|
|
"print('sorted eigenvalues from numpy')\n",
|
|
"print(np.sort(np.linalg.eigh(X)[0]))\n",
|
|
"print('sorted error') \n",
|
|
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "FpggTihknAOw"
|
|
},
|
|
"source": [
|
|
"## Calculate Full Symm Eigensystem"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Qos4ankYuj1T"
|
|
},
|
|
"source": [
|
|
"We can also calculate the eigenbasis by accumulating the Qs."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"id": "Kp3A-aAiZk0g",
|
|
"outputId": "bbaff039-20f4-45cd-b8fe-5a664d413f5b"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"sorted eigenvalues\n",
|
|
"[-0.95164776 -0.5988633 -0.28330874 -0.07402738 0.15438193 0.19796501\n",
|
|
" 0.4779069 0.58893895 0.81445134 4.5762177 ]\n",
|
|
"sorted eigenvalues from numpy\n",
|
|
"[-0.95164794 -0.6314303 -0.28330857 -0.07402731 0.15438198 0.19796519\n",
|
|
" 0.47790694 0.62150407 0.81445104 4.5762167 ]\n",
|
|
"sorted error\n",
|
|
"[ 1.7881393e-07 3.2567024e-02 -1.7881393e-07 -6.7055225e-08\n",
|
|
" -4.4703484e-08 -1.7881393e-07 -2.9802322e-08 -3.2565117e-02\n",
|
|
" 2.9802322e-07 9.5367432e-07]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAK5UlEQVR4nO3dT4xdB3mG8efF4yjYoRASVBXbwl4gkJUKBU1RSFQWCaqgpGSD1CAFUTZWpQIBIdHQDdsuEIIFSmUlsElEFiaLCEWESsACITlM/lTBdkBJSB0nRrgtCZFd1x7ydTFT5Nqx5/j6Hp+ZL89PiuSZuTl+lczjc++de49TVUjq401TD5A0X0YtNWPUUjNGLTVj1FIzRi01Y9RSM0b9BpXk+ST/neTVJC8n+VmSv0/i98QG5//AN7a/qaq3AO8C/hn4R+DeaSfpUhm1qKpXquoh4G+BTye5bupNmp1R64+q6lHgCPCXU2/R7IxaZ3sJePvUIzQ7o9bZtgH/NfUIzc6o9UdJ/oKVqH869RbNzqhFkj9JcivwAHBfVT019SbNLr6f+o0pyfPAnwLLwGvAQeA+4F+q6g8TTtMlMmqpGe9+S80YtdSMUUvNGLXUzMIYB7327Ztq547Ncz/ur569Zu7HlDaik//zMqdOH8/rfW2UqHfu2Myjj+yY+3H/6hOfnvsxpY3o0SfvPu/XvPstNWPUUjNGLTVj1FIzRi01Y9RSM4OiTvKRJL9M8kySu8YeJWl2a0adZBPwLeCjwG7gk0l2jz1M0myGnKk/ADxTVc9V1SlW3kh/27izJM1qSNTbgBfO+PjI6uf+nyR7kiwlWTr2n77HXprK3J4oq6q9VbVYVYvvuGbTvA4r6SINifpF4MwXcm9f/ZykdWhI1D8H3p1kV5IrgNuBh8adJWlWa75Lq6qWk3wWeATYBHy7qg6MvkzSTAa99bKqHgYeHnmLpDnwFWVSM0YtNWPUUjNGLTVj1FIzo1x48FfPXjPKRQLzs3+b+zEB6sb3jXJcaQqeqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZka5muhYxrrq5+/eu2Xux7z66RNzP6Y0hGdqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqZk1o06yI8mPkxxMciDJnZdjmKTZDHnxyTLwpap6PMlbgMeS/GtVHRx5m6QZrHmmrqqjVfX46q9fBQ4B28YeJmk2F/WYOslO4Hpg/+t8bU+SpSRLp5ePz2edpIs2OOokVwHfA75QVb8/++tVtbeqFqtqcfPC1nlulHQRBkWdZDMrQd9fVQ+OO0nSpRjy7HeAe4FDVfX18SdJuhRDztQ3AZ8Cbk7y5Oo/fz3yLkkzWvNHWlX1UyCXYYukOfAVZVIzRi01Y9RSM0YtNbOhLjw4ljEuErjw7NG5HxPg5J/vGOW4CyeWRzmuLj/P1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM15NdCRjXfXz5LWbRznuVYe9mmgXnqmlZoxaasaopWaMWmrGqKVmjFpqxqilZgZHnWRTkieSfH/MQZIuzcWcqe8EDo01RNJ8DIo6yXbgY8A9486RdKmGnqm/AXwZeO18N0iyJ8lSkqXTy8fnMk7SxVsz6iS3Ar+tqscudLuq2ltVi1W1uHlh69wGSro4Q87UNwEfT/I88ABwc5L7Rl0laWZrRl1VX6mq7VW1E7gd+FFV3TH6Mkkz8efUUjMX9X7qqvoJ8JNRlkiaC8/UUjNGLTVj1FIzRi01Y9RSM15NdCQLJ8a5OudYV/088WdXjnLcLUdPjnJcnZ9naqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGa8mKmC8q34u/O7EKMddvnrLKMftwDO11IxRS80YtdSMUUvNGLXUjFFLzRi11MygqJO8Lcm+JE8nOZTkg2MPkzSboS8++Sbwg6r6RJIrAH/yL61Ta0ad5K3Ah4C/A6iqU8CpcWdJmtWQu9+7gGPAd5I8keSeJFvPvlGSPUmWkiydXj4+96GShhkS9QLwfuDuqroeOA7cdfaNqmpvVS1W1eLmhXOal3SZDIn6CHCkqvavfryPlcglrUNrRl1VvwFeSPKe1U/dAhwcdZWkmQ199vtzwP2rz3w/B3xmvEmSLsWgqKvqSWBx5C2S5sBXlEnNGLXUjFFLzRi11IxRS814NVGNaqyrfi5vmf+37sKJ5bkfcwqeqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxgsPakMa4yKBf7hynBw2nby8FzT0TC01Y9RSM0YtNWPUUjNGLTVj1FIzRi01MyjqJF9MciDJL5J8N8mVYw+TNJs1o06yDfg8sFhV1wGbgNvHHiZpNkPvfi8Ab06yAGwBXhpvkqRLsWbUVfUi8DXgMHAUeKWqfnj27ZLsSbKUZOn08vH5L5U0yJC731cDtwG7gHcCW5PccfbtqmpvVS1W1eLmha3zXyppkCF3vz8M/LqqjlXVaeBB4MZxZ0ma1ZCoDwM3JNmSJMAtwKFxZ0ma1ZDH1PuBfcDjwFOr/87ekXdJmtGgN5BW1VeBr468RdIc+IoyqRmjlpoxaqkZo5aaMWqpGa8mKq0a66qfbzr92vwPWhf4/eb/u0maklFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11EyqLnBZwlkPmhwD/n3ATa8F/mPuA8azkfZupK2wsfauh63vqqp3vN4XRol6qCRLVbU42YCLtJH2bqStsLH2rvet3v2WmjFqqZmpo95of3n9Rtq7kbbCxtq7rrdO+pha0vxNfaaWNGdGLTUzWdRJPpLkl0meSXLXVDvWkmRHkh8nOZjkQJI7p940RJJNSZ5I8v2pt1xIkrcl2Zfk6SSHknxw6k0XkuSLq98Hv0jy3SRXTr3pbJNEnWQT8C3go8Bu4JNJdk+xZYBl4EtVtRu4AfiHdbz1THcCh6YeMcA3gR9U1XuB97GONyfZBnweWKyq64BNwO3TrjrXVGfqDwDPVNVzVXUKeAC4baItF1RVR6vq8dVfv8rKN922aVddWJLtwMeAe6beciFJ3gp8CLgXoKpOVdXL065a0wLw5iQLwBbgpYn3nGOqqLcBL5zx8RHWeSgASXYC1wP7p12ypm8AXwZG+NvO52oXcAz4zupDhXuSbJ161PlU1YvA14DDwFHglar64bSrzuUTZQMluQr4HvCFqvr91HvOJ8mtwG+r6rGptwywALwfuLuqrgeOA+v5+ZWrWblHuQt4J7A1yR3TrjrXVFG/COw44+Ptq59bl5JsZiXo+6vqwan3rOEm4ONJnmflYc3NSe6bdtJ5HQGOVNX/3fPZx0rk69WHgV9X1bGqOg08CNw48aZzTBX1z4F3J9mV5ApWnmx4aKItF5QkrDzmO1RVX596z1qq6itVtb2qdrLy3/VHVbXuziYAVfUb4IUk71n91C3AwQknreUwcEOSLavfF7ewDp/YW5jiN62q5SSfBR5h5RnEb1fVgSm2DHAT8CngqSRPrn7un6rq4Qk3dfI54P7VP9yfAz4z8Z7zqqr9SfYBj7PyU5EnWIcvGfVlolIzPlEmNWPUUjNGLTVj1FIzRi01Y9RSM0YtNfO/5H9lzlLkxRMAAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light",
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAM+klEQVR4nO3dfWyd9XnG8evCjomdhMQMNEoSiCfaQqi0BVkoNFpVEba1g5Vpa7dUoxPVtFRTX0LLWsE0jQrtv6KKSo3YokA1CVRYUyahDhG2tZWGVGWYwASJCSIv5JUl0JJASHCc3PvDnpQlOH5s/3489p3vR0KKzznc3DL5+jnn+PFjR4QA5HFB2wsAKIuogWSIGkiGqIFkiBpIhqiBZIgaSIaoz1O2w/ZVZ9z2bdsPt7UTyiBqIBmiBpIhaiAZogaSIerz10lJs864bZakEy3sgoKI+vy1W9KSM27rk/TaB78KSiLq89djkv7W9iLbF9i+SdIfSNrQ8l6YIvPz1Ocn292S7pX0OUm9krZL+nZEPNHqYpgyogaS4ek3kAxRA8kQNZAMUQPJdNYYOmt+d8y+bH7xuXHozHMlyjgxr/ybhbO765zDMTTcUWVu1/4qY3Xywjr76tLyn9+OXaeKz5Sk44u6is8cPvQrnXz7qN/vvipRz75svvof+LPic4//44eKz5Sk/SvLR33t1XuKz5SknW9eXGXuFfecrDL3nY+U/+IuSf7SweIz591+rPhMSRq89/LiMw/83dox7+PpN5AMUQPJEDWQDFEDyRA1kAxRA8k0itr2p2xvs/2q7btqLwVg8saN2naHpLWSPi1pqaTP215aezEAk9PkSH29pFcjYkdEDEl6VNKtddcCMFlNol4o6fTTo/aO3vb/2F5te8D2wIm33i21H4AJKvZGWUSsi4j+iOiftaCn1FgAE9Qk6n2SFp/28aLR2wBMQ02iflbSh2332e6StEoS17ECpqlxf0orIoZtf0XSRkkdkh6KiC3VNwMwKY1+9DIinpT0ZOVdABTAGWVAMkQNJEPUQDJEDSRD1EAyVS48OKdjSNddXP7Ce7d/57HiMyVp1do7i8/cOnRF8ZmS5OH3vYDklF3x0H9Xmbv7k3X2ffn3ri0+s3NNlRzUNfud4jPtsS+WyZEaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkimyuUTj52apcEjlxWf+yc/Kn/VT0n6nVX/VXzm4F9dU3ymJG3/0zlV5nac4+qUU9H91Nwqc3sOvVt8Zu/TdXa9+7M/Kj7zGz2/HPM+jtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMuNGbXux7Z/Z3mp7i+01H8RiACanycknw5LujIjNtudJes72v0XE1sq7AZiEcY/UEXEgIjaP/vltSYOSFtZeDMDkTOg1te0lkpZJ2vQ+9622PWB7YOitY2W2AzBhjaO2PVfSjyXdERFHzrw/ItZFRH9E9Hct6C65I4AJaBS17VkaCfqRiHi87koApqLJu9+W9KCkwYj4bv2VAExFkyP1CklfkHSj7RdG//n9ynsBmKRxv6UVEc9I8gewC4ACOKMMSIaogWSIGkiGqIFkqlx4cG7He1p+8c7icw9vu6L4TEna+MT1xWd2frL4SElS9+t15j6zr6/K3GOvLKgy9z9Wfaf4zL9csKr4TEn6/h/9YfGZB7c/OOZ9HKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSqXE10OC7QGyfmFp+793ej+ExJumRT+blv/ladXXd87h+qzF32bJ0rad7w21uqzF31zb8uPvO9+XV+u1TPVaeKzzz5WseY93GkBpIhaiAZogaSIWogGaIGkiFqIBmiBpJpHLXtDtvP2/5JzYUATM1EjtRrJA3WWgRAGY2itr1I0s2S1tddB8BUNT1S3y/pW5LGPN/N9mrbA7YHjv3qvSLLAZi4caO2fYukgxHx3LkeFxHrIqI/Ivq7ey8stiCAiWlypF4h6TO2d0l6VNKNth+uuhWASRs36oi4OyIWRcQSSask/TQibqu+GYBJ4fvUQDIT+nnqiPi5pJ9X2QRAERypgWSIGkiGqIFkiBpIhqiBZKpcTfS9U53afbS3+Nylf7+v+ExJ2nrvh4rP7L7oePGZktS38S+qzL1wzlCVua+sXVpl7oLt7xSfefM//WfxmZK08VD5z0HHS2P//+JIDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kU+VqohHW8ZOzis89sr7O772+9OTh8jO/9G7xmZK05/sXVZm76M/rXKl1+zevrTL3yJLyn4dnj1xZfKYkbdu0pPjM4+90jXkfR2ogGaIGkiFqIBmiBpIhaiAZogaSIWogmUZR215ge4Ptl20P2r6h9mIAJqfpySffk/RURHzWdpeknoo7AZiCcaO2PV/SJyTdLkkRMSSpzi8zBjBlTZ5+90k6JOkHtp+3vd72nDMfZHu17QHbA0OHjxVfFEAzTaLulHSdpAciYpmko5LuOvNBEbEuIvojor9rfnfhNQE01STqvZL2RsSm0Y83aCRyANPQuFFHxOuS9tj+6OhNKyVtrboVgElr+u73VyU9MvrO9w5JX6y3EoCpaBR1RLwgqb/yLgAK4IwyIBmiBpIhaiAZogaSIWogmSpXE53dcUIfuehg8bkHj88rPlOSlvS8WXzmHb/4RfGZkrT8X75RZe6uNb1V5v7xLc9Umfuvr5W/Sukbx+cWnylJvvJo+ZkXnhrzPo7UQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRT5cKDR4e7NHDoiuJz9+28pPhMSVqyvPyFB2tdIPCa+/+nytzBO3+9ytx//vcVVeZe9djbxWce7ltcfKYk/caud4vPPPj62PdxpAaSIWogGaIGkiFqIBmiBpIhaiAZogaSaRS17a/b3mL7Jds/tD279mIAJmfcqG0vlPQ1Sf0R8TFJHZJW1V4MwOQ0ffrdKanbdqekHkn7660EYCrGjToi9km6T9JuSQckHY6Ip898nO3VtgdsD5w4fKz8pgAaafL0u1fSrZL6JF0uaY7t2858XESsi4j+iOifNb+7/KYAGmny9PsmSTsj4lBEnJD0uKSP110LwGQ1iXq3pOW2e2xb0kpJg3XXAjBZTV5Tb5K0QdJmSS+O/jvrKu8FYJIa/Tx1RNwj6Z7KuwAogDPKgGSIGkiGqIFkiBpIhqiBZKpcTXR2x7Cu6S1/1cv9r1xafKYkPbes/Ne2C+5z8ZmSdPLX5lWZe/ELdb6+X3X7tipzX/zl1cVnDvVG8ZmS1PvlN4rP9OrhMe/jSA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJOOI8ldQtH1I0msNHnqJpPKXWqxnJu07k3aVZta+02HXKyPifS+vWyXqpmwPRER/awtM0EzadybtKs2sfaf7rjz9BpIhaiCZtqOeab+8fibtO5N2lWbWvtN611ZfUwMor+0jNYDCiBpIprWobX/K9jbbr9q+q609xmN7se2f2d5qe4vtNW3v1ITtDtvP2/5J27uci+0FtjfYftn2oO0b2t7pXGx/ffTvwUu2f2h7dts7namVqG13SFor6dOSlkr6vO2lbezSwLCkOyNiqaTlkr48jXc93RpJg20v0cD3JD0VEVdL+k1N451tL5T0NUn9EfExSR2SVrW71dnaOlJfL+nViNgREUOSHpV0a0u7nFNEHIiIzaN/flsjf+kWtrvVudleJOlmSevb3uVcbM+X9AlJD0pSRAxFxFvtbjWuTkndtjsl9Uja3/I+Z2kr6oWS9pz28V5N81AkyfYSScskbWp3k3HdL+lbkk61vcg4+iQdkvSD0ZcK623PaXupsUTEPkn3Sdot6YCkwxHxdLtbnY03yhqyPVfSjyXdERFH2t5nLLZvkXQwIp5re5cGOiVdJ+mBiFgm6aik6fz+Sq9GnlH2Sbpc0hzbt7W71dnainqfpMWnfbxo9LZpyfYsjQT9SEQ83vY+41gh6TO2d2nkZc2Nth9ud6Ux7ZW0NyL+75nPBo1EPl3dJGlnRByKiBOSHpf08ZZ3OktbUT8r6cO2+2x3aeTNhida2uWcbFsjr/kGI+K7be8znoi4OyIWRcQSjXxefxoR0+5oIkkR8bqkPbY/OnrTSklbW1xpPLslLbfdM/r3YqWm4Rt7nW38RyNi2PZXJG3UyDuID0XEljZ2aWCFpC9IetH2C6O3/U1EPNniTpl8VdIjo1/cd0j6Ysv7jCkiNtneIGmzRr4r8rym4SmjnCYKJMMbZUAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAy/wvI0tLk+VKqNgAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light",
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAALx0lEQVR4nO3dX4xc9XmH8eeL1whsqoRCKjW2VbttSoRSpURbREDiAmiVlDRILRcgQDQ3qGqTkDRSRKpKVL2soii5SGktkvQCFFIZlCKKApWSXESRHMyfKmATxJ/UGEzjVOSfXWMvvL3YieTaeHe8nsPZfft8pJV2zgxnXlv7cM7MnP05VYWkPs4YewBJs2XUUjNGLTVj1FIzRi01Y9RSM0bdTJLzk9ye5LKxZ9E4jHoNSFJJfvu4bX+b5K7jtm0E/g34A+CBJBcdc98NSX4x+fqfJG8cc/sXSzx3kjyfZPes5tSwjLqJJOuBe4HdwOXAnwP3J/ktgKq6u6rOqapzgA8CL//y9mTbyVwO/Brwm0l+f9g/hWZhbuwBdPqSBPhn4AXgL2rxMsGvJXmNxbCvqKr/WuHubwb+FTh78v0jMxhZAzLqBiYR3/Am278OfH2l+02yAbgWuI7FqP8pyV9V1ZGV7lPD8/RbS/kT4DXgYRZfq68Hrh51Ii3LqNeG11kM6ljrgaMDP+/NwL9U1UJVHWbxNfvNSzx+rDl1DE+/14a9wFZgzzHbtgHPDPWESTYDVwAXJ/nTyeYNwFlJzq+qH6+GOXUij9Rrw9eAv0myOckZSa4C/hjYMeBz3sRijBcAvzf5+h1gH3D9KppTxzHqteHvgO8C3wFeBf4euKGqnhzwOW8G/qGqXjn2C/hHTn4KPsacOk5cJEHqxSO11IxRS80YtdSMUUvNDPI59fm/uq62bjn+GoTT98xz5818n9JadPi1n3Dk6MG82X2DRL11y3q+99CWme/3D69d6mIm6f+P7z1xx0nv8/RbasaopWaMWmrGqKVmjFpqxqilZqaKOskHkvwgybNJbht6KEkrt2zUSdYBX2RxBcoLgeuTXDj0YJJWZpoj9cXAs1X1/GTBuXuAa4YdS9JKTRP1JuDFY27vm2z7P5LckmRXkl0H/vv1Wc0n6RTN7I2yqtpeVfNVNf+O89bNareSTtE0Ub8EHHsh9+bJNkmr0DRRPwK8K8m2JGeyuLD7/cOOJWmllv0trapaSPJR4CFgHfDlqnpq8MkkrchUv3pZVQ8CDw48i6QZ8IoyqRmjlpoxaqkZo5aaMWqpmUEWHnzmufMGWSQw3/2Pme8ToC597yD7lcbgkVpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaamaQ1USHMtSqn6++e8PM93nu04dmvk9pGh6ppWaMWmrGqKVmjFpqxqilZoxaasaopWaWjTrJliTfSrI7yVNJbn0rBpO0MtNcfLIAfKqqHkvyK8CjSf69qnYPPJukFVj2SF1V+6vqscn3Pwf2AJuGHkzSypzSa+okW4GLgJ1vct8tSXYl2XV04eBsppN0yqaOOsk5wL3AJ6rqZ8ffX1Xbq2q+qubXz22c5YySTsFUUSdZz2LQd1fVfcOOJOl0TPPud4AvAXuq6nPDjyTpdExzpL4MuAm4IskTk68/GnguSSu07EdaVfUdIG/BLJJmwCvKpGaMWmrGqKVmjFpqZk0tPDiUIRYJnHtu/8z3CXD4d7cMst+5QwuD7FdvPY/UUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzriY6kKFW/Tx8/vpB9nvOXlcT7cIjtdSMUUvNGLXUjFFLzRi11IxRS80YtdTM1FEnWZfk8SQPDDmQpNNzKkfqW4E9Qw0iaTamijrJZuBq4M5hx5F0uqY9Un8e+DTwxskekOSWJLuS7Dq6cHAmw0k6dctGneRDwI+q6tGlHldV26tqvqrm189tnNmAkk7NNEfqy4APJ/khcA9wRZK7Bp1K0ootG3VVfaaqNlfVVuA64JtVdePgk0laET+nlpo5pd+nrqpvA98eZBJJM+GRWmrGqKVmjFpqxqilZoxaasbVRAcyd2iY1TmHWvXz0K+fNch+N+w/PMh+dXIeqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZlxNVMBwq37OvXpokP0unLthkP124JFaasaopWaMWmrGqKVmjFpqxqilZoxaamaqqJO8PcmOJE8n2ZPk/UMPJmllpr345AvAN6rq2iRnAn7yL61Sy0ad5G3A5cCfAVTVEeDIsGNJWqlpTr+3AQeAryR5PMmdSTYe/6AktyTZlWTX0YWDMx9U0nSmiXoOeB9wR1VdBBwEbjv+QVW1varmq2p+/dwJzUt6i0wT9T5gX1XtnNzewWLkklahZaOuqleAF5NcMNl0JbB70Kkkrdi0735/DLh78s7388BHhhtJ0umYKuqqegKYH3gWSTPgFWVSM0YtNWPUUjNGLTVj1FIzriaqQQ216ufChtn/6M4dWpj5PsfgkVpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZlx4UGvSEIsEvn7WMDmsO/zWLmjokVpqxqilZoxaasaopWaMWmrGqKVmjFpqZqqok3wyyVNJnkzy1SRnDT2YpJVZNuokm4CPA/NV9R5gHXDd0INJWplpT7/ngLOTzAEbgJeHG0nS6Vg26qp6CfgssBfYD/y0qh4+/nFJbkmyK8muowsHZz+ppKlMc/p9LnANsA14J7AxyY3HP66qtlfVfFXNr5/bOPtJJU1lmtPvq4AXqupAVR0F7gMuHXYsSSs1TdR7gUuSbEgS4Epgz7BjSVqpaV5T7wR2AI8B35/8N9sHnkvSCk31C6RVdTtw+8CzSJoBryiTmjFqqRmjlpoxaqkZo5aacTVRaWKoVT/POPrG7HdaSzzf7J9N0piMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmUrXEsoQr3WlyAPjPKR56PvDjmQ8wnLU071qaFdbWvKth1t+oqne82R2DRD2tJLuqan60AU7RWpp3Lc0Ka2ve1T6rp99SM0YtNTN21GvtH69fS/OupVlhbc27qmcd9TW1pNkb+0gtacaMWmpmtKiTfCDJD5I8m+S2seZYTpItSb6VZHeSp5LcOvZM00iyLsnjSR4Ye5alJHl7kh1Jnk6yJ8n7x55pKUk+Ofk5eDLJV5OcNfZMxxsl6iTrgC8CHwQuBK5PcuEYs0xhAfhUVV0IXAL85Sqe9Vi3AnvGHmIKXwC+UVXvBt7LKp45ySbg48B8Vb0HWAdcN+5UJxrrSH0x8GxVPV9VR4B7gGtGmmVJVbW/qh6bfP9zFn/oNo071dKSbAauBu4ce5alJHkbcDnwJYCqOlJVPxl3qmXNAWcnmQM2AC+PPM8Jxop6E/DiMbf3scpDAUiyFbgI2DnuJMv6PPBpYIB/7XymtgEHgK9MXircmWTj2EOdTFW9BHwW2AvsB35aVQ+PO9WJfKNsSknOAe4FPlFVPxt7npNJ8iHgR1X16NizTGEOeB9wR1VdBBwEVvP7K+eyeEa5DXgnsDHJjeNOdaKxon4J2HLM7c2TbatSkvUsBn13Vd039jzLuAz4cJIfsviy5ookd4070kntA/ZV1S/PfHawGPlqdRXwQlUdqKqjwH3ApSPPdIKxon4EeFeSbUnOZPHNhvtHmmVJScLia749VfW5sedZTlV9pqo2V9VWFv9ev1lVq+5oAlBVrwAvJrlgsulKYPeIIy1nL3BJkg2Tn4srWYVv7M2N8aRVtZDko8BDLL6D+OWqemqMWaZwGXAT8P0kT0y2/XVVPTjiTJ18DLh78j/354GPjDzPSVXVziQ7gMdY/FTkcVbhJaNeJio14xtlUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjP/C52Cw8vnl5XSAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light",
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"Niter = 100\n",
|
|
"matrix_shape = (10, 10)\n",
|
|
"\n",
|
|
"in_shape_0 = xc.Shape.array_shape(np.dtype(np.float32), matrix_shape)\n",
|
|
"in_shape_1 = xc.Shape.array_shape(np.dtype(np.float32), matrix_shape)\n",
|
|
"in_shape_2 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
|
|
"in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1, in_shape_2])\n",
|
|
"\n",
|
|
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
|
|
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
|
|
"intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
|
|
"X = xops.GetTupleElement(intuple, 0)\n",
|
|
"O = xops.GetTupleElement(intuple, 1)\n",
|
|
"cntr = xops.GetTupleElement(intuple, 2)\n",
|
|
"Q, R = xops.QR(X, True)\n",
|
|
"RQ = xops.Dot(R, Q)\n",
|
|
"Onew = xops.Dot(O, Q)\n",
|
|
"xops.Tuple(bcb, [RQ, Onew, xops.Sub(cntr, xops.Constant(bcb, np.int32(1)))])\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# test computation -- just a for loop condition\n",
|
|
"tcb = xc.XlaBuilder(\"testcomp\")\n",
|
|
"intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
|
|
"cntr = xops.GetTupleElement(intuple, 2)\n",
|
|
"test = xops.Gt(cntr, xops.Constant(tcb, np.int32(0)))\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# while computation:\n",
|
|
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
|
|
"X = xops.Parameter(wcb, 0, in_shape_0)\n",
|
|
"O = xops.Parameter(wcb, 1, in_shape_1)\n",
|
|
"cntr = xops.Parameter(wcb, 2, in_shape_2)\n",
|
|
"tuple_init_carry = xops.Tuple(wcb, [X, O, cntr])\n",
|
|
"xops.While(test_computation, body_computation, tuple_init_carry)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
|
|
"\n",
|
|
"# compile graph based on shape\n",
|
|
"compiled_computation = cpu_backend.compile(while_computation)\n",
|
|
"\n",
|
|
"X = np.random.random(matrix_shape).astype(np.float32)\n",
|
|
"X = (X + X.T) / 2.0\n",
|
|
"Omat = np.eye(matrix_shape[0], dtype=np.float32)\n",
|
|
"it = np.array(Niter, dtype=np.int32)\n",
|
|
"\n",
|
|
"device_input_X = cpu_backend.buffer_from_pyval(X)\n",
|
|
"device_input_Omat = cpu_backend.buffer_from_pyval(Omat)\n",
|
|
"device_input_it = cpu_backend.buffer_from_pyval(it)\n",
|
|
"device_out = compiled_computation.execute([device_input_X, device_input_Omat, device_input_it])\n",
|
|
"\n",
|
|
"host_out = device_out[0].to_py()\n",
|
|
"eigh_vals = host_out.diagonal()\n",
|
|
"eigh_mat = device_out[1].to_py()\n",
|
|
"\n",
|
|
"plt.title('D')\n",
|
|
"plt.imshow(host_out)\n",
|
|
"plt.figure()\n",
|
|
"plt.title('U')\n",
|
|
"plt.imshow(eigh_mat)\n",
|
|
"plt.figure()\n",
|
|
"plt.title('U^T A U')\n",
|
|
"plt.imshow(np.dot(np.dot(eigh_mat.T, X), eigh_mat))\n",
|
|
"print('sorted eigenvalues')\n",
|
|
"print(np.sort(eigh_vals))\n",
|
|
"print('sorted eigenvalues from numpy')\n",
|
|
"print(np.sort(np.linalg.eigh(X)[0]))\n",
|
|
"print('sorted error') \n",
|
|
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Ee3LMzOvlCuK"
|
|
},
|
|
"source": [
|
|
"## Convolutions\n",
|
|
"\n",
|
|
"I keep hearing from the AGI folks that we can use convolutions to build artificial life. Let's try it out."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"id": "9xh6yeXKS9Vg"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Here we borrow convenience functions from LAX to handle conv dimension numbers.\n",
|
|
"from typing import NamedTuple, Sequence\n",
|
|
"\n",
|
|
"class ConvDimensionNumbers(NamedTuple):\n",
|
|
" \"\"\"Describes batch, spatial, and feature dimensions of a convolution.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" lhs_spec: a tuple of nonnegative integer dimension numbers containing\n",
|
|
" `(batch dimension, feature dimension, spatial dimensions...)`.\n",
|
|
" rhs_spec: a tuple of nonnegative integer dimension numbers containing\n",
|
|
" `(out feature dimension, in feature dimension, spatial dimensions...)`.\n",
|
|
" out_spec: a tuple of nonnegative integer dimension numbers containing\n",
|
|
" `(batch dimension, feature dimension, spatial dimensions...)`.\n",
|
|
" \"\"\"\n",
|
|
" lhs_spec: Sequence[int]\n",
|
|
" rhs_spec: Sequence[int]\n",
|
|
" out_spec: Sequence[int]\n",
|
|
"\n",
|
|
"def _conv_general_proto(dimension_numbers):\n",
|
|
" assert type(dimension_numbers) is ConvDimensionNumbers\n",
|
|
" lhs_spec, rhs_spec, out_spec = dimension_numbers\n",
|
|
" proto = xc.ConvolutionDimensionNumbers()\n",
|
|
" proto.input_batch_dimension = lhs_spec[0]\n",
|
|
" proto.input_feature_dimension = lhs_spec[1]\n",
|
|
" proto.output_batch_dimension = out_spec[0]\n",
|
|
" proto.output_feature_dimension = out_spec[1]\n",
|
|
" proto.kernel_output_feature_dimension = rhs_spec[0]\n",
|
|
" proto.kernel_input_feature_dimension = rhs_spec[1]\n",
|
|
" proto.input_spatial_dimensions.extend(lhs_spec[2:])\n",
|
|
" proto.kernel_spatial_dimensions.extend(rhs_spec[2:])\n",
|
|
" proto.output_spatial_dimensions.extend(out_spec[2:])\n",
|
|
" return proto"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"id": "J8QkirDalBse",
|
|
"outputId": "543a03fd-f038-46f2-9a76-a6532b86874e"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABEYAAABdCAYAAACo5mNeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAADjUlEQVR4nO3dwW0TURRAUTtKFVRBE4gKqJIKEE1QBWUwWSEFL4ideTMT+56zs7Kw9fRnc/Xn5bwsywkAAACg6OnoHwAAAABwFGEEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAICs5//98cvTN/+yZoWff76fX382z3Vez9Ms1zHLOZ7zWc7mHLOc4zmf5WzOcTZnOZtzzHKO53zW5Tz/cmMEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAIAsYQQAAADIEkYAAACALGEEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAIAsYQQAAADIEkYAAACALGEEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAIAsYQQAAADIEkYAAACALGEEAAAAyHo+8st//P71z+evnz4f9EsAAACAIjdGAAAAgCxhBAAAAMgSRgAAAICsQ3eMXO4UsXPkNm/NyzznmOX7md22zBcAANZxYwQAAADIEkYAAACALGEEAAAAyNp1x8jlu/Cs89ZOEa5nT8N2bj2nZn8bu4Xez56m/ZjlOua3HbMF4HRyYwQAAAAIE0YAAACALGEEAAAAyNp1x4j3Nvdl3tezr2U7dorMcjbneO7n2NOwLbuatmO30Dp2Ne3HLN/P7Lb1KPN1YwQAAADIEkYAAACALGEEAAAAyNp1xwjbutf3uT4is5xjlrPMcz9mfT37WrZlp8gcZ3OWZ3/Oo+xp+IjsadrWo+wWcmMEAAAAyBJGAAAAgCxhBAAAAMiyYwQATvfzDuw9MMtZ5jnHLPdl3tezr2U7dorMetSz6cYIAAAAkCWMAAAAAFnCCAAAAJBlxwgAALCaXQ1zzHKOWc561Hm6MQIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGSdl2U5+jcAAAAAHMKNEQAAACBLGAEAAACyhBEAAAAgSxgBAAAAsoQRAAAAIEsYAQAAALJeAPxyuoaRade9AAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 1080x144 with 13 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light",
|
|
"tags": []
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"Niter=13\n",
|
|
"matrix_shape = (1, 1, 20, 20)\n",
|
|
"in_shape_0 = xc.Shape.array_shape(np.dtype(np.int32), matrix_shape)\n",
|
|
"in_shape_1 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
|
|
"in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1])\n",
|
|
"\n",
|
|
"# Body computation -- Conway Update\n",
|
|
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
|
|
"intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
|
|
"x = xops.GetTupleElement(intuple, 0)\n",
|
|
"cntr = xops.GetTupleElement(intuple, 1)\n",
|
|
"# convs require floating-point type\n",
|
|
"xf = xops.ConvertElementType(x, xc.DTYPE_TO_XLA_ELEMENT_TYPE['float32'])\n",
|
|
"stamp = xops.Constant(bcb, np.ones((1,1,3,3), dtype=np.float32))\n",
|
|
"conv_dim_num_proto = _conv_general_proto(ConvDimensionNumbers(lhs_spec=(0,1,2,3), rhs_spec=(0,1,2,3), out_spec=(0,1,2,3)))\n",
|
|
"convd = xops.ConvGeneralDilated(xf, stamp, [1, 1], [(1, 1), (1, 1)], (), (), conv_dim_num_proto)\n",
|
|
"# # logic ops require integer types\n",
|
|
"convd = xops.ConvertElementType(convd, xc.DTYPE_TO_XLA_ELEMENT_TYPE['int32'])\n",
|
|
"bool_x = xops.Eq(x, xops.Constant(bcb, np.int32(1)))\n",
|
|
"# core update rule\n",
|
|
"res = xops.Or(\n",
|
|
" # birth rule\n",
|
|
" xops.And(xops.Not(bool_x), xops.Eq(convd, xops.Constant(bcb, np.int32(3)))),\n",
|
|
" # survival rule\n",
|
|
" xops.And(bool_x, xops.Or(\n",
|
|
" # these are +1 the normal numbers since conv-sum counts self\n",
|
|
" xops.Eq(convd, xops.Constant(bcb, np.int32(4))),\n",
|
|
" xops.Eq(convd, xops.Constant(bcb, np.int32(3))))\n",
|
|
" )\n",
|
|
")\n",
|
|
"# Convert output back to int type for type constancy\n",
|
|
"int_res = xops.ConvertElementType(res, xc.DTYPE_TO_XLA_ELEMENT_TYPE['int32'])\n",
|
|
"xops.Tuple(bcb, [int_res, xops.Sub(cntr, xops.Constant(bcb, np.int32(1)))])\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# Test computation -- just a for loop condition\n",
|
|
"tcb = xc.XlaBuilder(\"testcomp\")\n",
|
|
"intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
|
|
"cntr = xops.GetTupleElement(intuple, 1)\n",
|
|
"test = xops.Gt(cntr, xops.Constant(tcb, np.int32(0)))\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# While computation:\n",
|
|
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
|
|
"x = xops.Parameter(wcb, 0, in_shape_0)\n",
|
|
"cntr = xops.Parameter(wcb, 1, in_shape_1)\n",
|
|
"tuple_init_carry = xops.Tuple(wcb, [x, cntr])\n",
|
|
"xops.While(test_computation, body_computation, tuple_init_carry)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
|
|
"\n",
|
|
"# compile graph based on shape\n",
|
|
"compiled_computation = cpu_backend.compile(while_computation)\n",
|
|
"\n",
|
|
"# Set up initial state\n",
|
|
"X = np.zeros(matrix_shape, dtype=np.int32)\n",
|
|
"X[0,0, 5:8, 5:8] = np.array([[0,1,0],[0,0,1],[1,1,1]])\n",
|
|
"\n",
|
|
"# Evolve\n",
|
|
"movie = np.zeros((Niter,)+matrix_shape[-2:], dtype=np.int32)\n",
|
|
"for it in range(Niter):\n",
|
|
" itr = np.array(it, dtype=np.int32)\n",
|
|
" device_input_x = cpu_backend.buffer_from_pyval(X)\n",
|
|
" device_input_it = cpu_backend.buffer_from_pyval(itr)\n",
|
|
" device_out = compiled_computation.execute([device_input_x, device_input_it])\n",
|
|
" movie[it] = device_out[0].to_py()[0,0]\n",
|
|
"\n",
|
|
"# Plot\n",
|
|
"fig = plt.figure(figsize=(15,2))\n",
|
|
"gs = gridspec.GridSpec(1,Niter)\n",
|
|
"for i in range(Niter):\n",
|
|
" ax1 = plt.subplot(gs[:, i])\n",
|
|
" ax1.axis('off')\n",
|
|
" ax1.imshow(movie[i])\n",
|
|
"plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, hspace=0.0, wspace=0.05)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "9-0PJlqv237S"
|
|
},
|
|
"source": [
|
|
"## Fin \n",
|
|
"\n",
|
|
"There's much more to XLA, but this hopefully highlights how easy it is to play with via the python client!"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"name": "XLA in Python.ipnb",
|
|
"provenance": []
|
|
},
|
|
"jupytext": {
|
|
"formats": "ipynb,md:myst"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|