rocm_jax/docs/notebooks/XLA_in_Python.ipynb
2021-02-25 10:29:43 -08:00

836 lines
45 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "sAgUgR5Mzzz2"
},
"source": [
"# XLA in Python\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/XLA_in_Python.ipynb)\n",
"\n",
"<img style=\"height:100px;\" src=\"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/compiler/xla/g3doc/images/xlalogo.png\"> <img style=\"height:100px;\" src=\"https://upload.wikimedia.org/wikipedia/commons/c/c3/Python-logo-notext.svg\">\n",
"\n",
"_Anselm Levskaya_, _Qiao Zhang_\n",
"\n",
"XLA is the compiler that JAX uses, and the compiler that TF uses for TPUs and will soon use for all devices, so it's worth some study. However, it's not exactly easy to play with XLA computations directly using the raw C++ interface. JAX exposes the underlying XLA computation builder API through a python wrapper, and makes interacting with the XLA compute model accessible for messing around and prototyping.\n",
"\n",
"XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.). \n",
"\n",
"As end users we interact with the computational primitives offered to us by the HLO spec.\n",
"\n",
"**Caution: This is a pedagogical notebook covering some low level XLA details, the APIs herein are neither public nor stable!**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EZK5RseuvZkr"
},
"source": [
"## References \n",
"\n",
"__xla__: the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n",
"\n",
"https://www.tensorflow.org/xla/operation_semantics\n",
"\n",
"more details on ops in the source code.\n",
"\n",
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h\n",
"\n",
"__python xla client__: this is the XLA python client for JAX, and what we're using here.\n",
"\n",
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py\n",
"\n",
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client_test.py\n",
"\n",
"__jax__: you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n",
"\n",
"https://github.com/google/jax/blob/master/jax/lax.py\n",
"\n",
"https://github.com/google/jax/blob/master/jax/lib/xla_bridge.py\n",
"\n",
"https://github.com/google/jax/blob/master/jax/interpreters/xla.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3XR2NGmrzBGe"
},
"source": [
"## Colab Setup and Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Ogo2SBd3u18P"
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"# We only need to import JAX's xla_client, not all of JAX.\n",
"from jax.lib import xla_client as xc\n",
"xops = xc.ops\n",
"\n",
"# Plotting\n",
"import matplotlib as mpl\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib import gridspec\n",
"from matplotlib import rcParams\n",
"rcParams['image.interpolation'] = 'nearest'\n",
"rcParams['image.cmap'] = 'viridis'\n",
"rcParams['axes.grid'] = False"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "odmjXyhMuNJ5"
},
"source": [
"## Simple Computations"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "UYUtxVzMYIiv",
"outputId": "5c603ab4-0295-472c-b462-9928b2a9520d"
},
"outputs": [
{
"data": {
"text/plain": [
"array(0.14112, dtype=float32)"
]
},
"execution_count": 2,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# make a computation builder\n",
"c = xc.XlaBuilder(\"simple_scalar\")\n",
"\n",
"# define a parameter shape and parameter\n",
"param_shape = xc.Shape.array_shape(np.dtype(np.float32), ())\n",
"x = xops.Parameter(c, 0, param_shape)\n",
"\n",
"# define computation graph\n",
"y = xops.Sin(x)\n",
"\n",
"# build computation graph\n",
"# Keep in mind that incorrectly constructed graphs can cause \n",
"# your notebook kernel to crash!\n",
"computation = c.Build()\n",
"\n",
"# get a cpu backend\n",
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
"\n",
"# compile graph based on shape\n",
"compiled_computation = cpu_backend.compile(computation)\n",
"\n",
"# define a host variable with above parameter shape\n",
"host_input = np.array(3.0, dtype=np.float32)\n",
"\n",
"# place host variable on device and execute\n",
"device_input = cpu_backend.buffer_from_pyval(host_input)\n",
"device_out = compiled_computation.execute([device_input ,])\n",
"\n",
"# retrive the result\n",
"device_out[0].to_py()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "rIA-IVMVvQs2",
"outputId": "a4d8ef32-43f3-4a48-f732-e85e158b602e"
},
"outputs": [
{
"data": {
"text/plain": [
"array([0.14112 , 0.7568025, 0.9589243], dtype=float32)"
]
},
"execution_count": 3,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# same as above with vector type:\n",
"\n",
"c = xc.XlaBuilder(\"simple_vector\")\n",
"param_shape = xc.Shape.array_shape(np.dtype(np.float32), (3,))\n",
"x = xops.Parameter(c, 0, param_shape)\n",
"\n",
"# chain steps by reference:\n",
"y = xops.Sin(x)\n",
"z = xops.Abs(y)\n",
"computation = c.Build()\n",
"\n",
"# get a cpu backend\n",
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
"\n",
"# compile graph based on shape\n",
"compiled_computation = cpu_backend.compile(computation)\n",
"\n",
"host_input = np.array([3.0, 4.0, 5.0], dtype=np.float32)\n",
"\n",
"device_input = cpu_backend.buffer_from_pyval(host_input)\n",
"device_out = compiled_computation.execute([device_input ,])\n",
"\n",
"# retrive the result\n",
"device_out[0].to_py()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F8kWlLaVuQ1b"
},
"source": [
"## Simple While Loop"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "MDQP1qW515Ao",
"outputId": "53245817-b5fb-4285-ee62-7eb33a822be4"
},
"outputs": [
{
"data": {
"text/plain": [
"array(0, dtype=int32)"
]
},
"execution_count": 4,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# trivial while loop, decrement until 0\n",
"# x = 5\n",
"# while x > 0:\n",
"# x = x - 1\n",
"#\n",
"in_shape = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
"\n",
"# body computation:\n",
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
"x = xops.Parameter(bcb, 0, in_shape)\n",
"const1 = xops.Constant(bcb, np.int32(1))\n",
"y = xops.Sub(x, const1)\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation:\n",
"tcb = xc.XlaBuilder(\"testcomp\")\n",
"x = xops.Parameter(tcb, 0, in_shape)\n",
"const0 = xops.Constant(tcb, np.int32(0))\n",
"y = xops.Gt(x, const0)\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
"x = xops.Parameter(wcb, 0, in_shape)\n",
"xops.While(test_computation, body_computation, x)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"# get a cpu backend\n",
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
"\n",
"# compile graph based on shape\n",
"compiled_computation = cpu_backend.compile(while_computation)\n",
"\n",
"host_input = np.array(5, dtype=np.int32)\n",
"\n",
"device_input = cpu_backend.buffer_from_pyval(host_input)\n",
"device_out = compiled_computation.execute([device_input ,])\n",
"\n",
"# retrive the result\n",
"device_out[0].to_py()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7UOnXlY8slI6"
},
"source": [
"## While loops w/ Tuples - Newton's Method for sqrt"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "HEWz-vzd6QPR",
"outputId": "ad4c4247-8e81-4739-866f-2950fec5e759"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"square root of 2.0 is 1.4142156839370728\n"
]
}
],
"source": [
"Xsqr = 2\n",
"guess = 1.0\n",
"converged_delta = 0.001\n",
"maxit = 1000\n",
"\n",
"in_shape_0 = xc.Shape.array_shape(np.dtype(np.float32), ())\n",
"in_shape_1 = xc.Shape.array_shape(np.dtype(np.float32), ())\n",
"in_shape_2 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
"in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1, in_shape_2])\n",
"\n",
"# body computation:\n",
"# x_{i+1} = x_i - (x_i**2 - y) / (2 * x_i)\n",
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
"intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
"y = xops.GetTupleElement(intuple, 0)\n",
"x = xops.GetTupleElement(intuple, 1)\n",
"guard_cntr = xops.GetTupleElement(intuple, 2)\n",
"new_x = xops.Sub(x, xops.Div(xops.Sub(xops.Mul(x, x), y), xops.Add(x, x)))\n",
"result = xops.Tuple(bcb, [y, new_x, xops.Sub(guard_cntr, xops.Constant(bcb, np.int32(1)))])\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation -- convergence and max iteration test\n",
"tcb = xc.XlaBuilder(\"testcomp\")\n",
"intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
"y = xops.GetTupleElement(intuple, 0)\n",
"x = xops.GetTupleElement(intuple, 1)\n",
"guard_cntr = xops.GetTupleElement(intuple, 2)\n",
"criterion = xops.Abs(xops.Sub(xops.Mul(x, x), y))\n",
"# stop at convergence criteria or too many iterations\n",
"test = xops.And(xops.Gt(criterion, xops.Constant(tcb, np.float32(converged_delta))), \n",
" xops.Gt(guard_cntr, xops.Constant(tcb, np.int32(0))))\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
"# since jax does not allow users to create a tuple input directly, we need to\n",
"# take multiple parameters and make a intermediate tuple before feeding it as\n",
"# an initial carry to while loop\n",
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
"y = xops.Parameter(wcb, 0, in_shape_0)\n",
"x = xops.Parameter(wcb, 1, in_shape_1)\n",
"guard_cntr = xops.Parameter(wcb, 2, in_shape_2)\n",
"tuple_init_carry = xops.Tuple(wcb, [y, x, guard_cntr])\n",
"xops.While(test_computation, body_computation, tuple_init_carry)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
"\n",
"# compile graph based on shape\n",
"compiled_computation = cpu_backend.compile(while_computation)\n",
"\n",
"y = np.array(Xsqr, dtype=np.float32)\n",
"x = np.array(guess, dtype=np.float32)\n",
"maxit = np.array(maxit, dtype=np.int32)\n",
"\n",
"device_input_y = cpu_backend.buffer_from_pyval(y)\n",
"device_input_x = cpu_backend.buffer_from_pyval(x)\n",
"device_input_maxit = cpu_backend.buffer_from_pyval(maxit)\n",
"device_out = compiled_computation.execute([device_input_y, device_input_x, device_input_maxit])\n",
"\n",
"# retrive the result\n",
"print(\"square root of {y} is {x}\".format(y=y, x=device_out[1].to_py()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yETVIzTInFYr"
},
"source": [
"## Calculate Symm Eigenvalues"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AiyR1e2NubKa"
},
"source": [
"Let's exploit the XLA QR implementation to solve some eigenvalues for symmetric matrices. \n",
"\n",
"This is the naive QR algorithm, without acceleration for closely-spaced eigenvalue convergence, nor any permutation to sort eigenvalues by magnitude."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "wjxDPbqCcuXT",
"outputId": "2380db52-799d-494e-ded2-856e91f01b0f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sorted eigenvalues\n",
"[-1.1406534 -0.5946617 -0.29557052 -0.09876542 0.07503236 0.19509281\n",
" 0.47496718 0.858686 1.09709 5.281351 ]\n",
"sorted eigenvalues from numpy\n",
"[-1.140657 -0.5946614 -0.29557055 -0.09876533 0.07503222 0.19509293\n",
" 0.47496703 0.85868585 1.0970895 5.2813535 ]\n",
"sorted error\n",
"[ 3.5762787e-06 -2.9802322e-07 2.9802322e-08 -8.9406967e-08\n",
" 1.4156103e-07 -1.1920929e-07 1.4901161e-07 1.1920929e-07\n",
" 4.7683716e-07 -2.3841858e-06]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAK1klEQVR4nO3dUayeBX3H8e9v5xxmWx0osCW2ne0FcSEkDnNm0GZegFl0MrmZGSaY6U2zZCoaE4e78XYXxuiFc2lAbyByUbkghohL1Au3BDwUFqRVRyqDAoa6BSQdrKfw38U5M11Le56+fR+ec/58PwlJz3teHn6B8+V537fv+zRVhaQ+fmfqAZLmy6ilZoxaasaopWaMWmrGqKVmjFpqxqjfoJI8keSlJC8meT7Jvyb5myT+TGxx/gd8Y/uLqnoL8A7gH4C/A+6YdpIullGLqnqhqu4F/gr46yTXTL1JszNq/VZVPQgcA/506i2anVHrTM8Ab5t6hGZn1DrTTuC/ph6h2Rm1fivJn7AW9Y+n3qLZGbVI8ntJbgTuBu6sqken3qTZxc9TvzEleQL4A+AU8CpwGLgT+KeqemXCabpIRi0148NvqRmjlpoxaqkZo5aaWRzjoFe8baH27F6a+3F/cfTyuR9T2opefvl5Tq6eyGt9b5So9+xe4sH7d8/9uH/20U/M/ZjSVvTgI/94zu/58FtqxqilZoxaasaopWaMWmrGqKVmBkWd5INJfp7k8SS3jT1K0uw2jDrJAvB14EPA1cDHklw99jBJsxlypn4P8HhVHa2qk6x9kP6mcWdJmtWQqHcCT5329bH12/6fJPuTrCRZOf6ffsZemsrcXiirqgNVtVxVy1devjCvw0q6QEOifho4/Y3cu9Zvk7QJDYn6J8BVSfYmuQS4Gbh33FmSZrXhp7Sq6lSSTwH3AwvAN6vqsdGXSZrJoI9eVtV9wH0jb5E0B76jTGrGqKVmjFpqxqilZoxaamaUCw/+4ujlo1wkMP/yyNyPCVD7/niU40pT8EwtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTUzytVExzLWVT+Pv2v73I955b/999yPKQ3hmVpqxqilZoxaasaopWaMWmrGqKVmjFpqZsOok+xO8sMkh5M8luTW12OYpNkMefPJKeDzVXUoyVuAh5L8c1UdHnmbpBlseKauqmer6tD6r18EjgA7xx4maTYX9Jw6yR7gWuCB1/je/iQrSVZWV0/MZ52kCzY46iRvBr4DfLaqfnPm96vqQFUtV9Xy0tKOeW6UdAEGRZ1kibWg76qqe8adJOliDHn1O8AdwJGq+sr4kyRdjCFn6n3Ax4Hrkzyy/tefj7xL0ow2/C2tqvoxkNdhi6Q58B1lUjNGLTVj1FIzRi01s6UuPDiWMS4S+PxV2+Z+TIDL/v2lUY6rPjxTS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNeDXRkYx11c+lJ389ynFX//CKUY6r159naqkZo5aaMWqpGaOWmjFqqRmjlpoxaqmZwVEnWUjycJLvjjlI0sW5kDP1rcCRsYZImo9BUSfZBXwYuH3cOZIu1tAz9VeBLwCvnusOSfYnWUmysrp6Yi7jJF24DaNOciPwXFU9dL77VdWBqlququWlpR1zGyjpwgw5U+8DPpLkCeBu4Pokd466StLMNoy6qr5YVbuqag9wM/CDqrpl9GWSZuLvU0vNXNDnqavqR8CPRlkiaS48U0vNGLXUjFFLzRi11IxRS814NdEtZqyrfr70+787ynG3Pfc/oxxX5+aZWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxquJChjvqp8LL7w8ynFfufRNoxy3A8/UUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjODok5yWZKDSX6W5EiS9449TNJshr755GvA96rqL5NcAmwfcZOki7Bh1EkuBd4PfAKgqk4CJ8edJWlWQx5+7wWOA99K8nCS25PsOPNOSfYnWUmysrp6Yu5DJQ0zJOpF4N3AN6rqWuAEcNuZd6qqA1W1XFXLS0tnNS/pdTIk6mPAsap6YP3rg6xFLmkT2jDqqvoV8FSSd67fdANweNRVkmY29NXvTwN3rb/yfRT45HiTJF2MQVFX1SPA8shbJM2B7yiTmjFqqRmjlpoxaqkZo5aa8WqiGtVYV/08tW1h7sdcfOmVuR9zCp6ppWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGCw9qSxrjIoFZHefCg7U0/4skno9naqkZo5aaMWqpGaOWmjFqqRmjlpoxaqmZQVEn+VySx5L8NMm3k4zzp55JumgbRp1kJ/AZYLmqrgEWgJvHHiZpNkMffi8C25IsAtuBZ8abJOlibBh1VT0NfBl4EngWeKGqvn/m/ZLsT7KSZGV19cT8l0oaZMjD77cCNwF7gbcDO5Lccub9qupAVS1X1fLS0o75L5U0yJCH3x8AfllVx6tqFbgHeN+4syTNakjUTwLXJdmeJMANwJFxZ0ma1ZDn1A8AB4FDwKPrf8+BkXdJmtGgz1NX1ZeAL428RdIc+I4yqRmjlpoxaqkZo5aaMWqpGa8mKq0b66qftZhRjnsunqmlZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWZSVfM/aHIc+I8Bd70C+PXcB4xnK+3dSltha+3dDFvfUVVXvtY3Rol6qCQrVbU82YALtJX2bqWtsLX2bvatPvyWmjFqqZmpo95qf3j9Vtq7lbbC1tq7qbdO+pxa0vxNfaaWNGdGLTUzWdRJPpjk50keT3LbVDs2kmR3kh8mOZzksSS3Tr1piCQLSR5O8t2pt5xPksuSHEzysyRHkrx36k3nk+Rz6z8HP03y7SRvmnrTmSaJOskC8HXgQ8DVwMeSXD3FlgFOAZ+vqquB64C/3cRbT3crcGTqEQN8DfheVf0R8C428eYkO4HPAMtVdQ2wANw87aqzTXWmfg/weFUdraqTwN3ATRNtOa+qeraqDq3/+kXWfuh2Trvq/JLsAj4M3D71lvNJcinwfuAOgKo6WVXPT7tqQ4vAtiSLwHbgmYn3nGWqqHcCT5329TE2eSgASfYA1wIPTLtkQ18FvgC8OvWQDewFjgPfWn+qcHuSHVOPOpeqehr4MvAk8CzwQlV9f9pVZ/OFsoGSvBn4DvDZqvrN1HvOJcmNwHNV9dDUWwZYBN4NfKOqrgVOAJv59ZW3svaIci/wdmBHklumXXW2qaJ+Gth92te71m/blJIssRb0XVV1z9R7NrAP+EiSJ1h7WnN9kjunnXROx4BjVfV/j3wOshb5ZvUB4JdVdbyqVoF7gPdNvOksU0X9E+CqJHuTXMLaiw33TrTlvJKEted8R6rqK1Pv2UhVfbGqdlXVHtb+vf6gqjbd2QSgqn4FPJXknes33QAcnnDSRp4Erkuyff3n4gY24Qt7i1P8Q6vqVJJPAfez9griN6vqsSm2DLAP+DjwaJJH1m/7+6q6b8JNnXwauGv9f+5HgU9OvOecquqBJAeBQ6z9rsjDbMK3jPo2UakZXyiTmjFqqRmjlpoxaqkZo5aaMWqpGaOWmvlfrtFe31cYfuIAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"Niter = 200\n",
"matrix_shape = (10, 10)\n",
"\n",
"in_shape_0 = xc.Shape.array_shape(np.dtype(np.float32), matrix_shape)\n",
"in_shape_1 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
"in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1])\n",
"\n",
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
"\n",
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
"intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
"x = xops.GetTupleElement(intuple, 0)\n",
"cntr = xops.GetTupleElement(intuple, 1)\n",
"Q, R = xops.QR(x, True)\n",
"RQ = xops.Dot(R, Q)\n",
"xops.Tuple(bcb, [RQ, xops.Sub(cntr, xops.Constant(bcb, np.int32(1)))])\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation -- just a for loop condition\n",
"tcb = xc.XlaBuilder(\"testcomp\")\n",
"intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
"cntr = xops.GetTupleElement(intuple, 1)\n",
"test = xops.Gt(cntr, xops.Constant(tcb, np.int32(0)))\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
"x = xops.Parameter(wcb, 0, in_shape_0)\n",
"cntr = xops.Parameter(wcb, 1, in_shape_1)\n",
"tuple_init_carry = xops.Tuple(wcb, [x, cntr])\n",
"xops.While(test_computation, body_computation, tuple_init_carry)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
"\n",
"# compile graph based on shape\n",
"compiled_computation = cpu_backend.compile(while_computation)\n",
"\n",
"X = np.random.random(matrix_shape).astype(np.float32)\n",
"X = (X + X.T) / 2.0\n",
"it = np.array(Niter, dtype=np.int32)\n",
"\n",
"device_input_x = cpu_backend.buffer_from_pyval(X)\n",
"device_input_it = cpu_backend.buffer_from_pyval(it)\n",
"device_out = compiled_computation.execute([device_input_x, device_input_it])\n",
"\n",
"host_out = device_out[0].to_py()\n",
"eigh_vals = host_out.diagonal()\n",
"\n",
"plt.title('D')\n",
"plt.imshow(host_out)\n",
"print('sorted eigenvalues')\n",
"print(np.sort(eigh_vals))\n",
"print('sorted eigenvalues from numpy')\n",
"print(np.sort(np.linalg.eigh(X)[0]))\n",
"print('sorted error') \n",
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FpggTihknAOw"
},
"source": [
"## Calculate Full Symm Eigensystem"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qos4ankYuj1T"
},
"source": [
"We can also calculate the eigenbasis by accumulating the Qs."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "Kp3A-aAiZk0g",
"outputId": "bbaff039-20f4-45cd-b8fe-5a664d413f5b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sorted eigenvalues\n",
"[-0.95164776 -0.5988633 -0.28330874 -0.07402738 0.15438193 0.19796501\n",
" 0.4779069 0.58893895 0.81445134 4.5762177 ]\n",
"sorted eigenvalues from numpy\n",
"[-0.95164794 -0.6314303 -0.28330857 -0.07402731 0.15438198 0.19796519\n",
" 0.47790694 0.62150407 0.81445104 4.5762167 ]\n",
"sorted error\n",
"[ 1.7881393e-07 3.2567024e-02 -1.7881393e-07 -6.7055225e-08\n",
" -4.4703484e-08 -1.7881393e-07 -2.9802322e-08 -3.2565117e-02\n",
" 2.9802322e-07 9.5367432e-07]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAK5UlEQVR4nO3dT4xdB3mG8efF4yjYoRASVBXbwl4gkJUKBU1RSFQWCaqgpGSD1CAFUTZWpQIBIdHQDdsuEIIFSmUlsElEFiaLCEWESsACITlM/lTBdkBJSB0nRrgtCZFd1x7ydTFT5Nqx5/j6Hp+ZL89PiuSZuTl+lczjc++de49TVUjq401TD5A0X0YtNWPUUjNGLTVj1FIzRi01Y9RSM0b9BpXk+ST/neTVJC8n+VmSv0/i98QG5//AN7a/qaq3AO8C/hn4R+DeaSfpUhm1qKpXquoh4G+BTye5bupNmp1R64+q6lHgCPCXU2/R7IxaZ3sJePvUIzQ7o9bZtgH/NfUIzc6o9UdJ/oKVqH869RbNzqhFkj9JcivwAHBfVT019SbNLr6f+o0pyfPAnwLLwGvAQeA+4F+q6g8TTtMlMmqpGe9+S80YtdSMUUvNGLXUzMIYB7327Ztq547Ncz/ur569Zu7HlDaik//zMqdOH8/rfW2UqHfu2Myjj+yY+3H/6hOfnvsxpY3o0SfvPu/XvPstNWPUUjNGLTVj1FIzRi01Y9RSM4OiTvKRJL9M8kySu8YeJWl2a0adZBPwLeCjwG7gk0l2jz1M0myGnKk/ADxTVc9V1SlW3kh/27izJM1qSNTbgBfO+PjI6uf+nyR7kiwlWTr2n77HXprK3J4oq6q9VbVYVYvvuGbTvA4r6SINifpF4MwXcm9f/ZykdWhI1D8H3p1kV5IrgNuBh8adJWlWa75Lq6qWk3wWeATYBHy7qg6MvkzSTAa99bKqHgYeHnmLpDnwFWVSM0YtNWPUUjNGLTVj1FIzo1x48FfPXjPKRQLzs3+b+zEB6sb3jXJcaQqeqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZka5muhYxrrq5+/eu2Xux7z66RNzP6Y0hGdqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqZk1o06yI8mPkxxMciDJnZdjmKTZDHnxyTLwpap6PMlbgMeS/GtVHRx5m6QZrHmmrqqjVfX46q9fBQ4B28YeJmk2F/WYOslO4Hpg/+t8bU+SpSRLp5ePz2edpIs2OOokVwHfA75QVb8/++tVtbeqFqtqcfPC1nlulHQRBkWdZDMrQd9fVQ+OO0nSpRjy7HeAe4FDVfX18SdJuhRDztQ3AZ8Cbk7y5Oo/fz3yLkkzWvNHWlX1UyCXYYukOfAVZVIzRi01Y9RSM0YtNbOhLjw4ljEuErjw7NG5HxPg5J/vGOW4CyeWRzmuLj/P1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM15NdCRjXfXz5LWbRznuVYe9mmgXnqmlZoxaasaopWaMWmrGqKVmjFpqxqilZgZHnWRTkieSfH/MQZIuzcWcqe8EDo01RNJ8DIo6yXbgY8A9486RdKmGnqm/AXwZeO18N0iyJ8lSkqXTy8fnMk7SxVsz6iS3Ar+tqscudLuq2ltVi1W1uHlh69wGSro4Q87UNwEfT/I88ABwc5L7Rl0laWZrRl1VX6mq7VW1E7gd+FFV3TH6Mkkz8efUUjMX9X7qqvoJ8JNRlkiaC8/UUjNGLTVj1FIzRi01Y9RSM15NdCQLJ8a5OudYV/088WdXjnLcLUdPjnJcnZ9naqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGa8mKmC8q34u/O7EKMddvnrLKMftwDO11IxRS80YtdSMUUvNGLXUjFFLzRi11MygqJO8Lcm+JE8nOZTkg2MPkzSboS8++Sbwg6r6RJIrAH/yL61Ta0ad5K3Ah4C/A6iqU8CpcWdJmtWQu9+7gGPAd5I8keSeJFvPvlGSPUmWkiydXj4+96GShhkS9QLwfuDuqroeOA7cdfaNqmpvVS1W1eLmhXOal3SZDIn6CHCkqvavfryPlcglrUNrRl1VvwFeSPKe1U/dAhwcdZWkmQ199vtzwP2rz3w/B3xmvEmSLsWgqKvqSWBx5C2S5sBXlEnNGLXUjFFLzRi11IxRS814NVGNaqyrfi5vmf+37sKJ5bkfcwqeqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxgsPakMa4yKBf7hynBw2nby8FzT0TC01Y9RSM0YtNWPUUjNGLTVj1FIzRi01MyjqJF9MciDJL5J8N8mVYw+TNJs1o06yDfg8sFhV1wGbgNvHHiZpNkPvfi8Ab06yAGwBXhpvkqRLsWbUVfUi8DXgMHAUeKWqfnj27ZLsSbKUZOn08vH5L5U0yJC731cDtwG7gHcCW5PccfbtqmpvVS1W1eLmha3zXyppkCF3vz8M/LqqjlXVaeBB4MZxZ0ma1ZCoDwM3JNmSJMAtwKFxZ0ma1ZDH1PuBfcDjwFOr/87ekXdJmtGgN5BW1VeBr468RdIc+IoyqRmjlpoxaqkZo5aaMWqpGa8mKq0a66qfbzr92vwPWhf4/eb/u0maklFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11EyqLnBZwlkPmhwD/n3ATa8F/mPuA8azkfZupK2wsfauh63vqqp3vN4XRol6qCRLVbU42YCLtJH2bqStsLH2rvet3v2WmjFqqZmpo95of3n9Rtq7kbbCxtq7rrdO+pha0vxNfaaWNGdGLTUzWdRJPpLkl0meSXLXVDvWkmRHkh8nOZjkQJI7p940RJJNSZ5I8v2pt1xIkrcl2Zfk6SSHknxw6k0XkuSLq98Hv0jy3SRXTr3pbJNEnWQT8C3go8Bu4JNJdk+xZYBl4EtVtRu4AfiHdbz1THcCh6YeMcA3gR9U1XuB97GONyfZBnweWKyq64BNwO3TrjrXVGfqDwDPVNVzVXUKeAC4baItF1RVR6vq8dVfv8rKN922aVddWJLtwMeAe6beciFJ3gp8CLgXoKpOVdXL065a0wLw5iQLwBbgpYn3nGOqqLcBL5zx8RHWeSgASXYC1wP7p12ypm8AXwZG+NvO52oXcAz4zupDhXuSbJ161PlU1YvA14DDwFHglar64bSrzuUTZQMluQr4HvCFqvr91HvOJ8mtwG+r6rGptwywALwfuLuqrgeOA+v5+ZWrWblHuQt4J7A1yR3TrjrXVFG/COw44+Ptq59bl5JsZiXo+6vqwan3rOEm4ONJnmflYc3NSe6bdtJ5HQGOVNX/3fPZx0rk69WHgV9X1bGqOg08CNw48aZzTBX1z4F3J9mV5ApWnmx4aKItF5QkrDzmO1RVX596z1qq6itVtb2qdrLy3/VHVbXuziYAVfUb4IUk71n91C3AwQknreUwcEOSLavfF7ewDp/YW5jiN62q5SSfBR5h5RnEb1fVgSm2DHAT8CngqSRPrn7un6rq4Qk3dfI54P7VP9yfAz4z8Z7zqqr9SfYBj7PyU5EnWIcvGfVlolIzPlEmNWPUUjNGLTVj1FIzRi01Y9RSM0YtNfO/5H9lzlLkxRMAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAM+klEQVR4nO3dfWyd9XnG8evCjomdhMQMNEoSiCfaQqi0BVkoNFpVEba1g5Vpa7dUoxPVtFRTX0LLWsE0jQrtv6KKSo3YokA1CVRYUyahDhG2tZWGVGWYwASJCSIv5JUl0JJASHCc3PvDnpQlOH5s/3489p3vR0KKzznc3DL5+jnn+PFjR4QA5HFB2wsAKIuogWSIGkiGqIFkiBpIhqiBZIgaSIaoz1O2w/ZVZ9z2bdsPt7UTyiBqIBmiBpIhaiAZogaSIerz10lJs864bZakEy3sgoKI+vy1W9KSM27rk/TaB78KSiLq89djkv7W9iLbF9i+SdIfSNrQ8l6YIvPz1Ocn292S7pX0OUm9krZL+nZEPNHqYpgyogaS4ek3kAxRA8kQNZAMUQPJdNYYOmt+d8y+bH7xuXHozHMlyjgxr/ybhbO765zDMTTcUWVu1/4qY3Xywjr76tLyn9+OXaeKz5Sk44u6is8cPvQrnXz7qN/vvipRz75svvof+LPic4//44eKz5Sk/SvLR33t1XuKz5SknW9eXGXuFfecrDL3nY+U/+IuSf7SweIz591+rPhMSRq89/LiMw/83dox7+PpN5AMUQPJEDWQDFEDyRA1kAxRA8k0itr2p2xvs/2q7btqLwVg8saN2naHpLWSPi1pqaTP215aezEAk9PkSH29pFcjYkdEDEl6VNKtddcCMFlNol4o6fTTo/aO3vb/2F5te8D2wIm33i21H4AJKvZGWUSsi4j+iOiftaCn1FgAE9Qk6n2SFp/28aLR2wBMQ02iflbSh2332e6StEoS17ECpqlxf0orIoZtf0XSRkkdkh6KiC3VNwMwKY1+9DIinpT0ZOVdABTAGWVAMkQNJEPUQDJEDSRD1EAyVS48OKdjSNddXP7Ce7d/57HiMyVp1do7i8/cOnRF8ZmS5OH3vYDklF3x0H9Xmbv7k3X2ffn3ri0+s3NNlRzUNfud4jPtsS+WyZEaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkimyuUTj52apcEjlxWf+yc/Kn/VT0n6nVX/VXzm4F9dU3ymJG3/0zlV5nac4+qUU9H91Nwqc3sOvVt8Zu/TdXa9+7M/Kj7zGz2/HPM+jtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMuNGbXux7Z/Z3mp7i+01H8RiACanycknw5LujIjNtudJes72v0XE1sq7AZiEcY/UEXEgIjaP/vltSYOSFtZeDMDkTOg1te0lkpZJ2vQ+9622PWB7YOitY2W2AzBhjaO2PVfSjyXdERFHzrw/ItZFRH9E9Hct6C65I4AJaBS17VkaCfqRiHi87koApqLJu9+W9KCkwYj4bv2VAExFkyP1CklfkHSj7RdG//n9ynsBmKRxv6UVEc9I8gewC4ACOKMMSIaogWSIGkiGqIFkqlx4cG7He1p+8c7icw9vu6L4TEna+MT1xWd2frL4SElS9+t15j6zr6/K3GOvLKgy9z9Wfaf4zL9csKr4TEn6/h/9YfGZB7c/OOZ9HKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSqXE10OC7QGyfmFp+793ej+ExJumRT+blv/ladXXd87h+qzF32bJ0rad7w21uqzF31zb8uPvO9+XV+u1TPVaeKzzz5WseY93GkBpIhaiAZogaSIWogGaIGkiFqIBmiBpJpHLXtDtvP2/5JzYUATM1EjtRrJA3WWgRAGY2itr1I0s2S1tddB8BUNT1S3y/pW5LGPN/N9mrbA7YHjv3qvSLLAZi4caO2fYukgxHx3LkeFxHrIqI/Ivq7ey8stiCAiWlypF4h6TO2d0l6VNKNth+uuhWASRs36oi4OyIWRcQSSask/TQibqu+GYBJ4fvUQDIT+nnqiPi5pJ9X2QRAERypgWSIGkiGqIFkiBpIhqiBZKpcTfS9U53afbS3+Nylf7+v+ExJ2nrvh4rP7L7oePGZktS38S+qzL1wzlCVua+sXVpl7oLt7xSfefM//WfxmZK08VD5z0HHS2P//+JIDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kU+VqohHW8ZOzis89sr7O772+9OTh8jO/9G7xmZK05/sXVZm76M/rXKl1+zevrTL3yJLyn4dnj1xZfKYkbdu0pPjM4+90jXkfR2ogGaIGkiFqIBmiBpIhaiAZogaSIWogmUZR215ge4Ptl20P2r6h9mIAJqfpySffk/RURHzWdpeknoo7AZiCcaO2PV/SJyTdLkkRMSSpzi8zBjBlTZ5+90k6JOkHtp+3vd72nDMfZHu17QHbA0OHjxVfFEAzTaLulHSdpAciYpmko5LuOvNBEbEuIvojor9rfnfhNQE01STqvZL2RsSm0Y83aCRyANPQuFFHxOuS9tj+6OhNKyVtrboVgElr+u73VyU9MvrO9w5JX6y3EoCpaBR1RLwgqb/yLgAK4IwyIBmiBpIhaiAZogaSIWogmSpXE53dcUIfuehg8bkHj88rPlOSlvS8WXzmHb/4RfGZkrT8X75RZe6uNb1V5v7xLc9Umfuvr5W/Sukbx+cWnylJvvJo+ZkXnhrzPo7UQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRT5cKDR4e7NHDoiuJz9+28pPhMSVqyvPyFB2tdIPCa+/+nytzBO3+9ytx//vcVVeZe9djbxWce7ltcfKYk/caud4vPPPj62PdxpAaSIWogGaIGkiFqIBmiBpIhaiAZogaSaRS17a/b3mL7Jds/tD279mIAJmfcqG0vlPQ1Sf0R8TFJHZJW1V4MwOQ0ffrdKanbdqekHkn7660EYCrGjToi9km6T9JuSQckHY6Ip898nO3VtgdsD5w4fKz8pgAaafL0u1fSrZL6JF0uaY7t2858XESsi4j+iOifNb+7/KYAGmny9PsmSTsj4lBEnJD0uKSP110LwGQ1iXq3pOW2e2xb0kpJg3XXAjBZTV5Tb5K0QdJmSS+O/jvrKu8FYJIa/Tx1RNwj6Z7KuwAogDPKgGSIGkiGqIFkiBpIhqiBZKpcTXR2x7Cu6S1/1cv9r1xafKYkPbes/Ne2C+5z8ZmSdPLX5lWZe/ELdb6+X3X7tipzX/zl1cVnDvVG8ZmS1PvlN4rP9OrhMe/jSA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJOOI8ldQtH1I0msNHnqJpPKXWqxnJu07k3aVZta+02HXKyPifS+vWyXqpmwPRER/awtM0EzadybtKs2sfaf7rjz9BpIhaiCZtqOeab+8fibtO5N2lWbWvtN611ZfUwMor+0jNYDCiBpIprWobX/K9jbbr9q+q609xmN7se2f2d5qe4vtNW3v1ITtDtvP2/5J27uci+0FtjfYftn2oO0b2t7pXGx/ffTvwUu2f2h7dts7namVqG13SFor6dOSlkr6vO2lbezSwLCkOyNiqaTlkr48jXc93RpJg20v0cD3JD0VEVdL+k1N451tL5T0NUn9EfExSR2SVrW71dnaOlJfL+nViNgREUOSHpV0a0u7nFNEHIiIzaN/flsjf+kWtrvVudleJOlmSevb3uVcbM+X9AlJD0pSRAxFxFvtbjWuTkndtjsl9Uja3/I+Z2kr6oWS9pz28V5N81AkyfYSScskbWp3k3HdL+lbkk61vcg4+iQdkvSD0ZcK623PaXupsUTEPkn3Sdot6YCkwxHxdLtbnY03yhqyPVfSjyXdERFH2t5nLLZvkXQwIp5re5cGOiVdJ+mBiFgm6aik6fz+Sq9GnlH2Sbpc0hzbt7W71dnainqfpMWnfbxo9LZpyfYsjQT9SEQ83vY+41gh6TO2d2nkZc2Nth9ud6Ux7ZW0NyL+75nPBo1EPl3dJGlnRByKiBOSHpf08ZZ3OktbUT8r6cO2+2x3aeTNhida2uWcbFsjr/kGI+K7be8znoi4OyIWRcQSjXxefxoR0+5oIkkR8bqkPbY/OnrTSklbW1xpPLslLbfdM/r3YqWm4Rt7nW38RyNi2PZXJG3UyDuID0XEljZ2aWCFpC9IetH2C6O3/U1EPNniTpl8VdIjo1/cd0j6Ysv7jCkiNtneIGmzRr4r8rym4SmjnCYKJMMbZUAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAy/wvI0tLk+VKqNgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAEICAYAAACHyrIWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAALx0lEQVR4nO3dX4xc9XmH8eeL1whsqoRCKjW2VbttSoRSpURbREDiAmiVlDRILRcgQDQ3qGqTkDRSRKpKVL2soii5SGktkvQCFFIZlCKKApWSXESRHMyfKmATxJ/UGEzjVOSfXWMvvL3YieTaeHe8nsPZfft8pJV2zgxnXlv7cM7MnP05VYWkPs4YewBJs2XUUjNGLTVj1FIzRi01Y9RSM0bdTJLzk9ye5LKxZ9E4jHoNSFJJfvu4bX+b5K7jtm0E/g34A+CBJBcdc98NSX4x+fqfJG8cc/sXSzx3kjyfZPes5tSwjLqJJOuBe4HdwOXAnwP3J/ktgKq6u6rOqapzgA8CL//y9mTbyVwO/Brwm0l+f9g/hWZhbuwBdPqSBPhn4AXgL2rxMsGvJXmNxbCvqKr/WuHubwb+FTh78v0jMxhZAzLqBiYR3/Am278OfH2l+02yAbgWuI7FqP8pyV9V1ZGV7lPD8/RbS/kT4DXgYRZfq68Hrh51Ii3LqNeG11kM6ljrgaMDP+/NwL9U1UJVHWbxNfvNSzx+rDl1DE+/14a9wFZgzzHbtgHPDPWESTYDVwAXJ/nTyeYNwFlJzq+qH6+GOXUij9Rrw9eAv0myOckZSa4C/hjYMeBz3sRijBcAvzf5+h1gH3D9KppTxzHqteHvgO8C3wFeBf4euKGqnhzwOW8G/qGqXjn2C/hHTn4KPsacOk5cJEHqxSO11IxRS80YtdSMUUvNDPI59fm/uq62bjn+GoTT98xz5818n9JadPi1n3Dk6MG82X2DRL11y3q+99CWme/3D69d6mIm6f+P7z1xx0nv8/RbasaopWaMWmrGqKVmjFpqxqilZqaKOskHkvwgybNJbht6KEkrt2zUSdYBX2RxBcoLgeuTXDj0YJJWZpoj9cXAs1X1/GTBuXuAa4YdS9JKTRP1JuDFY27vm2z7P5LckmRXkl0H/vv1Wc0n6RTN7I2yqtpeVfNVNf+O89bNareSTtE0Ub8EHHsh9+bJNkmr0DRRPwK8K8m2JGeyuLD7/cOOJWmllv0trapaSPJR4CFgHfDlqnpq8MkkrchUv3pZVQ8CDw48i6QZ8IoyqRmjlpoxaqkZo5aaMWqpmUEWHnzmufMGWSQw3/2Pme8ToC597yD7lcbgkVpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaamaQ1USHMtSqn6++e8PM93nu04dmvk9pGh6ppWaMWmrGqKVmjFpqxqilZoxaasaopWaWjTrJliTfSrI7yVNJbn0rBpO0MtNcfLIAfKqqHkvyK8CjSf69qnYPPJukFVj2SF1V+6vqscn3Pwf2AJuGHkzSypzSa+okW4GLgJ1vct8tSXYl2XV04eBsppN0yqaOOsk5wL3AJ6rqZ8ffX1Xbq2q+qubXz22c5YySTsFUUSdZz2LQd1fVfcOOJOl0TPPud4AvAXuq6nPDjyTpdExzpL4MuAm4IskTk68/GnguSSu07EdaVfUdIG/BLJJmwCvKpGaMWmrGqKVmjFpqZk0tPDiUIRYJnHtu/8z3CXD4d7cMst+5QwuD7FdvPY/UUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzriY6kKFW/Tx8/vpB9nvOXlcT7cIjtdSMUUvNGLXUjFFLzRi11IxRS80YtdTM1FEnWZfk8SQPDDmQpNNzKkfqW4E9Qw0iaTamijrJZuBq4M5hx5F0uqY9Un8e+DTwxskekOSWJLuS7Dq6cHAmw0k6dctGneRDwI+q6tGlHldV26tqvqrm189tnNmAkk7NNEfqy4APJ/khcA9wRZK7Bp1K0ootG3VVfaaqNlfVVuA64JtVdePgk0laET+nlpo5pd+nrqpvA98eZBJJM+GRWmrGqKVmjFpqxqilZoxaasbVRAcyd2iY1TmHWvXz0K+fNch+N+w/PMh+dXIeqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZlxNVMBwq37OvXpokP0unLthkP124JFaasaopWaMWmrGqKVmjFpqxqilZoxaamaqqJO8PcmOJE8n2ZPk/UMPJmllpr345AvAN6rq2iRnAn7yL61Sy0ad5G3A5cCfAVTVEeDIsGNJWqlpTr+3AQeAryR5PMmdSTYe/6AktyTZlWTX0YWDMx9U0nSmiXoOeB9wR1VdBBwEbjv+QVW1varmq2p+/dwJzUt6i0wT9T5gX1XtnNzewWLkklahZaOuqleAF5NcMNl0JbB70Kkkrdi0735/DLh78s7388BHhhtJ0umYKuqqegKYH3gWSTPgFWVSM0YtNWPUUjNGLTVj1FIzriaqQQ216ufChtn/6M4dWpj5PsfgkVpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZlx4UGvSEIsEvn7WMDmsO/zWLmjokVpqxqilZoxaasaopWaMWmrGqKVmjFpqZqqok3wyyVNJnkzy1SRnDT2YpJVZNuokm4CPA/NV9R5gHXDd0INJWplpT7/ngLOTzAEbgJeHG0nS6Vg26qp6CfgssBfYD/y0qh4+/nFJbkmyK8muowsHZz+ppKlMc/p9LnANsA14J7AxyY3HP66qtlfVfFXNr5/bOPtJJU1lmtPvq4AXqupAVR0F7gMuHXYsSSs1TdR7gUuSbEgS4Epgz7BjSVqpaV5T7wR2AI8B35/8N9sHnkvSCk31C6RVdTtw+8CzSJoBryiTmjFqqRmjlpoxaqkZo5aacTVRaWKoVT/POPrG7HdaSzzf7J9N0piMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmUrXEsoQr3WlyAPjPKR56PvDjmQ8wnLU071qaFdbWvKth1t+oqne82R2DRD2tJLuqan60AU7RWpp3Lc0Ka2ve1T6rp99SM0YtNTN21GvtH69fS/OupVlhbc27qmcd9TW1pNkb+0gtacaMWmpmtKiTfCDJD5I8m+S2seZYTpItSb6VZHeSp5LcOvZM00iyLsnjSR4Ye5alJHl7kh1Jnk6yJ8n7x55pKUk+Ofk5eDLJV5OcNfZMxxsl6iTrgC8CHwQuBK5PcuEYs0xhAfhUVV0IXAL85Sqe9Vi3AnvGHmIKXwC+UVXvBt7LKp45ySbg48B8Vb0HWAdcN+5UJxrrSH0x8GxVPV9VR4B7gGtGmmVJVbW/qh6bfP9zFn/oNo071dKSbAauBu4ce5alJHkbcDnwJYCqOlJVPxl3qmXNAWcnmQM2AC+PPM8Jxop6E/DiMbf3scpDAUiyFbgI2DnuJMv6PPBpYIB/7XymtgEHgK9MXircmWTj2EOdTFW9BHwW2AvsB35aVQ+PO9WJfKNsSknOAe4FPlFVPxt7npNJ8iHgR1X16NizTGEOeB9wR1VdBBwEVvP7K+eyeEa5DXgnsDHJjeNOdaKxon4J2HLM7c2TbatSkvUsBn13Vd039jzLuAz4cJIfsviy5ookd4070kntA/ZV1S/PfHawGPlqdRXwQlUdqKqjwH3ApSPPdIKxon4EeFeSbUnOZPHNhvtHmmVJScLia749VfW5sedZTlV9pqo2V9VWFv9ev1lVq+5oAlBVrwAvJrlgsulKYPeIIy1nL3BJkg2Tn4srWYVv7M2N8aRVtZDko8BDLL6D+OWqemqMWaZwGXAT8P0kT0y2/XVVPTjiTJ18DLh78j/354GPjDzPSVXVziQ7gMdY/FTkcVbhJaNeJio14xtlUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjP/C52Cw8vnl5XSAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"Niter = 100\n",
"matrix_shape = (10, 10)\n",
"\n",
"in_shape_0 = xc.Shape.array_shape(np.dtype(np.float32), matrix_shape)\n",
"in_shape_1 = xc.Shape.array_shape(np.dtype(np.float32), matrix_shape)\n",
"in_shape_2 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
"in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1, in_shape_2])\n",
"\n",
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
"intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
"X = xops.GetTupleElement(intuple, 0)\n",
"O = xops.GetTupleElement(intuple, 1)\n",
"cntr = xops.GetTupleElement(intuple, 2)\n",
"Q, R = xops.QR(X, True)\n",
"RQ = xops.Dot(R, Q)\n",
"Onew = xops.Dot(O, Q)\n",
"xops.Tuple(bcb, [RQ, Onew, xops.Sub(cntr, xops.Constant(bcb, np.int32(1)))])\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation -- just a for loop condition\n",
"tcb = xc.XlaBuilder(\"testcomp\")\n",
"intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
"cntr = xops.GetTupleElement(intuple, 2)\n",
"test = xops.Gt(cntr, xops.Constant(tcb, np.int32(0)))\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
"X = xops.Parameter(wcb, 0, in_shape_0)\n",
"O = xops.Parameter(wcb, 1, in_shape_1)\n",
"cntr = xops.Parameter(wcb, 2, in_shape_2)\n",
"tuple_init_carry = xops.Tuple(wcb, [X, O, cntr])\n",
"xops.While(test_computation, body_computation, tuple_init_carry)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
"\n",
"# compile graph based on shape\n",
"compiled_computation = cpu_backend.compile(while_computation)\n",
"\n",
"X = np.random.random(matrix_shape).astype(np.float32)\n",
"X = (X + X.T) / 2.0\n",
"Omat = np.eye(matrix_shape[0], dtype=np.float32)\n",
"it = np.array(Niter, dtype=np.int32)\n",
"\n",
"device_input_X = cpu_backend.buffer_from_pyval(X)\n",
"device_input_Omat = cpu_backend.buffer_from_pyval(Omat)\n",
"device_input_it = cpu_backend.buffer_from_pyval(it)\n",
"device_out = compiled_computation.execute([device_input_X, device_input_Omat, device_input_it])\n",
"\n",
"host_out = device_out[0].to_py()\n",
"eigh_vals = host_out.diagonal()\n",
"eigh_mat = device_out[1].to_py()\n",
"\n",
"plt.title('D')\n",
"plt.imshow(host_out)\n",
"plt.figure()\n",
"plt.title('U')\n",
"plt.imshow(eigh_mat)\n",
"plt.figure()\n",
"plt.title('U^T A U')\n",
"plt.imshow(np.dot(np.dot(eigh_mat.T, X), eigh_mat))\n",
"print('sorted eigenvalues')\n",
"print(np.sort(eigh_vals))\n",
"print('sorted eigenvalues from numpy')\n",
"print(np.sort(np.linalg.eigh(X)[0]))\n",
"print('sorted error') \n",
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ee3LMzOvlCuK"
},
"source": [
"## Convolutions\n",
"\n",
"I keep hearing from the AGI folks that we can use convolutions to build artificial life. Let's try it out."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "9xh6yeXKS9Vg"
},
"outputs": [],
"source": [
"# Here we borrow convenience functions from LAX to handle conv dimension numbers.\n",
"from typing import NamedTuple, Sequence\n",
"\n",
"class ConvDimensionNumbers(NamedTuple):\n",
" \"\"\"Describes batch, spatial, and feature dimensions of a convolution.\n",
"\n",
" Args:\n",
" lhs_spec: a tuple of nonnegative integer dimension numbers containing\n",
" `(batch dimension, feature dimension, spatial dimensions...)`.\n",
" rhs_spec: a tuple of nonnegative integer dimension numbers containing\n",
" `(out feature dimension, in feature dimension, spatial dimensions...)`.\n",
" out_spec: a tuple of nonnegative integer dimension numbers containing\n",
" `(batch dimension, feature dimension, spatial dimensions...)`.\n",
" \"\"\"\n",
" lhs_spec: Sequence[int]\n",
" rhs_spec: Sequence[int]\n",
" out_spec: Sequence[int]\n",
"\n",
"def _conv_general_proto(dimension_numbers):\n",
" assert type(dimension_numbers) is ConvDimensionNumbers\n",
" lhs_spec, rhs_spec, out_spec = dimension_numbers\n",
" proto = xc.ConvolutionDimensionNumbers()\n",
" proto.input_batch_dimension = lhs_spec[0]\n",
" proto.input_feature_dimension = lhs_spec[1]\n",
" proto.output_batch_dimension = out_spec[0]\n",
" proto.output_feature_dimension = out_spec[1]\n",
" proto.kernel_output_feature_dimension = rhs_spec[0]\n",
" proto.kernel_input_feature_dimension = rhs_spec[1]\n",
" proto.input_spatial_dimensions.extend(lhs_spec[2:])\n",
" proto.kernel_spatial_dimensions.extend(rhs_spec[2:])\n",
" proto.output_spatial_dimensions.extend(out_spec[2:])\n",
" return proto"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "J8QkirDalBse",
"outputId": "543a03fd-f038-46f2-9a76-a6532b86874e"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABEYAAABdCAYAAACo5mNeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAADjUlEQVR4nO3dwW0TURRAUTtKFVRBE4gKqJIKEE1QBWUwWSEFL4ideTMT+56zs7Kw9fRnc/Xn5bwsywkAAACg6OnoHwAAAABwFGEEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAICs5//98cvTN/+yZoWff76fX382z3Vez9Ms1zHLOZ7zWc7mHLOc4zmf5WzOcTZnOZtzzHKO53zW5Tz/cmMEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAIAsYQQAAADIEkYAAACALGEEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAIAsYQQAAADIEkYAAACALGEEAAAAyBJGAAAAgCxhBAAAAMgSRgAAAIAsYQQAAADIEkYAAACALGEEAAAAyHo+8st//P71z+evnz4f9EsAAACAIjdGAAAAgCxhBAAAAMgSRgAAAICsQ3eMXO4UsXPkNm/NyzznmOX7md22zBcAANZxYwQAAADIEkYAAACALGEEAAAAyNp1x8jlu/Cs89ZOEa5nT8N2bj2nZn8bu4Xez56m/ZjlOua3HbMF4HRyYwQAAAAIE0YAAACALGEEAAAAyNp1x4j3Nvdl3tezr2U7dorMcjbneO7n2NOwLbuatmO30Dp2Ne3HLN/P7Lb1KPN1YwQAAADIEkYAAACALGEEAAAAyNp1xwjbutf3uT4is5xjlrPMcz9mfT37WrZlp8gcZ3OWZ3/Oo+xp+IjsadrWo+wWcmMEAAAAyBJGAAAAgCxhBAAAAMiyYwQATvfzDuw9MMtZ5jnHLPdl3tezr2U7dorMetSz6cYIAAAAkCWMAAAAAFnCCAAAAJBlxwgAALCaXQ1zzHKOWc561Hm6MQIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGQJIwAAAECWMAIAAABkCSMAAABAljACAAAAZAkjAAAAQJYwAgAAAGSdl2U5+jcAAAAAHMKNEQAAACBLGAEAAACyhBEAAAAgSxgBAAAAsoQRAAAAIEsYAQAAALJeAPxyuoaRade9AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 1080x144 with 13 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"Niter=13\n",
"matrix_shape = (1, 1, 20, 20)\n",
"in_shape_0 = xc.Shape.array_shape(np.dtype(np.int32), matrix_shape)\n",
"in_shape_1 = xc.Shape.array_shape(np.dtype(np.int32), ())\n",
"in_tuple_shape = xc.Shape.tuple_shape([in_shape_0, in_shape_1])\n",
"\n",
"# Body computation -- Conway Update\n",
"bcb = xc.XlaBuilder(\"bodycomp\")\n",
"intuple = xops.Parameter(bcb, 0, in_tuple_shape)\n",
"x = xops.GetTupleElement(intuple, 0)\n",
"cntr = xops.GetTupleElement(intuple, 1)\n",
"# convs require floating-point type\n",
"xf = xops.ConvertElementType(x, xc.DTYPE_TO_XLA_ELEMENT_TYPE['float32'])\n",
"stamp = xops.Constant(bcb, np.ones((1,1,3,3), dtype=np.float32))\n",
"conv_dim_num_proto = _conv_general_proto(ConvDimensionNumbers(lhs_spec=(0,1,2,3), rhs_spec=(0,1,2,3), out_spec=(0,1,2,3)))\n",
"convd = xops.ConvGeneralDilated(xf, stamp, [1, 1], [(1, 1), (1, 1)], (), (), conv_dim_num_proto)\n",
"# # logic ops require integer types\n",
"convd = xops.ConvertElementType(convd, xc.DTYPE_TO_XLA_ELEMENT_TYPE['int32'])\n",
"bool_x = xops.Eq(x, xops.Constant(bcb, np.int32(1)))\n",
"# core update rule\n",
"res = xops.Or(\n",
" # birth rule\n",
" xops.And(xops.Not(bool_x), xops.Eq(convd, xops.Constant(bcb, np.int32(3)))),\n",
" # survival rule\n",
" xops.And(bool_x, xops.Or(\n",
" # these are +1 the normal numbers since conv-sum counts self\n",
" xops.Eq(convd, xops.Constant(bcb, np.int32(4))),\n",
" xops.Eq(convd, xops.Constant(bcb, np.int32(3))))\n",
" )\n",
")\n",
"# Convert output back to int type for type constancy\n",
"int_res = xops.ConvertElementType(res, xc.DTYPE_TO_XLA_ELEMENT_TYPE['int32'])\n",
"xops.Tuple(bcb, [int_res, xops.Sub(cntr, xops.Constant(bcb, np.int32(1)))])\n",
"body_computation = bcb.Build()\n",
"\n",
"# Test computation -- just a for loop condition\n",
"tcb = xc.XlaBuilder(\"testcomp\")\n",
"intuple = xops.Parameter(tcb, 0, in_tuple_shape)\n",
"cntr = xops.GetTupleElement(intuple, 1)\n",
"test = xops.Gt(cntr, xops.Constant(tcb, np.int32(0)))\n",
"test_computation = tcb.Build()\n",
"\n",
"# While computation:\n",
"wcb = xc.XlaBuilder(\"whilecomp\")\n",
"x = xops.Parameter(wcb, 0, in_shape_0)\n",
"cntr = xops.Parameter(wcb, 1, in_shape_1)\n",
"tuple_init_carry = xops.Tuple(wcb, [x, cntr])\n",
"xops.While(test_computation, body_computation, tuple_init_carry)\n",
"while_computation = wcb.Build()\n",
"\n",
"# Now compile and execute:\n",
"cpu_backend = xc.get_local_backend(\"cpu\")\n",
"\n",
"# compile graph based on shape\n",
"compiled_computation = cpu_backend.compile(while_computation)\n",
"\n",
"# Set up initial state\n",
"X = np.zeros(matrix_shape, dtype=np.int32)\n",
"X[0,0, 5:8, 5:8] = np.array([[0,1,0],[0,0,1],[1,1,1]])\n",
"\n",
"# Evolve\n",
"movie = np.zeros((Niter,)+matrix_shape[-2:], dtype=np.int32)\n",
"for it in range(Niter):\n",
" itr = np.array(it, dtype=np.int32)\n",
" device_input_x = cpu_backend.buffer_from_pyval(X)\n",
" device_input_it = cpu_backend.buffer_from_pyval(itr)\n",
" device_out = compiled_computation.execute([device_input_x, device_input_it])\n",
" movie[it] = device_out[0].to_py()[0,0]\n",
"\n",
"# Plot\n",
"fig = plt.figure(figsize=(15,2))\n",
"gs = gridspec.GridSpec(1,Niter)\n",
"for i in range(Niter):\n",
" ax1 = plt.subplot(gs[:, i])\n",
" ax1.axis('off')\n",
" ax1.imshow(movie[i])\n",
"plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, hspace=0.0, wspace=0.05)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9-0PJlqv237S"
},
"source": [
"## Fin \n",
"\n",
"There's much more to XLA, but this hopefully highlights how easy it is to play with via the python client!"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "XLA in Python.ipnb",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}