mirror of
https://github.com/ROCm/jax.git
synced 2025-04-27 17:36:06 +00:00
853 lines
47 KiB
Plaintext
853 lines
47 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "XLA in Python.ipnb",
|
|
"version": "0.3.2",
|
|
"provenance": [],
|
|
"collapsed_sections": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.2"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "sAgUgR5Mzzz2"
|
|
},
|
|
"source": [
|
|
"# XLA in Python\n",
|
|
"\n",
|
|
"<img style=\"height:100px;\" src=\"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/compiler/xla/g3doc/images/xlalogo.png\"> <img style=\"height:100px;\" src=\"https://upload.wikimedia.org/wikipedia/commons/c/c3/Python-logo-notext.svg\">\n",
|
|
"\n",
|
|
"_Anselm Levskaya_ \n",
|
|
"\n",
|
|
"XLA is the compiler that JAX uses, and the compiler that TF uses for TPUs and will soon use for all devices, so it's worth some study. However, it's not exactly easy to play with XLA computations directly using the raw C++ interface. JAX exposes the underlying XLA computation builder API through a python wrapper, and makes interacting with the XLA compute model accessible for messing around and prototyping.\n",
|
|
"\n",
|
|
"XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.). \n",
|
|
"\n",
|
|
"As end users we interact with the computational primitives offered to us by the HLO spec."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "EZK5RseuvZkr"
|
|
},
|
|
"source": [
|
|
"## References \n",
|
|
"\n",
|
|
"__xla__: the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n",
|
|
"\n",
|
|
"https://www.tensorflow.org/xla/operation_semantics\n",
|
|
"\n",
|
|
"more details on ops in the source code.\n",
|
|
"\n",
|
|
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h\n",
|
|
"\n",
|
|
"__python xla client__: this is the XLA python client for JAX, and what we're using here.\n",
|
|
"\n",
|
|
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py\n",
|
|
"\n",
|
|
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client_test.py\n",
|
|
"\n",
|
|
"__jax__: you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n",
|
|
"\n",
|
|
"https://github.com/google/jax/blob/master/jax/lax.py\n",
|
|
"\n",
|
|
"https://github.com/google/jax/blob/master/jax/lib/xla_bridge.py\n",
|
|
"\n",
|
|
"https://github.com/google/jax/blob/master/jax/interpreters/xla.py"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "3XR2NGmrzBGe"
|
|
},
|
|
"source": [
|
|
"## Colab Setup and Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "Ogo2SBd3u18P",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"# We only need to import JAX's xla_client, not all of JAX.\n",
|
|
"from jaxlib import xla_client\n",
|
|
"\n",
|
|
"# Plotting\n",
|
|
"import matplotlib as mpl\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"from matplotlib import gridspec\n",
|
|
"from matplotlib import rcParams\n",
|
|
"rcParams['image.interpolation'] = 'nearest'\n",
|
|
"rcParams['image.cmap'] = 'viridis'\n",
|
|
"rcParams['axes.grid'] = False"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "0cf7swaobc5l"
|
|
},
|
|
"source": [
|
|
"## Convenience Functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "5I50k0rhbg6W",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"# Here we borrow convenience functions from JAX to convert numpy shape/dtypes\n",
|
|
"# to XLA appropriate shape/dtypes\n",
|
|
"def canonicalize_dtype(dtype):\n",
|
|
" \"\"\"We restrict ourselves to 32bit types for this demo.\"\"\"\n",
|
|
" _dtype_to_32bit_dtype = {\n",
|
|
" str(np.dtype('int64')): np.dtype('int32'),\n",
|
|
" str(np.dtype('uint64')): np.dtype('uint32'),\n",
|
|
" str(np.dtype('float64')): np.dtype('float32'),\n",
|
|
" str(np.dtype('complex128')): np.dtype('complex64'),\n",
|
|
" }\n",
|
|
" dtype = np.dtype(dtype)\n",
|
|
" return _dtype_to_32bit_dtype.get(str(dtype), dtype)\n",
|
|
"\n",
|
|
"def shape_of(value):\n",
|
|
" \"\"\"Given a Python or XLA value, return its canonicalized XLA Shape.\"\"\"\n",
|
|
" if hasattr(value, 'shape') and hasattr(value, 'dtype'):\n",
|
|
" return xla_client.Shape.array_shape(canonicalize_dtype(value.dtype), \n",
|
|
" value.shape)\n",
|
|
" elif np.isscalar(value):\n",
|
|
" return shape_of(np.asarray(value))\n",
|
|
" elif isinstance(value, (tuple, list)):\n",
|
|
" return xla_client.Shape.tuple_shape(tuple(shape_of(elt) for elt in value))\n",
|
|
" else:\n",
|
|
" raise TypeError('Unexpected type: {}'.format(type(value)))\n",
|
|
"\n",
|
|
"def to_xla_type(dtype):\n",
|
|
" \"Convert to integert xla type, for use with ConvertElementType, etc.\"\n",
|
|
" if isinstance(dtype, str):\n",
|
|
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype]\n",
|
|
" elif isinstance(dtype, type):\n",
|
|
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[np.dtype(dtype).name]\n",
|
|
" elif isinstance(dtype, np.dtype):\n",
|
|
" return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype.name]\n",
|
|
" else:\n",
|
|
" raise TypeError('Unexpected type: {}'.format(type(dtype)))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "odmjXyhMuNJ5"
|
|
},
|
|
"source": [
|
|
"## Simple Computations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "UYUtxVzMYIiv",
|
|
"outputId": "bd8aa18e-26d9-4df4-ebc3-20026119de17",
|
|
"colab": {
|
|
"height": 33
|
|
}
|
|
},
|
|
"source": [
|
|
"# make a computation builder\n",
|
|
"c = xla_client.ComputationBuilder(\"simple_scalar\")\n",
|
|
"\n",
|
|
"# define a parameter shape and parameter\n",
|
|
"param_shape = xla_client.Shape.array_shape(np.dtype(np.float32), ())\n",
|
|
"x = c.ParameterWithShape(param_shape)\n",
|
|
"\n",
|
|
"# define computation graph\n",
|
|
"y = c.Sin(x)\n",
|
|
"\n",
|
|
"# build computation graph\n",
|
|
"# Keep in mind that incorrectly constructed graphs can cause \n",
|
|
"# your notebook kernel to crash!\n",
|
|
"computation = c.Build()\n",
|
|
"\n",
|
|
"# compile graph based on shape\n",
|
|
"compiled_computation = computation.Compile([param_shape,])\n",
|
|
"\n",
|
|
"# define a host variable with above parameter shape\n",
|
|
"host_input = np.array(3.0, dtype=np.float32)\n",
|
|
"\n",
|
|
"# place host variable on device and execute\n",
|
|
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
|
|
"device_out = compiled_computation.Execute([device_input ,])\n",
|
|
"\n",
|
|
"# retrive the result\n",
|
|
"device_out.to_py()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": [
|
|
{
|
|
"output_type": "execute_result",
|
|
"data": {
|
|
"text/plain": [
|
|
"array(0.14112, dtype=float32)"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"execution_count": 3
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "rIA-IVMVvQs2",
|
|
"outputId": "ce88ec6f-d2ea-4ec2-80b4-ddd1afd36957",
|
|
"colab": {
|
|
"height": 33
|
|
}
|
|
},
|
|
"source": [
|
|
"# same as above with vector type:\n",
|
|
"\n",
|
|
"c = xla_client.ComputationBuilder(\"simple_vector\")\n",
|
|
"param_shape = xla_client.Shape.array_shape(np.dtype(np.float32), (3,))\n",
|
|
"x = c.ParameterWithShape(param_shape)\n",
|
|
"\n",
|
|
"# can also use this function to define a shape from an example:\n",
|
|
"#x = c.ParameterFromNumpy(np.array([0.0, 0.0, 0.0], dtype=np.float32))\n",
|
|
"\n",
|
|
"# which is the same as using our convenience function above:\n",
|
|
"#x = c.ParameterWithShape(shape_of(np.array([0.0, 0.0, 0.0], \n",
|
|
"# dtype=np.float32)))\n",
|
|
"\n",
|
|
"# chain steps by reference:\n",
|
|
"y = c.Sin(x)\n",
|
|
"z = c.Abs(y)\n",
|
|
"computation = c.Build()\n",
|
|
"compiled_computation = computation.Compile([param_shape,])\n",
|
|
"\n",
|
|
"host_input = np.array([3.0, 4.0, 5.0], dtype=np.float32)\n",
|
|
"\n",
|
|
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
|
|
"device_out = compiled_computation.Execute([device_input ,])\n",
|
|
"\n",
|
|
"# retrive the result\n",
|
|
"device_out.to_py()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": [
|
|
{
|
|
"output_type": "execute_result",
|
|
"data": {
|
|
"text/plain": [
|
|
"array([0.14112 , 0.7568025, 0.9589243], dtype=float32)"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"execution_count": 4
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "F8kWlLaVuQ1b"
|
|
},
|
|
"source": [
|
|
"## Simple While Loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "MDQP1qW515Ao",
|
|
"outputId": "4da894b5-2b0e-455e-a720-3bdadc57d164",
|
|
"colab": {
|
|
"height": 33
|
|
}
|
|
},
|
|
"source": [
|
|
"# trivial while loop, decrement until 0\n",
|
|
"# x = 5\n",
|
|
"# while x > 0:\n",
|
|
"# x = x - 1\n",
|
|
"#\n",
|
|
"in_shape = shape_of(5)\n",
|
|
"\n",
|
|
"# body computation:\n",
|
|
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
|
|
"x = bcb.ParameterWithShape(in_shape)\n",
|
|
"const1 = bcb.Constant(np.int32(1))\n",
|
|
"y = bcb.Sub(x, const1)\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# test computation:\n",
|
|
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
|
|
"x = tcb.ParameterWithShape(in_shape)\n",
|
|
"const0 = tcb.Constant(np.int32(0))\n",
|
|
"y = tcb.Gt(x, const0)\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# while computation:\n",
|
|
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
|
|
"x = wcb.ParameterWithShape(in_shape)\n",
|
|
"wcb.While(test_computation, body_computation, x)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
|
"\n",
|
|
"host_input = np.array(5, dtype=np.int32)\n",
|
|
"\n",
|
|
"device_input = xla_client.LocalBuffer.from_pyval(host_input)\n",
|
|
"device_out = compiled_computation.Execute([device_input ,])\n",
|
|
"\n",
|
|
"# retrive the result\n",
|
|
"device_out.to_py()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": [
|
|
{
|
|
"output_type": "execute_result",
|
|
"data": {
|
|
"text/plain": [
|
|
"array(0, dtype=int32)"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"execution_count": 5
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "7UOnXlY8slI6"
|
|
},
|
|
"source": [
|
|
"## While loops w. tuples - Newton's Method for sqrt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "HEWz-vzd6QPR",
|
|
"outputId": "6ef10855-232d-4701-a442-0e2667b2fd97",
|
|
"colab": {
|
|
"height": 33
|
|
}
|
|
},
|
|
"source": [
|
|
"Xsqr = 2\n",
|
|
"guess = 1.0\n",
|
|
"converged_delta = 0.001\n",
|
|
"maxit = 1000\n",
|
|
"\n",
|
|
"in_shape = shape_of((1.0, 1.0, 1))\n",
|
|
"\n",
|
|
"# body computation:\n",
|
|
"# x_{i+1} = x_{i} - (x_i**2 - y) / (2 * x_i)\n",
|
|
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
|
|
"intuple = bcb.ParameterWithShape(in_shape)\n",
|
|
"y = bcb.GetTupleElement(intuple, 0)\n",
|
|
"x = bcb.GetTupleElement(intuple, 1)\n",
|
|
"guard_cntr = bcb.GetTupleElement(intuple, 2)\n",
|
|
"new_x = bcb.Sub(x, bcb.Div(bcb.Sub(bcb.Mul(x, x), y), bcb.Add(x, x)))\n",
|
|
"result = bcb.Tuple(y, new_x, bcb.Sub(guard_cntr, bcb.Constant(np.int32(1))))\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# test computation -- convergence and max iteration test\n",
|
|
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
|
|
"intuple = tcb.ParameterWithShape(in_shape)\n",
|
|
"y = tcb.GetTupleElement(intuple, 0)\n",
|
|
"x = tcb.GetTupleElement(intuple, 1)\n",
|
|
"guard_cntr = tcb.GetTupleElement(intuple, 2)\n",
|
|
"criterion = tcb.Abs(tcb.Sub(tcb.Mul(x, x), y))\n",
|
|
"# stop at convergence criteria or too many iterations\n",
|
|
"test = tcb.And(tcb.Gt(criterion, tcb.Constant(np.float32(converged_delta))), \n",
|
|
" tcb.Gt(guard_cntr, tcb.Constant(np.int32(0))))\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# while computation:\n",
|
|
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
|
|
"intuple = wcb.ParameterWithShape(in_shape)\n",
|
|
"wcb.While(test_computation, body_computation, intuple)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
|
"\n",
|
|
"y = np.array(Xsqr, dtype=np.float32)\n",
|
|
"x = np.array(guess, dtype=np.float32)\n",
|
|
"maxit = np.array(maxit, dtype=np.int32)\n",
|
|
"\n",
|
|
"device_input = xla_client.LocalBuffer.from_pyval((y, x, maxit))\n",
|
|
"device_out = compiled_computation.Execute([device_input ,])\n",
|
|
"\n",
|
|
"host_out = device_out.to_py()\n",
|
|
"print(\"square root of {y} is {x}\".format(y=y, x=host_out[1]))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"square root of 2.0 is 1.4142156839370728\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "yETVIzTInFYr"
|
|
},
|
|
"source": [
|
|
"## Calculate Symm Eigenvalues"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "AiyR1e2NubKa"
|
|
},
|
|
"source": [
|
|
"Let's exploit the XLA QR implementation to solve some eigenvalues for symmetric matrices. \n",
|
|
"\n",
|
|
"This is the naive QR algorithm, without acceleration for closely-spaced eigenvalue convergence, nor any permutation to sort eigenvalues by magnitude."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "wjxDPbqCcuXT",
|
|
"outputId": "9683e40b-3c5f-4f3e-c971-0613b182c68c",
|
|
"colab": {
|
|
"height": 455
|
|
}
|
|
},
|
|
"source": [
|
|
"Niter = 200\n",
|
|
"matrix_shape = (10, 10)\n",
|
|
"in_shape = shape_of(\n",
|
|
" (np.zeros(matrix_shape, dtype=np.float32), 1)\n",
|
|
")\n",
|
|
"# NB: in_shape is the same as the manually constructed:\n",
|
|
"# xla_client.Shape.tuple_shape(\n",
|
|
"# (xla_client.Shape.array_shape(np.dtype(np.float32), matrix_shape), \n",
|
|
"# xla_client.Shape.array_shape(np.dtype(np.int32), ()))\n",
|
|
"# )\n",
|
|
"\n",
|
|
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
|
|
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
|
|
"intuple = bcb.ParameterWithShape(in_shape)\n",
|
|
"x = bcb.GetTupleElement(intuple, 0)\n",
|
|
"cntr = bcb.GetTupleElement(intuple, 1)\n",
|
|
"QR = bcb.QR(x)\n",
|
|
"Q = bcb.GetTupleElement(QR, 0)\n",
|
|
"R = bcb.GetTupleElement(QR, 1)\n",
|
|
"RQ = bcb.Dot(R, Q)\n",
|
|
"bcb.Tuple(RQ, bcb.Sub(cntr, bcb.Constant(np.int32(1))))\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# test computation -- just a for loop condition\n",
|
|
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
|
|
"intuple = tcb.ParameterWithShape(in_shape)\n",
|
|
"cntr = tcb.GetTupleElement(intuple, 1)\n",
|
|
"test = tcb.Gt(cntr, tcb.Constant(np.int32(0)))\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# while computation:\n",
|
|
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
|
|
"intuple = wcb.ParameterWithShape(in_shape)\n",
|
|
"wcb.While(test_computation, body_computation, intuple)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
|
"\n",
|
|
"X = np.random.random(matrix_shape).astype(np.float32)\n",
|
|
"X = (X + X.T) / 2.0\n",
|
|
"it = np.array(Niter, dtype=np.int32)\n",
|
|
"\n",
|
|
"device_in = xla_client.LocalBuffer.from_pyval((X, it))\n",
|
|
"device_out = compiled_computation.Execute([device_in,])\n",
|
|
"\n",
|
|
"host_out = device_out.to_py()\n",
|
|
"eigh_vals = host_out[0].diagonal()\n",
|
|
"\n",
|
|
"plt.title('D')\n",
|
|
"plt.imshow(host_out[0])\n",
|
|
"print('sorted eigenvalues')\n",
|
|
"print(np.sort(eigh_vals))\n",
|
|
"print('sorted eigenvalues from numpy')\n",
|
|
"print(np.sort(np.linalg.eigh(X)[0]))\n",
|
|
"print('sorted error') \n",
|
|
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"sorted eigenvalues\n",
|
|
"[-1.190547 -0.91282177 -0.32339668 -0.14050038 -0.09441247 0.08265306\n",
|
|
" 0.49015656 0.731502 1.0677357 5.3513203 ]\n",
|
|
"sorted eigenvalues from numpy\n",
|
|
"[-1.1905469 -0.9128221 -0.32339665 -0.14050038 -0.09441243 0.08265309\n",
|
|
" 0.49015662 0.7315014 1.0677353 5.351319 ]\n",
|
|
"sorted error\n",
|
|
"[-1.1920929e-07 3.5762787e-07 -2.9802322e-08 0.0000000e+00\n",
|
|
" -3.7252903e-08 -2.9802322e-08 -5.9604645e-08 5.9604645e-07\n",
|
|
" 3.5762787e-07 1.4305115e-06]\n"
|
|
],
|
|
"name": "stdout"
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACtdJREFUeJzt3V+snwV9x/H3h9MWaGWg6JbYNrZLjEvD/mDODMrmBbhEJ5ObZcMEM71plkxFY+JwN97uwhi9cC4N6A1ELioXxBBxiXpBliClJcO2akjtoIChcwFJIzut/e7inJmuW8956nkennO+vF9Jk55ffzx8Aufd53ee/s7TVBWSerpi7gGSpmPgUmMGLjVm4FJjBi41ZuBSYwYuNWbgr1NJTib5ZZJXkryU5F+T/G0SPyca8X/m69tfVNU1wNuAfwT+Hrh33kkak4GLqnq5qh4C/hr4myQ3zL1J4zBw/VpV/QA4Bfzp3Fs0DgPXxZ4H3jT3CI3DwHWxncB/zj1C4zBw/VqSP2Y58Efn3qJxGLhI8ltJbgMeAO6rqqfm3qRxxO8Hf31KchL4HeAccB44BtwH/HNV/WrGaRqRgUuN+RJdaszApcYMXGrMwKXGtkxx0De/aaH27N46+nF/cuL60Y8pbUavvvoSS2fPZK3nTRL4nt1b+cEju0c/7p/91UdHP6a0GT1+5J8GPc+X6FJjBi41ZuBSYwYuNWbgUmMGLjU2KPAk70/y4yRPJ7l76lGSxrFm4EkWgK8AHwD2AR9Osm/qYZLWb8gZ/F3A01V1oqqWWL4pwO3TzpI0hiGB7wSeveDjUyuP/S9J9ic5lOTQ6Z97vwBpIxjtIltVHaiqxapafMv1C2MdVtI6DAn8OeDCN5bvWnlM0gY3JPDHgbcn2ZtkG3AH8NC0sySNYc3vJquqc0k+DjwCLABfq6qjky+TtG6Dvl20qh4GHp54i6SR+U42qTEDlxozcKkxA5caM3CpsUluuviTE9dPcoPEKx59cvRjApz/kz+a5LjS3DyDS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNTXJX1alMdffTn//+1aMf8/qnfjn6MaXL5RlcaszApcYMXGrMwKXGDFxqzMClxgxcamzNwJPsTvK9JMeSHE1y12sxTNL6DXmjyzngM1V1OMk1wBNJ/qWqjk28TdI6rXkGr6oXqurwys9fAY4DO6ceJmn9Luutqkn2ADcCj/0/v7Yf2A9w5ZXXjjBN0noNvsiW5A3AN4FPVdUvLv71qjpQVYtVtbht644xN0r6DQ0KPMlWluO+v6oenHaSpLEMuYoe4F7geFV9cfpJksYy5Ax+M/AR4JYkT678+POJd0kawZoX2arqUSCvwRZJI/OdbFJjBi41ZuBSYwYuNbapbro4lSlukLj1306MfkyAs3/wu5McVz15BpcaM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGvOuqhOZ6u6nr+y5apLjXnPy1UmOq3l5BpcaM3CpMQOXGjNwqTEDlxozcKkxA5caGxx4koUkR5J8a8pBksZzOWfwu4DjUw2RNL5BgSfZBXwQuGfaOZLGNPQM/iXgs8D5Sz0hyf4kh5IcWjp7ZpRxktZnzcCT3Aa8WFVPrPa8qjpQVYtVtbht647RBkr6zQ05g98MfCjJSeAB4JYk9026StIo1gy8qj5XVbuqag9wB/Ddqrpz8mWS1s0/B5cau6zvB6+q7wPfn2SJpNF5BpcaM3CpMQOXGjNwqTEDlxrzrqqbzFR3P331t6+c5LhXvfhfkxxXw3gGlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5ca866qAqa7++mWl6a5C+y5666a5LjdeAaXGjNwqTEDlxozcKkxA5caM3CpMQOXGhsUeJLrkhxM8qMkx5O8e+phktZv6Btdvgx8u6r+Msk2YPuEmySNZM3Ak1wLvBf4KEBVLQFL086SNIYhL9H3AqeBryc5kuSeJDsuflKS/UkOJTm0dPbM6EMlXb4hgW8B3gl8tapuBM4Ad1/8pKo6UFWLVbW4bev/6V/SDIYEfgo4VVWPrXx8kOXgJW1wawZeVT8Dnk3yjpWHbgWOTbpK0iiGXkX/BHD/yhX0E8DHppskaSyDAq+qJ4HFibdIGpnvZJMaM3CpMQOXGjNwqTEDlxrzrqqa1FR3P71i6VejH/P8toXRjzk3z+BSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNeZNF7UpTXGDxCvOnR/9mADnt8x3HvUMLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjU2KPAkn05yNMkPk3wjyTR/o5ykUa0ZeJKdwCeBxaq6AVgA7ph6mKT1G/oSfQtwdZItwHbg+ekmSRrLmoFX1XPAF4BngBeAl6vqOxc/L8n+JIeSHFo6e2b8pZIu25CX6G8Ebgf2Am8FdiS58+LnVdWBqlqsqsVtW3eMv1TSZRvyEv19wE+r6nRVnQUeBN4z7SxJYxgS+DPATUm2JwlwK3B82lmSxjDka/DHgIPAYeCplX/mwMS7JI1g0PeDV9Xngc9PvEXSyHwnm9SYgUuNGbjUmIFLjRm41Jh3VZVWTHX305riuMmgp3kGlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caS1WNf9DkNPDvA576ZuA/Rh8wnc20dzNthc21dyNsfVtVvWWtJ00S+FBJDlXV4mwDLtNm2ruZtsLm2ruZtvoSXWrMwKXG5g78wMz//su1mfZupq2wufZumq2zfg0uaVpzn8ElTcjApcZmCzzJ+5P8OMnTSe6ea8dakuxO8r0kx5IcTXLX3JuGSLKQ5EiSb829ZTVJrktyMMmPkhxP8u65N60myadXPg9+mOQbSa6ae9NqZgk8yQLwFeADwD7gw0n2zbFlgHPAZ6pqH3AT8HcbeOuF7gKOzz1igC8D366q3wP+kA28OclO4JPAYlXdACwAd8y7anVzncHfBTxdVSeqagl4ALh9pi2rqqoXqurwys9fYfkTcOe8q1aXZBfwQeCeubesJsm1wHuBewGqaqmqXpp31Zq2AFcn2QJsB56fec+q5gp8J/DsBR+fYoNHA5BkD3Aj8Ni8S9b0JeCzwPm5h6xhL3Aa+PrKlxP3JNkx96hLqarngC8AzwAvAC9X1XfmXbU6L7INlOQNwDeBT1XVL+becylJbgNerKon5t4ywBbgncBXq+pG4Aywka/HvJHlV5p7gbcCO5LcOe+q1c0V+HPA7gs+3rXy2IaUZCvLcd9fVQ/OvWcNNwMfSnKS5S99bkly37yTLukUcKqq/ucV0UGWg9+o3gf8tKpOV9VZ4EHgPTNvWtVcgT8OvD3J3iTbWL5Q8dBMW1aVJCx/jXi8qr449561VNXnqmpXVe1h+b/rd6tqQ55lqupnwLNJ3rHy0K3AsRknreUZ4KYk21c+L25lA18UhOWXSK+5qjqX5OPAIyxfifxaVR2dY8sANwMfAZ5K8uTKY/9QVQ/PuKmTTwD3r/xGfwL42Mx7LqmqHktyEDjM8p+uHGGDv23Vt6pKjXmRTWrMwKXGDFxqzMClxgxcaszApcYMXGrsvwFxd1oMhwT6kgAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": [],
|
|
"needs_background": "light"
|
|
}
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "FpggTihknAOw"
|
|
},
|
|
"source": [
|
|
"## Calculate Full Symm Eigensystem"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Qos4ankYuj1T"
|
|
},
|
|
"source": [
|
|
"We can also calculate the eigenbasis by accumulating the Qs."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "Kp3A-aAiZk0g",
|
|
"outputId": "ebdc1ecc-c9e1-4e95-b989-9645f8648ee0",
|
|
"colab": {
|
|
"height": 1000
|
|
}
|
|
},
|
|
"source": [
|
|
"Niter = 100\n",
|
|
"matrix_shape = (10, 10)\n",
|
|
"in_shape = shape_of(\n",
|
|
" (np.zeros(matrix_shape, dtype=np.float32), \n",
|
|
" np.eye(matrix_shape[0]),\n",
|
|
" 1)\n",
|
|
")\n",
|
|
"\n",
|
|
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
|
|
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
|
|
"intuple = bcb.ParameterWithShape(in_shape)\n",
|
|
"X = bcb.GetTupleElement(intuple, 0)\n",
|
|
"O = bcb.GetTupleElement(intuple, 1)\n",
|
|
"cntr = bcb.GetTupleElement(intuple, 2)\n",
|
|
"QR = bcb.QR(X)\n",
|
|
"Q = bcb.GetTupleElement(QR, 0)\n",
|
|
"R = bcb.GetTupleElement(QR, 1)\n",
|
|
"RQ = bcb.Dot(R, Q)\n",
|
|
"Onew = bcb.Dot(O, Q)\n",
|
|
"bcb.Tuple(RQ, Onew, bcb.Sub(cntr, bcb.Constant(np.int32(1))))\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# test computation -- just a for loop condition\n",
|
|
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
|
|
"intuple = tcb.ParameterWithShape(in_shape)\n",
|
|
"cntr = tcb.GetTupleElement(intuple, 2)\n",
|
|
"test = tcb.Gt(cntr, tcb.Constant(np.int32(0)))\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# while computation:\n",
|
|
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
|
|
"intuple = wcb.ParameterWithShape(in_shape)\n",
|
|
"wcb.While(test_computation, body_computation, intuple)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
|
"\n",
|
|
"X = np.random.random(matrix_shape).astype(np.float32)\n",
|
|
"X = (X + X.T) / 2.0\n",
|
|
"Omat = np.eye(matrix_shape[0], dtype=np.float32)\n",
|
|
"it = np.array(Niter, dtype=np.int32)\n",
|
|
"\n",
|
|
"device_in = xla_client.LocalBuffer.from_pyval((X, Omat, it))\n",
|
|
"device_out = compiled_computation.Execute([device_in,])\n",
|
|
"\n",
|
|
"host_out = device_out.to_py()\n",
|
|
"eigh_vals = host_out[0].diagonal()\n",
|
|
"eigh_mat = host_out[1]\n",
|
|
"\n",
|
|
"plt.title('D')\n",
|
|
"plt.imshow(host_out[0])\n",
|
|
"plt.figure()\n",
|
|
"plt.title('U')\n",
|
|
"plt.imshow(eigh_mat)\n",
|
|
"plt.figure()\n",
|
|
"plt.title('U^T A U')\n",
|
|
"plt.imshow(np.dot(np.dot(eigh_mat.T, X), eigh_mat))\n",
|
|
"print('sorted eigenvalues')\n",
|
|
"print(np.sort(eigh_vals))\n",
|
|
"print('sorted eigenvalues from numpy')\n",
|
|
"print(np.sort(np.linalg.eigh(X)[0]))\n",
|
|
"print('sorted error') \n",
|
|
"print(np.sort(eigh_vals) - np.sort(np.linalg.eigh(X)[0]))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"sorted eigenvalues\n",
|
|
"[-0.94551486 -0.63820213 -0.57944936 -0.28589356 -0.05510262 0.16862962\n",
|
|
" 0.4192178 0.4671099 0.88734317 4.990509 ]\n",
|
|
"sorted eigenvalues from numpy\n",
|
|
"[-0.9455159 -0.63820285 -0.5794492 -0.28589386 -0.05510259 0.16862962\n",
|
|
" 0.41921794 0.46710995 0.88734376 4.9905105 ]\n",
|
|
"sorted error\n",
|
|
"[ 1.0132790e-06 7.1525574e-07 -1.7881393e-07 2.9802322e-07\n",
|
|
" -2.9802322e-08 0.0000000e+00 -1.4901161e-07 -5.9604645e-08\n",
|
|
" -5.9604645e-07 -1.4305115e-06]\n"
|
|
],
|
|
"name": "stdout"
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACtZJREFUeJzt3V+snwV9x/H3h3NKoMWhIHGxbWyXGBfGsmCOBiXzArzAyeRmyTCBbN6QJVPRmDjcjbe7MEYvjFsDegORi8oFMURcol6YZZVDIYG2apraQfkzuxmQVUhb+93FOTOsW8956nkennO+vF8JSc+vPx4+gfPm+f1+5/d7mqpCUk+XzD1A0nQMXGrMwKXGDFxqzMClxgxcaszApcYM/E0qyfEkryZ5JclLSf4lyd8k8XuiEf9jvrn9eVW9BXgX8A/A3wH3zTtJYzJwUVUvV9XDwF8Cf5Xkurk3aRwGrt+qqh8DJ4A/nXuLxmHgOt/zwFVzj9A4DFzn2wn8cu4RGoeB67eSvI+VwH809xaNw8BFkt9LcivwIHB/VT019yaNI34e/M0pyXHgHcBZ4BxwGLgf+Meq+s2M0zQiA5ca8yG61JiBS40ZuNSYgUuNLU5x0LdftVB7dm8b/bg/O3r16MeUtqLXTr/E6TOnst79Jgl8z+5t/PjR3aMf95bb7hz9mNJW9K9P/9Og+/kQXWrMwKXGDFxqzMClxgxcaszApcYGBZ7kliQ/TXI0yT1Tj5I0jnUDT7IAfA34CHAt8PEk1049TNLGDTmDvx84WlXHquo0KxcFuG3aWZLGMCTwncCzr/v6xOpt/0uSu5IsJ1k++Z9eL0DaDEZ7ka2q9lXVUlUtXXP1wliHlbQBQwJ/Dnj9G8t3rd4maZMbEvhjwLuT7E1yKXA78PC0sySNYd1Pk1XV2SSfBB4FFoBvVNWhyZdJ2rBBHxetqkeARybeImlkvpNNaszApcYMXGrMwKXGDFxqbJKLLv7s6NWTXCCxHpvmz8TL+/54kuNKc/MMLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41NslVVacy1dVPf/lHV4x+zKsO/dfox5QulmdwqTEDlxozcKkxA5caM3CpMQOXGjNwqbF1A0+yO8kPkhxOcijJ3W/EMEkbN+SNLmeBz1XVwSRvAR5P8s9VdXjibZI2aN0zeFW9UFUHV3/9CnAE2Dn1MEkbd1HPwZPsAa4HDvw/v3dXkuUky2fOnhpnnaQNGRx4kiuAbwOfqapfnf/7VbWvqpaqamnb4o4xN0r6HQ0KPMk2VuJ+oKoemnaSpLEMeRU9wH3Akar68vSTJI1lyBn8RuBO4KYkT67+9WcT75I0gnV/TFZVPwLyBmyRNDLfySY1ZuBSYwYuNWbgUmNb6qKLU5niAomXHHtu9GMCnPsD3yWs4TyDS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNeVXViUx19dNLjr84yXHP7fn9SY6reXkGlxozcKkxA5caM3CpMQOXGjNwqTEDlxobHHiShSRPJPnOlIMkjedizuB3A0emGiJpfIMCT7IL+Chw77RzJI1p6Bn8K8DngXMXukOSu5IsJ1k+c/bUKOMkbcy6gSe5FfhFVT2+1v2qal9VLVXV0rbFHaMNlPS7G3IGvxH4WJLjwIPATUnun3SVpFGsG3hVfaGqdlXVHuB24PtVdcfkyyRtmD8Hlxq7qM+DV9UPgR9OskTS6DyDS40ZuNSYgUuNGbjUmIFLjXlV1S1mqqufvvqOyyc57uX//uokx9UwnsGlxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApca8qqqA6a5++to1l01y3MtOvjbJcbvxDC41ZuBSYwYuNWbgUmMGLjVm4FJjBi41NijwJG9Nsj/JT5IcSfKBqYdJ2rihb3T5KvDdqvqLJJcC2yfcJGkk6wae5ErgQ8BfA1TVaeD0tLMkjWHIQ/S9wEngm0meSHJvkh3n3ynJXUmWkyyfOXtq9KGSLt6QwBeB9wJfr6rrgVPAPeffqar2VdVSVS1tW/w//UuawZDATwAnqurA6tf7WQle0ia3buBV9SLwbJL3rN50M3B40lWSRjH0VfRPAQ+svoJ+DPjEdJMkjWVQ4FX1JLA08RZJI/OdbFJjBi41ZuBSYwYuNWbgUmNeVVWTmurqp5e8/OvRj3nuyn6fofIMLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjXnRRW9IUF0j8zfZpclj49dlJjjuEZ3CpMQOXGjNwqTEDlxozcKkxA5caM3CpsUGBJ/lskkNJnk7yrSSXTT1M0satG3iSncCngaWqug5YAG6fepikjRv6EH0RuDzJIrAdeH66SZLGsm7gVfUc8CXgGeAF4OWq+t7590tyV5LlJMtnzp4af6mkizbkIfrbgNuAvcA7gR1J7jj/flW1r6qWqmpp2+KO8ZdKumhDHqJ/GPh5VZ2sqjPAQ8AHp50laQxDAn8GuCHJ9iQBbgaOTDtL0hiGPAc/AOwHDgJPrf49+ybeJWkEgz4AW1VfBL448RZJI/OdbFJjBi41ZuBSYwYuNWbgUmNeVVVaNdnVT8+dm+CgNehensGlxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcZSNezqjBd10OQk8G8D7vp24D9GHzCdrbR3K22FrbV3M2x9V1Vds96dJgl8qCTLVbU024CLtJX2bqWtsLX2bqWtPkSXGjNwqbG5A9838z//Ym2lvVtpK2ytvVtm66zPwSVNa+4zuKQJGbjU2GyBJ7klyU+THE1yz1w71pNkd5IfJDmc5FCSu+feNESShSRPJPnO3FvWkuStSfYn+UmSI0k+MPemtST57Or3wdNJvpXksrk3rWWWwJMsAF8DPgJcC3w8ybVzbBngLPC5qroWuAH420289fXuBo7MPWKArwLfrao/BP6ETbw5yU7g08BSVV0HLAC3z7tqbXOdwd8PHK2qY1V1GngQuG2mLWuqqheq6uDqr19h5Rtw57yr1pZkF/BR4N65t6wlyZXAh4D7AKrqdFW9NO+qdS0ClydZBLYDz8+8Z01zBb4TePZ1X59gk0cDkGQPcD1wYN4l6/oK8Hlgij95fkx7gZPAN1efTtybZMfcoy6kqp4DvgQ8A7wAvFxV35t31dp8kW2gJFcA3wY+U1W/mnvPhSS5FfhFVT0+95YBFoH3Al+vquuBU8Bmfj3mbaw80twLvBPYkeSOeVetba7AnwN2v+7rXau3bUpJtrES9wNV9dDce9ZxI/CxJMdZeepzU5L75510QSeAE1X1P4+I9rMS/Gb1YeDnVXWyqs4ADwEfnHnTmuYK/DHg3Un2JrmUlRcqHp5py5qShJXniEeq6stz71lPVX2hqnZV1R5W/r1+v6o25Vmmql4Enk3yntWbbgYOzzhpPc8ANyTZvvp9cTOb+EVBWHmI9IarqrNJPgk8ysorkd+oqkNzbBngRuBO4KkkT67e9vdV9ciMmzr5FPDA6v/ojwGfmHnPBVXVgST7gYOs/HTlCTb521Z9q6rUmC+ySY0ZuNSYgUuNGbjUmIFLjRm41JiBS439NyrdYAajSKUYAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": [],
|
|
"needs_background": "light"
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADO9JREFUeJzt3W1snfV5x/HfL46fEjqHsm4dcUbCoLCoFQK5ERSp6yDS+rSirawLU2jLm2jqgLRiYxStKqu0NxtCIK1DuNBKE6h0DVnF2ohSrXRt1DatCag0SWnDQ43TIMKYAwkhsZNrL+xJWbb43Mb/P7d95fuRkOKTOxeXbH9zn3N8zh1HhADktKjtBQDUQ+BAYgQOJEbgQGIEDiRG4EBiBA4kRuCnKNth+5wTbrvF9r1t7YTyCBxIjMCBxAgcSIzAgcQI/NR1VFL3Cbd1S5poYRdUQuCnrlFJK0+4bZWkX77xq6AWAj91fUXS39getL3I9lpJfyhpU8t7oSDzfvBTk+1+SZ+T9CeSTpf0lKRbIuLBVhdDUQQOJMZddCAxAgcSI3AgMQIHEltcY2j3QH/0vXWg+NxFoy4+U5J81tHiMw/v7y0+U5KO9VZ6UrTOp1b9vUeqzD2j+2DxmQeO1vmajR9YWnzm5Esv6eiBgx2/alUC73vrgC76p/XF5/Zf31N8piT13TVefObPt5xbfKYkHVxV6YVmPceqjL3g7LEqc6/+rR8Un/n9V87pfNDr8K9b1xSfufcfbm90HHfRgcQIHEiMwIHECBxIjMCBxAgcSKxR4Lbfa/tJ27tt31R7KQBldAzcdpekz0t6n6TVkq6yvbr2YgDmrskZfI2k3RHxdEQckXS/pCvqrgWghCaBL5f03HEfj03f9r/Y3mB7xPbIxPirpfYDMAfFnmSLiOGIGIqIoe5lS0qNBTAHTQLfI2nFcR8PTt8GYJ5rEviPJZ1re5XtHknrJHHdLmAB6PhusoiYtH2tpG9K6pL0xYjYUX0zAHPW6O2iEbFF0pbKuwAojFeyAYkROJAYgQOJETiQGIEDiVW56GKXj+lNPYeLz/3Tr20tPlOS/vaBjxSf+ZF13ys+U5K2/9eKzge9Dv5onb/rd1x7dpW5t7ztjOIzI+pcWnbxq+XnuuE1MjmDA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJVbmq6uSxRfrPQ+X/jfDPfbX81U8lacUjR4rP/N73Lyk+U5L+Y3i4ytx33lnnc9v33TpXKn31QG/xmWffFcVnStJTVzW8BOosRFez4ziDA4kROJAYgQOJETiQGIEDiRE4kBiBA4l1DNz2CtuP2N5pe4ftjW/EYgDmrskLXSYl3RAR222/SdKjtr8VETsr7wZgjjqewSNib0Rsn/71K5J2SVpeezEAczerx+C2V0q6UNK2/+f3NtgesT0ysf9Qme0AzEnjwG2fJukBSZ+MiJdP/P2IGI6IoYgY6h7oL7kjgNepUeC2uzUV930RsbnuSgBKafIsuiXdI2lXRNxWfyUApTQ5g18q6WpJl9l+fPq/91feC0ABHX9MFhFbJdV5Uy+AqnglG5AYgQOJETiQGIEDiVW56KIldS8qf6G5vn11nut7/hOvFZ85+OEdxWdK0u/e9Ykqc8/6ux9VmfviNWdUmduzta/4zLG1db6/Bh+eLD7zxZebXSCSMziQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kFidq6o61NtV/kqSh36z2ZUkZ2vwrvL/3PEv7ri4+ExJ6jpc53PQ9+91rn468bU6Vyp9884jxWfefMO/FJ8pSff84zuLz+w6cLjRcZzBgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQaB267y/Zjtr9ecyEA5czmDL5R0q5aiwAor1HgtgclfUDS3XXXAVBS0zP47ZJulHTsZAfY3mB7xPbIxP5DRZYDMDcdA7f9QUkvRMSjMx0XEcMRMRQRQ90D5V/bDWD2mpzBL5X0IdvPSrpf0mW27626FYAiOgYeEZ+OiMGIWClpnaRvR8T66psBmDN+Dg4kNqv3g0fEdyR9p8omAIrjDA4kRuBAYgQOJEbgQGIEDiRW5aqqXT6mgd7yL1dd854673VZ+8c7i8+87QtXFp8pST5aZawGl4xXmXvkoX1V5o5e8ZbiM2/8xp8VnylJ53xlrPjM+PNm52bO4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYlWuqlrLD59dWWXuky/9RvGZB95xuPhMSRr4UW+VuQ/vPr/K3L9/8IEqcz/zhY8Wn3n+raPFZ0rS0TPPKD90rFm6nMGBxAgcSIzAgcQIHEiMwIHECBxIjMCBxBoFbnuZ7U22f2Z7l+1Lai8GYO6avtDlDkkPRcSVtnskLam4E4BCOgZue0DSuyV9XJIi4oikI3XXAlBCk7voqyTtk/Ql24/Zvtv20hMPsr3B9ojtkcPjrxVfFMDsNQl8saSLJN0ZERdKOijpphMPiojhiBiKiKHeZX2F1wTwejQJfEzSWERsm/54k6aCBzDPdQw8Ip6X9Jzt86ZvulzSzqpbASii6bPo10m6b/oZ9KclXVNvJQClNAo8Ih6XNFR5FwCF8Uo2IDECBxIjcCAxAgcSI3AgsSpXVV3kUF/XRPG55/3VC8VnStLPN55VfOb7f3978ZmS9I3XLqgyd9nW//Pq4yL+evfVVeae889PFZ85uv53is+UpDVX/qT4zJ6PNbtqL2dwIDECBxIjcCAxAgcSI3AgMQIHEiNwIDECBxIjcCAxAgcSI3AgMQIHEiNwIDECBxIjcCAxAgcSI3AgMQIHEiNwIDECBxKrctHFY7FIr072FJ/7zB1vLj5Tkt72l3uKz3ziu3Uujnj+6MtV5j77R3U+t+94zy+qzP1J97nFZ06cXv5CoZL02/0vFZ/Zs2iy0XGcwYHECBxIjMCBxAgcSIzAgcQIHEiMwIHEGgVu+1O2d9j+qe0v2+6rvRiAuesYuO3lkq6XNBQRb5fUJWld7cUAzF3Tu+iLJfXbXixpiaRf1VsJQCkdA4+IPZJulTQqaa+k/RHx8InH2d5ge8T2yOHxQ+U3BTBrTe6iny7pCkmrJJ0paant9SceFxHDETEUEUO9y/rLbwpg1prcRV8r6ZmI2BcRE5I2S3pX3bUAlNAk8FFJF9teYtuSLpe0q+5aAEpo8hh8m6RNkrZLemL6zwxX3gtAAY3eDx4Rn5X02cq7ACiMV7IBiRE4kBiBA4kROJAYgQOJVbmq6sSxRdp78NeKzx3YfFrxmZI0vqb83KMff7H4TEnq/0xvlblynbGHPlbna3bdv20pPvOrN/9B8ZmS9OCjv1d85vgL2xsdxxkcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEjMEVF+qL1P0i8bHPrrkupcfrSOhbTvQtpVWlj7zoddz4qIt3Q6qErgTdkeiYih1haYpYW070LaVVpY+y6kXbmLDiRG4EBibQc+3PL/f7YW0r4LaVdpYe27YHZt9TE4gLraPoMDqIjAgcRaC9z2e20/aXu37Zva2qMT2ytsP2J7p+0dtje2vVMTtrtsP2b7623vMhPby2xvsv0z27tsX9L2TjOx/anp74Of2v6y7b62d5pJK4Hb7pL0eUnvk7Ra0lW2V7exSwOTkm6IiNWSLpb0F/N41+NtlLSr7SUauEPSQxFxvqQLNI93tr1c0vWShiLi7ZK6JK1rd6uZtXUGXyNpd0Q8HRFHJN0v6YqWdplRROyNiO3Tv35FU9+Ay9vdama2ByV9QNLdbe8yE9sDkt4t6R5JiogjETHe7lYdLZbUb3uxpCWSftXyPjNqK/Dlkp477uMxzfNoJMn2SkkXStrW7iYd3S7pRknH2l6kg1WS9kn60vTDibttL217qZOJiD2SbpU0KmmvpP0R8XC7W82MJ9kasn2apAckfTIiXm57n5Ox/UFJL0TEo23v0sBiSRdJujMiLpR0UNJ8fj7mdE3d01wl6UxJS22vb3ermbUV+B5JK477eHD6tnnJdrem4r4vIja3vU8Hl0r6kO1nNfXQ5zLb97a70kmNSRqLiP+5R7RJU8HPV2slPRMR+yJiQtJmSe9qeacZtRX4jyWda3uV7R5NPVHxYEu7zMi2NfUYcVdE3Nb2Pp1ExKcjYjAiVmrq8/rtiJiXZ5mIeF7Sc7bPm77pckk7W1ypk1FJF9teMv19cbnm8ZOC0tRdpDdcREzavlbSNzX1TOQXI2JHG7s0cKmkqyU9Yfvx6dtujogtLe6UyXWS7pv+i/5pSde0vM9JRcQ225skbdfUT1ce0zx/2SovVQUS40k2IDECBxIjcCAxAgcSI3AgMQIHEiNwILH/Bobry1k8oM1RAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": [],
|
|
"needs_background": "light"
|
|
}
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAC71JREFUeJzt3X+sX3V9x/Hni7YEWhZUYFtsm7VujoW5bJgrQUn4A1iCw0my+QcECPMfsmwqOhODyxKW/bkYo384Z4O6PyDiUogjjAgm6h9mGaP8SKCtkqYwKD+0bhOlim3pe3/cr0nX2Xu/5Z7Tc+87z0dyk/s99/R835D77Dnf7z3301QVkno6Y+oBJI3HwKXGDFxqzMClxgxcaszApcYMvJkk5ye5PcllU8+i6Rn4GpCkkvzWCdv+NsmdJ2zbBPwr8IfA/UkuPu5rNyR5dfbxsyTHjnv86hLPnST7k+wZak6dPgbeRJINwD3AHuBy4M+B+5L8JkBV3VVV51TVOcB7gRd/8Xi27WQuB34VeFuSd437X6GhrZ96AK1ckgD/BDwD/EUt3p741SQ/ZzHyK6rq+2/w8DcD/wKcPfv8kQFG1mli4A3Mgr7hl2z/GvC1N3rcJBuBDwDXsRj4F5L8VVUdfqPH1OnlJbqW8ifAz4GHWHxtvwG4ZtKJdEoMfG14ncW4jrcBODLy894M/HNVHa2q11h8jX/zEvtPNadOwkv0teE5YBuw97ht24Gnx3rCJFuAK4BLkvzpbPNG4Kwk51fVD1fDnFqaZ/C14avA3yTZkuSMJFcBfwzsHPE5b2IxzAuBP5h9/DZwALh+Fc2pJRj42vB3wL8B3wH+B/h74IaqemrE57wZ+Ieqevn4D+AfOfll+hRzaglxwQepL8/gUmMGLjVm4FJjBi41NsrPwc9/y7ratvXE+x1W7ul95w1+TGkteu3wjzh85FCW22+UwLdt3cB/PLh18ONefe1Ngx9TWov+/akvzLWfl+hSYwYuNWbgUmMGLjVm4FJjBi41NlfgSa5O8r0k+5LcNvZQkoaxbOBJ1gGfY3ElzouA65NcNPZgklZunjP4JcC+qto/W2zvbuDacceSNIR5At8MPH/c4wOzbf9HkluS7Eqy6+B/vT7UfJJWYLA32apqR1UtVNXCBeetG+qwklZgnsBfAI6/sXzLbJukVW6ewB8B3p5ke5IzWVwE/75xx5I0hGV/m6yqjib5EPAgsA74UlXtHn0ySSs216+LVtUDwAMjzyJpYN7JJjVm4FJjBi41ZuBSYwYuNTbKootP7ztvlAUS65EnBz8mQN71e6McV5qaZ3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqbFRVlUdy1irn/73754z+DHfsvvVwY8pnSrP4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjywaeZGuSbyXZk2R3kltPx2CSVm6eG12OAh+vqseS/ArwaJJvVNWekWeTtELLnsGr6qWqemz2+U+AvcDmsQeTtHKn9Bo8yTbgYuDhX/K1W5LsSrLryNFDw0wnaUXmDjzJOcA9wEer6scnfr2qdlTVQlUtbFi/acgZJb1BcwWeZAOLcd9VVfeOO5KkoczzLnqALwJ7q+rT448kaSjznMEvA24CrkjyxOzjj0aeS9IAlv0xWVV9B8hpmEXSwLyTTWrMwKXGDFxqzMClxtbUootjGWOBxDP2vzD4MQGOvc27hDU/z+BSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmOuqjqSsVY/PePZl0c57rFtvz7KcTUtz+BSYwYuNWbgUmMGLjVm4FJjBi41ZuBSY3MHnmRdkseT3D/mQJKGcypn8FuBvWMNIml4cwWeZAtwDXDHuONIGtK8Z/DPAJ8Ajp1shyS3JNmVZNeRo4cGGU7SyiwbeJL3AT+oqkeX2q+qdlTVQlUtbFi/abABJb1x85zBLwPen+RZ4G7giiR3jjqVpEEsG3hVfbKqtlTVNuA64JtVdePok0laMX8OLjV2Sr8PXlXfBr49yiSSBucZXGrMwKXGDFxqzMClxgxcasxVVdeYsVY//dmvnT3Kcc/+/s9GOa7m4xlcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMVVUFjLf66WsXnDXKcc86+Noox+3GM7jUmIFLjRm41JiBS40ZuNSYgUuNGbjU2FyBJ3lTkp1Jvptkb5J3jz2YpJWb90aXzwJfr6oPJDkT2DjiTJIGsmzgSc4FLgf+DKCqDgOHxx1L0hDmuUTfDhwEvpzk8SR3JNl04k5JbkmyK8muI0cPDT6opFM3T+DrgXcCn6+qi4FDwG0n7lRVO6pqoaoWNqz/f/1LmsA8gR8ADlTVw7PHO1kMXtIqt2zgVfUy8HySC2ebrgT2jDqVpEHM+y76h4G7Zu+g7wc+ON5IkoYyV+BV9QSwMPIskgbmnWxSYwYuNWbgUmMGLjVm4FJjrqqqUY21+ukZr/x08GMeO7ff71B5BpcaM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMRdd1Jo0xgKJr28cJ4d1Pz06ynHn4RlcaszApcYMXGrMwKXGDFxqzMClxgxcamyuwJN8LMnuJE8l+UqSs8YeTNLKLRt4ks3AR4CFqnoHsA64buzBJK3cvJfo64Gzk6wHNgIvjjeSpKEsG3hVvQB8CngOeAl4paoeOnG/JLck2ZVk15Gjh4afVNIpm+cS/c3AtcB24K3ApiQ3nrhfVe2oqoWqWtiwftPwk0o6ZfNcol8FPFNVB6vqCHAv8J5xx5I0hHkCfw64NMnGJAGuBPaOO5akIczzGvxhYCfwGPDk7M/sGHkuSQOY6xdgq+p24PaRZ5E0MO9kkxozcKkxA5caM3CpMQOXGnNVVWlmtNVPjx0b4aA1116ewaXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxlI13+qMp3TQ5CDwn3Psej7ww8EHGM9amnctzQpra97VMOtvVNUFy+00SuDzSrKrqhYmG+AUraV519KssLbmXUuzeokuNWbgUmNTB75j4uc/VWtp3rU0K6ytedfMrJO+Bpc0rqnP4JJGZOBSY5MFnuTqJN9Lsi/JbVPNsZwkW5N8K8meJLuT3Dr1TPNIsi7J40nun3qWpSR5U5KdSb6bZG+Sd08901KSfGz2ffBUkq8kOWvqmZYySeBJ1gGfA94LXARcn+SiKWaZw1Hg41V1EXAp8JereNbj3QrsnXqIOXwW+HpV/Q7w+6zimZNsBj4CLFTVO4B1wHXTTrW0qc7glwD7qmp/VR0G7gaunWiWJVXVS1X12Ozzn7D4Dbh52qmWlmQLcA1wx9SzLCXJucDlwBcBqupwVf1o2qmWtR44O8l6YCPw4sTzLGmqwDcDzx/3+ACrPBqAJNuAi4GHp51kWZ8BPgGM8S/PD2k7cBD48uzlxB1JNk091MlU1QvAp4DngJeAV6rqoWmnWppvss0pyTnAPcBHq+rHU89zMkneB/ygqh6depY5rAfeCXy+qi4GDgGr+f2YN7N4pbkdeCuwKcmN0061tKkCfwHYetzjLbNtq1KSDSzGfVdV3Tv1PMu4DHh/kmdZfOlzRZI7px3ppA4AB6rqF1dEO1kMfrW6Cnimqg5W1RHgXuA9E8+0pKkCfwR4e5LtSc5k8Y2K+yaaZUlJwuJrxL1V9emp51lOVX2yqrZU1TYW/79+s6pW5Vmmql4Gnk9y4WzTlcCeCUdaznPApUk2zr4vrmQVvykIi5dIp11VHU3yIeBBFt+J/FJV7Z5iljlcBtwEPJnkidm2v66qByacqZMPA3fN/qLfD3xw4nlOqqoeTrITeIzFn648ziq/bdVbVaXGfJNNaszApcYMXGrMwKXGDFxqzMClxgxcaux/AWPSvgMr5OSGAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": [],
|
|
"needs_background": "light"
|
|
}
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Ee3LMzOvlCuK"
|
|
},
|
|
"source": [
|
|
"## Convolutions\n",
|
|
"\n",
|
|
"I keep hearing from the AGI folks that we can use convolutions to build artificial life. Let's try it out."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "J8QkirDalBse",
|
|
"outputId": "73c53980-8dbd-497b-fe56-7e606a29c19f",
|
|
"colab": {
|
|
"height": 132
|
|
}
|
|
},
|
|
"source": [
|
|
"Niter=13\n",
|
|
"matrix_shape = (1,1, 20, 20)\n",
|
|
"in_shape = shape_of(\n",
|
|
" (np.zeros(matrix_shape, dtype=np.int32), 1)\n",
|
|
")\n",
|
|
"\n",
|
|
"# Body computation -- Conway Update\n",
|
|
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
|
|
"intuple = bcb.ParameterWithShape(in_shape)\n",
|
|
"x = bcb.GetTupleElement(intuple, 0)\n",
|
|
"cntr = bcb.GetTupleElement(intuple, 1)\n",
|
|
"# convs require floating-point type\n",
|
|
"xf = bcb.ConvertElementType(x, to_xla_type('float32'))\n",
|
|
"stamp = bcb.Constant(np.ones((1,1,3,3), dtype=np.float32))\n",
|
|
"convd = bcb.Conv(xf, stamp, np.array([1, 1]), xla_client.PaddingType.SAME)\n",
|
|
"# logic ops require integer types\n",
|
|
"convd = bcb.ConvertElementType(convd, to_xla_type('int32'))\n",
|
|
"bool_x = bcb.Eq(x, bcb.ConstantS32Scalar(1))\n",
|
|
"# core update rule\n",
|
|
"res = bcb.Or(\n",
|
|
" # birth rule\n",
|
|
" bcb.And(bcb.Not(bool_x), bcb.Eq(convd, bcb.ConstantS32Scalar(3))),\n",
|
|
" # survival rule\n",
|
|
" bcb.And(bool_x, bcb.Or(\n",
|
|
" # these are +1 the normal numbers since conv-sum counts self\n",
|
|
" bcb.Eq(convd, bcb.ConstantS32Scalar(4)),\n",
|
|
" bcb.Eq(convd, bcb.ConstantS32Scalar(3)))\n",
|
|
" )\n",
|
|
")\n",
|
|
"# Convert output back to int type for type constancy\n",
|
|
"int_res = bcb.ConvertElementType(res, to_xla_type('int32'))\n",
|
|
"bcb.Tuple(int_res, bcb.Sub(cntr, bcb.ConstantS32Scalar(1)))\n",
|
|
"body_computation = bcb.Build()\n",
|
|
"\n",
|
|
"# Test computation -- just a for loop condition\n",
|
|
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
|
|
"intuple = tcb.ParameterWithShape(in_shape)\n",
|
|
"cntr = tcb.GetTupleElement(intuple, 1)\n",
|
|
"test = tcb.Gt(cntr, tcb.ConstantS32Scalar(0))\n",
|
|
"test_computation = tcb.Build()\n",
|
|
"\n",
|
|
"# While computation:\n",
|
|
"wcb = xla_client.ComputationBuilder(\"whilecomp\")\n",
|
|
"intuple = wcb.ParameterWithShape(in_shape)\n",
|
|
"wcb.While(test_computation, body_computation, intuple)\n",
|
|
"while_computation = wcb.Build()\n",
|
|
"\n",
|
|
"# Now compile and execute:\n",
|
|
"compiled_computation = while_computation.Compile([in_shape,])\n",
|
|
"\n",
|
|
"# Set up initial state\n",
|
|
"X = np.zeros(matrix_shape, dtype=np.int32)\n",
|
|
"X[0,0, 5:8, 5:8] = np.array([[0,1,0],[0,0,1],[1,1,1]])\n",
|
|
"\n",
|
|
"# Evolve\n",
|
|
"movie = np.zeros((Niter,)+matrix_shape[-2:], dtype=np.int32)\n",
|
|
"for it in range(Niter):\n",
|
|
" itr = np.array(it, dtype=np.int32)\n",
|
|
" device_in = xla_client.LocalBuffer.from_pyval((X, itr))\n",
|
|
" device_out = compiled_computation.Execute([device_in,])\n",
|
|
" movie[it] = device_out.to_py()[0][0,0]\n",
|
|
"\n",
|
|
"# Plot\n",
|
|
"fig = plt.figure(figsize=(15,2))\n",
|
|
"gs = gridspec.GridSpec(1,Niter)\n",
|
|
"for i in range(Niter):\n",
|
|
" ax1 = plt.subplot(gs[:, i])\n",
|
|
" ax1.axis('off')\n",
|
|
" ax1.imshow(movie[i])\n",
|
|
"plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, hspace=0.0, wspace=0.05)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": [
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABFoAAABwCAYAAAAuRhTQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABBBJREFUeJzt3MFt20AQQFHKUBWpwk0EqSBVpoIgTbiKlBH6FMDRxZL4lxI3753sk4jB8vKxnNO6rgsAAAAA2708+gEAAAAAZiG0AAAAAESEFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIic9/yxry/f1z1/bza//vw4/f3bLLf5OMtlMc+tnM2OWXa85y1ns+NstpzNjll2vOctZ7Njlp3L9/wjN1oAAAAAIkILAAAAQERoAQAAAIgILQAAAAARoQUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAiQgsAAABARGgBAAAAiAgtAAAAABGhBQAAACAitAAAAABEhBYAAACAiNACAAAAEBFaAAAAACJCCwAAAEBEaAEAAACInB/9AKWfv9/++f/bl9cHPQkAAADwP3KjBQAAACAitAAAAABEhBYAAACAyFQ7Wi53stjZcr3PZmWWLfO8n9mNY7YAALCdGy0AAAAAEaEFAAAAICK0AAAAAEQOvaPlcp8A9/tsJwu3setinFvPqtlfz26mbey62o9ZbmN+45gtAMviRgsAAABARmgBAAAAiAgtAAAAAJFD72jx3et+zPo2dt6MYydLx7lsee879lyMZdfVOHYz3c+eq32Z5/3MbpyZZutGCwAAAEBEaAEAAACICC0AAAAAkUPvaGGcI38P94zMs2OWHbPcl3lfz76bsexk6TibHe99a6ZdF8/GnqtxZtrN5EYLAAAAQERoAQAAAIgILQAAAAARO1oAYIAjfUf87MyyZZ4ds9yPWd/Gzptx7GTpzHwu3WgBAAAAiAgtAAAAABGhBQAAACBiRwsAAPBU7LlomWfHLDszz9KNFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIgILQAAAAARoQUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAiQgsAAABARGgBAAAAiAgtAAAAABGhBQAAACAitAAAAABEhBYAAACAiNACAAAAEBFaAAAAACJCCwAAAEBEaAEAAACICC0AAAAAEaEFAAAAICK0AAAAAESEFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIgILQAAAAARoQUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAip3VdH/0MAAAAAFNwowUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAiQgsAAABARGgBAAAAiAgtAAAAABGhBQAAACAitAAAAABEhBYAAACAiNACAAAAEBFaAAAAACJCCwAAAEBEaAEAAACICC0AAAAAEaEFAAAAICK0AAAAAESEFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIi8A8vjuqwsx0TPAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 1080x144 with 13 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": [],
|
|
"needs_background": "light"
|
|
}
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "9-0PJlqv237S"
|
|
},
|
|
"source": [
|
|
"## Fin \n",
|
|
"\n",
|
|
"There's much more to XLA, but this hopefully highlights how easy it is to play with via the python client!"
|
|
]
|
|
}
|
|
]
|
|
}
|