mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
527 lines
15 KiB
Plaintext
527 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "l7Sh70ascFNs"
|
|
},
|
|
"source": [
|
|
"# Solving the wave equation on cloud TPUs\n",
|
|
"\n",
|
|
"[_Stephan Hoyer_](https://twitter.com/shoyer)\n",
|
|
"\n",
|
|
"In this notebook, we solve the 2D [wave equation](https://en.wikipedia.org/wiki/Wave_equation):\n",
|
|
"$$\n",
|
|
"\\frac{\\partial^2 u}{\\partial t^2} = c^2 \\nabla^2 u\n",
|
|
"$$\n",
|
|
"\n",
|
|
"We use a simple [finite difference](https://en.wikipedia.org/wiki/Finite_difference_method) formulation with [Leapfrog time integration](https://en.wikipedia.org/wiki/Leapfrog_integration).\n",
|
|
"\n",
|
|
"Note: It is natural to express finite difference methods as convolutions, but here we intentionally avoid convolutions in favor of array indexing/arithmetic. This is because \"batch\" and \"feature\" dimensions in TPU convolutions are padded to multiples of either 8 and 128, but in our case both these dimensions are effectively of size 1.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "xAd8PvW6ceyk"
|
|
},
|
|
"source": [
|
|
"## Setup required environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "n6L8zIaAIrqj"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Grab other packages for this demo.\n",
|
|
"!pip install -U -q Pillow moviepy proglog scikit-image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "UHkzcAMhcpm_"
|
|
},
|
|
"source": [
|
|
"## Simulation code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "4NJQ8Q5M99wy"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from functools import partial\n",
|
|
"import jax\n",
|
|
"from jax import lax\n",
|
|
"from jax import tree_util\n",
|
|
"import jax.numpy as jnp\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import skimage.filters\n",
|
|
"import proglog\n",
|
|
"from moviepy.editor import ImageSequenceClip\n",
|
|
"\n",
|
|
"device_count = jax.device_count()\n",
|
|
"\n",
|
|
"# Spatial partitioning via halo exchange\n",
|
|
"\n",
|
|
"def send_right(x, axis_name):\n",
|
|
" # Note: if some devices are omitted from the permutation, lax.ppermute\n",
|
|
" # provides zeros instead. This gives us an easy way to apply Dirichlet\n",
|
|
" # boundary conditions.\n",
|
|
" left_perm = [(i, (i + 1) % device_count) for i in range(device_count - 1)]\n",
|
|
" return lax.ppermute(x, perm=left_perm, axis_name=axis_name)\n",
|
|
"\n",
|
|
"def send_left(x, axis_name):\n",
|
|
" left_perm = [((i + 1) % device_count, i) for i in range(device_count - 1)]\n",
|
|
" return lax.ppermute(x, perm=left_perm, axis_name=axis_name)\n",
|
|
"\n",
|
|
"def axis_slice(ndim, index, axis):\n",
|
|
" slices = [slice(None)] * ndim\n",
|
|
" slices[axis] = index\n",
|
|
" return tuple(slices)\n",
|
|
"\n",
|
|
"def slice_along_axis(array, index, axis):\n",
|
|
" return array[axis_slice(array.ndim, index, axis)]\n",
|
|
"\n",
|
|
"def tree_vectorize(func):\n",
|
|
" def wrapper(x, *args, **kwargs):\n",
|
|
" return tree_util.tree_map(lambda x: func(x, *args, **kwargs), x)\n",
|
|
" return wrapper\n",
|
|
"\n",
|
|
"@tree_vectorize\n",
|
|
"def halo_exchange_padding(array, padding=1, axis=0, axis_name='x'):\n",
|
|
" if not padding > 0:\n",
|
|
" raise ValueError(f'invalid padding: {padding}')\n",
|
|
" array = jnp.array(array)\n",
|
|
" if array.ndim == 0:\n",
|
|
" return array\n",
|
|
" left = slice_along_axis(array, slice(None, padding), axis)\n",
|
|
" right = slice_along_axis(array, slice(-padding, None), axis)\n",
|
|
" right, left = send_left(left, axis_name), send_right(right, axis_name)\n",
|
|
" return jnp.concatenate([left, array, right], axis)\n",
|
|
"\n",
|
|
"@tree_vectorize\n",
|
|
"def halo_exchange_inplace(array, padding=1, axis=0, axis_name='x'):\n",
|
|
" left = slice_along_axis(array, slice(padding, 2*padding), axis)\n",
|
|
" right = slice_along_axis(array, slice(-2*padding, -padding), axis)\n",
|
|
" right, left = send_left(left, axis_name), send_right(right, axis_name)\n",
|
|
" array = array.at[axis_slice(array.ndim, slice(None, padding), axis)].set(left)\n",
|
|
" array = array.at[axis_slice(array.ndim, slice(-padding, None), axis)].set(right)\n",
|
|
" return array\n",
|
|
"\n",
|
|
"# Reshaping inputs/outputs for pmap\n",
|
|
"\n",
|
|
"def split_with_reshape(array, num_splits, *, split_axis=0, tile_id_axis=None):\n",
|
|
" if tile_id_axis is None:\n",
|
|
" tile_id_axis = split_axis\n",
|
|
" tile_size, remainder = divmod(array.shape[split_axis], num_splits)\n",
|
|
" if remainder:\n",
|
|
" raise ValueError('num_splits must equally divide the dimension size')\n",
|
|
" new_shape = list(array.shape)\n",
|
|
" new_shape[split_axis] = tile_size\n",
|
|
" new_shape.insert(split_axis, num_splits)\n",
|
|
" return jnp.moveaxis(jnp.reshape(array, new_shape), split_axis, tile_id_axis)\n",
|
|
"\n",
|
|
"def stack_with_reshape(array, *, split_axis=0, tile_id_axis=None):\n",
|
|
" if tile_id_axis is None:\n",
|
|
" tile_id_axis = split_axis\n",
|
|
" array = jnp.moveaxis(array, tile_id_axis, split_axis)\n",
|
|
" new_shape = array.shape[:split_axis] + (-1,) + array.shape[split_axis+2:]\n",
|
|
" return jnp.reshape(array, new_shape)\n",
|
|
"\n",
|
|
"def shard(func):\n",
|
|
" def wrapper(state):\n",
|
|
" sharded_state = tree_util.tree_map(\n",
|
|
" lambda x: split_with_reshape(x, device_count), state)\n",
|
|
" sharded_result = func(sharded_state)\n",
|
|
" result = tree_util.tree_map(stack_with_reshape, sharded_result)\n",
|
|
" return result\n",
|
|
" return wrapper\n",
|
|
"\n",
|
|
"# Physics\n",
|
|
"\n",
|
|
"def shift(array, offset, axis):\n",
|
|
" index = slice(offset, None) if offset >= 0 else slice(None, offset)\n",
|
|
" sliced = slice_along_axis(array, index, axis)\n",
|
|
" padding = [(0, 0)] * array.ndim\n",
|
|
" padding[axis] = (-min(offset, 0), max(offset, 0))\n",
|
|
" return jnp.pad(sliced, padding, mode='constant', constant_values=0)\n",
|
|
"\n",
|
|
"def laplacian(array, step=1):\n",
|
|
" left = shift(array, +1, axis=0)\n",
|
|
" right = shift(array, -1, axis=0)\n",
|
|
" up = shift(array, +1, axis=1)\n",
|
|
" down = shift(array, -1, axis=1)\n",
|
|
" convolved = (left + right + up + down - 4 * array)\n",
|
|
" if step != 1:\n",
|
|
" convolved *= (1 / step ** 2)\n",
|
|
" return convolved\n",
|
|
"\n",
|
|
"def scalar_wave_equation(u, c=1, dx=1):\n",
|
|
" return c ** 2 * laplacian(u, dx)\n",
|
|
"\n",
|
|
"@jax.jit\n",
|
|
"def leapfrog_step(state, dt=0.5, c=1):\n",
|
|
" # https://en.wikipedia.org/wiki/Leapfrog_integration\n",
|
|
" u, u_t = state\n",
|
|
" u_tt = scalar_wave_equation(u, c)\n",
|
|
" u_t = u_t + u_tt * dt\n",
|
|
" u = u + u_t * dt\n",
|
|
" return (u, u_t)\n",
|
|
"\n",
|
|
"# Time stepping\n",
|
|
"\n",
|
|
"def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):\n",
|
|
" return lax.fori_loop(0, count, lambda i, s: leapfrog_step(s, dt, c), state)\n",
|
|
"\n",
|
|
"def multi_step_pmap(state, count, dt=1/jnp.sqrt(2), c=1, exchange_interval=1,\n",
|
|
" save_interval=1):\n",
|
|
"\n",
|
|
" def exchange_and_multi_step(state_padded):\n",
|
|
" c_padded = halo_exchange_padding(c, exchange_interval)\n",
|
|
" evolved = multi_step(state_padded, exchange_interval, dt, c_padded)\n",
|
|
" return halo_exchange_inplace(evolved, exchange_interval)\n",
|
|
"\n",
|
|
" @shard\n",
|
|
" @partial(jax.pmap, axis_name='x')\n",
|
|
" def simulate_until_output(state):\n",
|
|
" stop = save_interval // exchange_interval\n",
|
|
" state_padded = halo_exchange_padding(state, exchange_interval)\n",
|
|
" advanced = lax.fori_loop(\n",
|
|
" 0, stop, lambda i, s: exchange_and_multi_step(s), state_padded)\n",
|
|
" xi = exchange_interval\n",
|
|
" return tree_util.tree_map(lambda array: array[xi:-xi, ...], advanced)\n",
|
|
"\n",
|
|
" results = [state]\n",
|
|
" for _ in range(count // save_interval):\n",
|
|
" state = simulate_until_output(state)\n",
|
|
" tree_util.tree_map(lambda x: x.copy_to_host_async(), state)\n",
|
|
" results.append(state)\n",
|
|
" results = jax.device_get(results)\n",
|
|
" return tree_util.tree_map(lambda *xs: np.stack([np.array(x) for x in xs]), *results)\n",
|
|
"\n",
|
|
"multi_step_jit = jax.jit(multi_step)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "_kMyXbkEeCu3"
|
|
},
|
|
"source": [
|
|
"## Initial conditions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "usWkf2UgAYc5"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = jnp.linspace(0, 8, num=8*1024, endpoint=False)\n",
|
|
"y = jnp.linspace(0, 1, num=1*1024, endpoint=False)\n",
|
|
"x_mesh, y_mesh = jnp.meshgrid(x, y, indexing='ij')\n",
|
|
"\n",
|
|
"# NOTE: smooth initial conditions are important, so we aren't exciting\n",
|
|
"# arbitrarily high frequencies (that cannot be resolved)\n",
|
|
"u = skimage.filters.gaussian(\n",
|
|
" ((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) < 0.1 ** 2,\n",
|
|
" sigma=1)\n",
|
|
"\n",
|
|
"# u = jnp.exp(-((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) / 0.1 ** 2)\n",
|
|
"\n",
|
|
"# u = skimage.filters.gaussian(\n",
|
|
"# (x_mesh > 1/3) & (x_mesh < 1/2) & (y_mesh > 1/3) & (y_mesh < 1/2),\n",
|
|
"# sigma=5)\n",
|
|
"\n",
|
|
"v = jnp.zeros_like(u)\n",
|
|
"c = 1 # could also use a 2D array matching the mesh shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "jWtIvDTfCx12"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"u.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "RVdPIRKmeNX3"
|
|
},
|
|
"source": [
|
|
"## Test scaling from 1 to 8 chips"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "SPG6jvYSCaKQ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%time\n",
|
|
"# single TPU chip\n",
|
|
"u_final, _ = multi_step_jit((u, v), count=2**13, c=c, dt=0.5)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "MpsDzNyC6OI0"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%time\n",
|
|
"# 8x TPU chips, 4x more steps in roughly half the time!\n",
|
|
"u_final, _ = multi_step_pmap(\n",
|
|
" (u, v), count=2**15, c=c, dt=0.5, exchange_interval=4, save_interval=2**15)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "fPCwJ51FeBLR"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"18.3 / (10.3 / 4) # near linear scaling (8x would be perfect)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "jYMn69z8eQ7O"
|
|
},
|
|
"source": [
|
|
"## Save a bunch of outputs for a movie"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Uns4X34EPYgF"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%time\n",
|
|
"# save more outputs for a movie -- this is slow!\n",
|
|
"u_final, _ = multi_step_pmap(\n",
|
|
" (u, v), count=2**15, c=c, dt=0.2, exchange_interval=4, save_interval=2**10)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Hsjuk22Nbe9Z"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"u_final.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "LXqmz-CCdQjt"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"u_final.nbytes / 1e9"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "cPvXrPUCPPtt"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.figure(figsize=(18, 6))\n",
|
|
"plt.axis('off')\n",
|
|
"plt.imshow(u_final[-1].T, cmap='RdBu');"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "uqxGmttmgSsN"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"fig, axes = plt.subplots(9, 1, figsize=(14, 14))\n",
|
|
"[ax.axis('off') for ax in axes]\n",
|
|
"axes[0].imshow(u_final[0].T, cmap='RdBu', aspect='equal', vmin=-1, vmax=1)\n",
|
|
"for i in range(8):\n",
|
|
" axes[i+1].imshow(u_final[4*i+1].T / abs(u_final[4*i+1]).max(), cmap='RdBu', aspect='equal', vmin=-1, vmax=1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "hNZrt590p6zr"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import matplotlib.cm\n",
|
|
"import matplotlib.colors\n",
|
|
"from PIL import Image\n",
|
|
"\n",
|
|
"def make_images(data, cmap='RdBu', vmax=None):\n",
|
|
" images = []\n",
|
|
" for frame in data:\n",
|
|
" if vmax is None:\n",
|
|
" this_vmax = np.max(abs(frame))\n",
|
|
" else:\n",
|
|
" this_vmax = vmax\n",
|
|
" norm = matplotlib.colors.Normalize(vmin=-this_vmax, vmax=this_vmax)\n",
|
|
" mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)\n",
|
|
" rgba = mappable.to_rgba(frame, bytes=True)\n",
|
|
" image = Image.fromarray(rgba, mode='RGBA')\n",
|
|
" images.append(image)\n",
|
|
" return images\n",
|
|
"\n",
|
|
"def save_movie(images, path, duration=100, loop=0, **kwargs):\n",
|
|
" images[0].save(path, save_all=True, append_images=images[1:],\n",
|
|
" duration=duration, loop=loop, **kwargs)\n",
|
|
"\n",
|
|
"images = make_images(u_final[::, ::8, ::8].transpose(0, 2, 1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "ZWEZWehFjboa"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Show Movie\n",
|
|
"proglog.default_bar_logger = partial(proglog.default_bar_logger, None)\n",
|
|
"ImageSequenceClip([np.array(im) for im in images], fps=25).ipython_display()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "whG0OOepezuX"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save GIF.\n",
|
|
"save_movie(images,'wave_movie.gif', duration=[2000]+[200]*(len(images)-2)+[2000])\n",
|
|
"# The movie sometimes takes a second before showing up in the file system.\n",
|
|
"import time; time.sleep(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 0,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "E1V7aS63vRAT"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Download animation.\n",
|
|
"try:\n",
|
|
" from google.colab import files\n",
|
|
"except ImportError:\n",
|
|
" pass\n",
|
|
"else:\n",
|
|
" files.download('wave_movie.gif')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "TPU",
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"name": "JAX TPU wave equation",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|