mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
444 lines
14 KiB
Plaintext
444 lines
14 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "Lorentz ODE Solver",
|
|
"provenance": [],
|
|
"collapsed_sections": []
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"accelerator": "TPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ntE40GybB_cn",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# Lorentz ODE Solver in JAX\n",
|
|
"Alex Alemi"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "fyoHa_blbI71",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "GAFiL4V_kPE8",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"import io\n",
|
|
"import os\n",
|
|
"from functools import partial\n",
|
|
"import numpy as np\n",
|
|
"import jax\n",
|
|
"import jax.numpy as jnp\n",
|
|
"from jax import vmap, jit, grad, ops, lax, config\n",
|
|
"from jax import random as jr\n",
|
|
"\n",
|
|
"import matplotlib as mpl\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import matplotlib.cm as cm\n",
|
|
"from IPython.display import display_png\n",
|
|
"\n",
|
|
"mpl.rcParams['savefig.pad_inches'] = 0\n",
|
|
"plt.style.use('seaborn-dark')\n",
|
|
"%matplotlib inline"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "vruUCSlrU_L7",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# Plotting Utilities\n",
|
|
"\n",
|
|
"These just provide fast, better antialiased line plotting than typical matplotlib plotting routines."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "aTVqxdEQLZwM",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"@jit\n",
|
|
"def drawline(im, x0, y0, x1, y1):\n",
|
|
" \"\"\"An implementation of Wu's antialiased line algorithm.\n",
|
|
" \n",
|
|
" This functional version was adapted from here:\n",
|
|
" https://en.wikipedia.org/wiki/Xiaolin_Wu's_line_algorithm\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" ipart = lambda x: jnp.floor(x).astype('int32')\n",
|
|
" round_ = lambda x: ipart(x + 0.5).astype('int32')\n",
|
|
" fpart = lambda x: x - jnp.floor(x)\n",
|
|
" rfpart = lambda x: 1 - fpart(x)\n",
|
|
"\n",
|
|
" def plot(im, x, y, c):\n",
|
|
" return ops.index_add(im, ops.index[x, y], c)\n",
|
|
"\n",
|
|
" steep = jnp.abs(y1 - y0) > jnp.abs(x1 - x0)\n",
|
|
" cond_swap = lambda cond, x: lax.cond(cond, x, lambda x: (x[1], x[0]), x, lambda x: x)\n",
|
|
" \n",
|
|
" (x0, y0) = cond_swap(steep, (x0, y0))\n",
|
|
" (x1, y1) = cond_swap(steep, (x1, y1))\n",
|
|
" \n",
|
|
" (y0, y1) = cond_swap(x0 > x1, (y0, y1))\n",
|
|
" (x0, x1) = cond_swap(x0 > x1, (x0, x1))\n",
|
|
"\n",
|
|
" dx = x1 - x0\n",
|
|
" dy = y1 - y0\n",
|
|
" gradient = jnp.where(dx == 0.0, 1.0, dy/dx)\n",
|
|
"\n",
|
|
" # handle first endpoint\n",
|
|
" xend = round_(x0)\n",
|
|
" yend = y0 + gradient * (xend - x0)\n",
|
|
" xgap = rfpart(x0 + 0.5)\n",
|
|
" xpxl1 = xend # this will be used in main loop\n",
|
|
" ypxl1 = ipart(yend)\n",
|
|
"\n",
|
|
" def true_fun(im):\n",
|
|
" im = plot(im, ypxl1, xpxl1, rfpart(yend) * xgap)\n",
|
|
" im = plot(im, ypxl1+1, xpxl1, fpart(yend) * xgap)\n",
|
|
" return im\n",
|
|
" def false_fun(im):\n",
|
|
" im = plot(im, xpxl1, ypxl1 , rfpart(yend) * xgap)\n",
|
|
" im = plot(im, xpxl1, ypxl1+1, fpart(yend) * xgap)\n",
|
|
" return im\n",
|
|
" im = lax.cond(steep, im, true_fun, im, false_fun)\n",
|
|
" \n",
|
|
" intery = yend + gradient\n",
|
|
"\n",
|
|
" # handle second endpoint\n",
|
|
" xend = round_(x1)\n",
|
|
" yend = y1 + gradient * (xend - x1)\n",
|
|
" xgap = fpart(x1 + 0.5)\n",
|
|
" xpxl2 = xend # this will be used in the main loop\n",
|
|
" ypxl2 = ipart(yend)\n",
|
|
" def true_fun(im):\n",
|
|
" im = plot(im, ypxl2 , xpxl2, rfpart(yend) * xgap)\n",
|
|
" im = plot(im, ypxl2+1, xpxl2, fpart(yend) * xgap)\n",
|
|
" return im\n",
|
|
" def false_fun(im):\n",
|
|
" im = plot(im, xpxl2, ypxl2, rfpart(yend) * xgap)\n",
|
|
" im = plot(im, xpxl2, ypxl2+1, fpart(yend) * xgap)\n",
|
|
" return im\n",
|
|
" im = lax.cond(steep, im, true_fun, im, false_fun)\n",
|
|
" \n",
|
|
" def true_fun(arg):\n",
|
|
" im, intery = arg\n",
|
|
" def body_fun(x, arg):\n",
|
|
" im, intery = arg\n",
|
|
" im = plot(im, ipart(intery), x, rfpart(intery))\n",
|
|
" im = plot(im, ipart(intery)+1, x, fpart(intery))\n",
|
|
" intery = intery + gradient\n",
|
|
" return (im, intery)\n",
|
|
" im, intery = lax.fori_loop(xpxl1+1, xpxl2, body_fun, (im, intery))\n",
|
|
" return (im, intery)\n",
|
|
" def false_fun(arg):\n",
|
|
" im, intery = arg\n",
|
|
" def body_fun(x, arg):\n",
|
|
" im, intery = arg\n",
|
|
" im = plot(im, x, ipart(intery), rfpart(intery))\n",
|
|
" im = plot(im, x, ipart(intery)+1, fpart(intery))\n",
|
|
" intery = intery + gradient\n",
|
|
" return (im, intery)\n",
|
|
" im, intery = lax.fori_loop(xpxl1+1, xpxl2, body_fun, (im, intery))\n",
|
|
" return (im, intery)\n",
|
|
" im, intery = lax.cond(steep, (im, intery), true_fun, (im, intery), false_fun)\n",
|
|
" \n",
|
|
" return im\n",
|
|
"\n",
|
|
"def img_adjust(data):\n",
|
|
" oim = np.array(data)\n",
|
|
" hist, bin_edges = np.histogram(oim.flat, bins=256*256)\n",
|
|
" bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2\n",
|
|
" cdf = hist.cumsum()\n",
|
|
" cdf = cdf / float(cdf[-1])\n",
|
|
" return np.interp(oim.flat, bin_centers, cdf).reshape(oim.shape)\n",
|
|
"\n",
|
|
"def imify(arr, vmin=None, vmax=None, cmap=None, origin=None):\n",
|
|
" arr = img_adjust(arr)\n",
|
|
" sm = cm.ScalarMappable(cmap=cmap)\n",
|
|
" sm.set_clim(vmin, vmax)\n",
|
|
" if origin is None:\n",
|
|
" origin = mpl.rcParams[\"image.origin\"]\n",
|
|
" if origin == \"lower\":\n",
|
|
" arr = arr[::-1]\n",
|
|
" rgba = sm.to_rgba(arr, bytes=True)\n",
|
|
" return rgba\n",
|
|
"\n",
|
|
"def plot_image(array, **kwargs):\n",
|
|
" f = io.BytesIO()\n",
|
|
" imarray = imify(array, **kwargs)\n",
|
|
" plt.imsave(f, imarray, format=\"png\")\n",
|
|
" f.seek(0)\n",
|
|
" dat = f.read()\n",
|
|
" f.close()\n",
|
|
" display_png(dat, raw=True)\n",
|
|
"\n",
|
|
"def pack_images(images, rows, cols):\n",
|
|
" shape = np.shape(images)\n",
|
|
" width, height, depth = shape[-3:]\n",
|
|
" images = np.reshape(images, (-1, width, height, depth))\n",
|
|
" batch = np.shape(images)[0]\n",
|
|
" rows = np.minimum(rows, batch)\n",
|
|
" cols = np.minimum(batch // rows, cols)\n",
|
|
" images = images[:rows * cols]\n",
|
|
" images = np.reshape(images, (rows, cols, width, height, depth))\n",
|
|
" images = np.transpose(images, [0, 2, 1, 3, 4])\n",
|
|
" images = np.reshape(images, [rows * width, cols * height, depth])\n",
|
|
" return images"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "FFkdRUDR9cWD",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# Lorentz Dynamics\n",
|
|
"\n",
|
|
"Implement Lorentz' attractor"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "aoSvqedskd0W",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"sigma = 10.\n",
|
|
"beta = 8./3\n",
|
|
"rho = 28.\n",
|
|
"\n",
|
|
"@jit\n",
|
|
"def f(state, t):\n",
|
|
" x, y, z = state\n",
|
|
" return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "tanYn8Cx9hUb",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# Runge Kutta Integrator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "ejuN_R7Km28v",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"@jit\n",
|
|
"def rk4(ys, dt, N):\n",
|
|
" @jit\n",
|
|
" def step(i, ys):\n",
|
|
" h = dt\n",
|
|
" t = dt * i\n",
|
|
" k1 = h * f(ys[i-1], t)\n",
|
|
" k2 = h * f(ys[i-1] + k1/2., dt * i + h/2.)\n",
|
|
" k3 = h * f(ys[i-1] + k2/2., t + h/2.)\n",
|
|
" k4 = h * f(ys[i-1] + k3, t + h)\n",
|
|
" \n",
|
|
" ysi = ys[i-1] + 1./6 * (k1 + 2 * k2 + 2 * k3 + k4)\n",
|
|
" return ops.index_update(ys, ops.index[i], ysi)\n",
|
|
" return lax.fori_loop(1, N, step, ys)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "i2UIxo3Z9PZ2",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# Solve and plot a single ODE Solution using jitted solver and plotter"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "XvROzDrukzH_",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"N = 40000\n",
|
|
"\n",
|
|
"# set initial condition\n",
|
|
"state0 = jnp.array([1., 1., 1.])\n",
|
|
"ys = jnp.zeros((N,) + state0.shape)\n",
|
|
"ys = ops.index_update(ys, ops.index[0], state0)\n",
|
|
"\n",
|
|
"# solve for N steps\n",
|
|
"ys = rk4(ys, 0.004, N).block_until_ready()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "t4k3UrtbM4jy",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"# plotting size and region:\n",
|
|
"xlim, zlim = (-20, 20), (0, 50)\n",
|
|
"xN, zN = 800, 600\n",
|
|
"\n",
|
|
"# fast, jitted plotting function\n",
|
|
"@partial(jax.jit, static_argnums=(2,3,4,5))\n",
|
|
"def jplotter(xs, zs, xlim, zlim, xN, zN):\n",
|
|
" im = jnp.zeros((xN, zN))\n",
|
|
" xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN\n",
|
|
" zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN\n",
|
|
" def body_fun(i, im):\n",
|
|
" return drawline(im, xpixels[i-1], zpixels[i-1], xpixels[i], zpixels[i])\n",
|
|
" return lax.fori_loop(1, xpixels.shape[0], body_fun, im)\n",
|
|
"\n",
|
|
"im = jplotter(ys[...,0], ys[...,2], xlim, zlim, xN, zN)\n",
|
|
"plot_image(im[:,::-1].T, cmap='magma')"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "JWkKc-mh7m9x",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# Parallel ODE Solutions with Pmap"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "tlc8Y_pfOERv",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"N_dev = jax.device_count()\n",
|
|
"N = 4000\n",
|
|
"\n",
|
|
"# set some initial conditions for each replicate\n",
|
|
"ys = jnp.zeros((N_dev, N, 3))\n",
|
|
"state0 = jr.uniform(jr.key(1), \n",
|
|
" minval=-1., maxval=1.,\n",
|
|
" shape=(N_dev, 3))\n",
|
|
"state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))\n",
|
|
"ys = ops.index_update(ys, ops.index[:, 0], state0)\n",
|
|
"\n",
|
|
"# solve each replicate in parallel using `pmap` of rk4 solver:\n",
|
|
"ys = jax.pmap(rk4)(ys, \n",
|
|
" 0.004 * jnp.ones(N_dev), \n",
|
|
" N * jnp.ones(N_dev, dtype=np.int32)\n",
|
|
" ).block_until_ready()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "_NdalA1qy1Fp",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"# parallel plotter using lexical closure and pmap'd core plotting function\n",
|
|
"def pplotter(_xs, _zs, xlim, zlim, xN, zN):\n",
|
|
" N_dev = _xs.shape[0]\n",
|
|
" im = jnp.zeros((N_dev, xN, zN))\n",
|
|
" @jax.pmap\n",
|
|
" def plotfn(im, xs, zs):\n",
|
|
" xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN\n",
|
|
" zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN\n",
|
|
" def body_fun(i, im):\n",
|
|
" return drawline(im, xpixels[i-1], zpixels[i-1], xpixels[i], zpixels[i])\n",
|
|
" return lax.fori_loop(1, xpixels.shape[0], body_fun, im)\n",
|
|
" return plotfn(im, _xs, _zs)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "vhZyGqHUYkKK",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"xlim, zlim = (-20, 20), (0, 50)\n",
|
|
"xN, zN = 200, 150\n",
|
|
"# above, plot ODE traces separately\n",
|
|
"ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN, zN)\n",
|
|
"im = pack_images(ims[..., None], 4, 2)[..., 0]\n",
|
|
"plot_image(im[:,::-1].T, cmap='magma')\n",
|
|
"# below, plot combined ODE traces\n",
|
|
"ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN*4, zN*4)\n",
|
|
"plot_image(jnp.sum(ims, axis=0)[:,::-1].T, cmap='magma')"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "S6c5GWHBbkEX",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
""
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
}
|
|
]
|
|
}
|