mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
1011 lines
23 KiB
Plaintext
1011 lines
23 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "JAX demo.ipynb",
|
|
"provenance": [],
|
|
"collapsed_sections": [
|
|
"AvXl1WDPKjmV"
|
|
]
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"accelerator": "TPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "hLEyhfMqmnrt",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## Colab JAX TPU Setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "5CTEVmyKmkfp",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"# Make sure the Colab Runtime is set to Accelerator: TPU.\n",
|
|
"import requests\n",
|
|
"import os\n",
|
|
"if 'TPU_DRIVER_MODE' not in globals():\n",
|
|
" url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'\n",
|
|
" resp = requests.post(url)\n",
|
|
" TPU_DRIVER_MODE = 1\n",
|
|
"\n",
|
|
"# The following is required to use TPU Driver as JAX's backend.\n",
|
|
"from jax.config import config\n",
|
|
"config.FLAGS.jax_xla_backend = \"tpu_driver\"\n",
|
|
"config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n",
|
|
"print(config.FLAGS.jax_backend_target)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ebUMqK9mGIDm",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## The basics: interactive NumPy on GPU and TPU\n",
|
|
"\n",
|
|
"---\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "27TqNtiQF97X",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"import jax\n",
|
|
"import jax.numpy as np\n",
|
|
"from jax import random\n",
|
|
"\n",
|
|
"key = random.PRNGKey(0)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "cRWoxSCNGU4o",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"key, subkey = random.split(key)\n",
|
|
"x = random.normal(key, (5000, 5000))\n",
|
|
"\n",
|
|
"print(x.shape)\n",
|
|
"print(x.dtype)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "diPllsvgGfSA",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"y = np.dot(x, x)\n",
|
|
"print(y[0, 0])"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "8-psauxnGiRk",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"x"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "-2FMQ8UeoTJ8",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"plt.plot(x[0])"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "DRnwCKFuGk8P",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"np.dot(x, x.T)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "z4VX5PkMHJIu",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"print(np.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "ORZ9Odu85BCJ",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"import numpy as onp\n",
|
|
"\n",
|
|
"x_cpu = onp.array(x)\n",
|
|
"%timeit -n 1 -r 1 onp.dot(x_cpu, x_cpu)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "5BKh0eeAGvO5",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"%timeit -n 5 -r 5 np.dot(x, x).block_until_ready()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "fm4Q2zpFHUAu",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## Automatic differentiation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "MCIQbyUYHWn1",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"from jax import grad"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "kfqZpKYsHo4j",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"def f(x):\n",
|
|
" if x > 0:\n",
|
|
" return 2 * x ** 3\n",
|
|
" else:\n",
|
|
" return 3 * x"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "K_26_odPHqLJ",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"key = random.PRNGKey(0)\n",
|
|
"x = random.normal(key, ())\n",
|
|
"\n",
|
|
"print(grad(f)(x))\n",
|
|
"print(grad(f)(-x))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "q5V3A6loHrhS",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"print(grad(grad(f))(-x))\n",
|
|
"print(grad(grad(grad(f)))(-x))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "ba4WY4ArHv8I",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"def predict(params, inputs):\n",
|
|
" for W, b in params:\n",
|
|
" outputs = np.dot(inputs, W) + b\n",
|
|
" inputs = np.tanh(outputs)\n",
|
|
" return outputs\n",
|
|
"\n",
|
|
"def loss(params, batch):\n",
|
|
" inputs, targets = batch\n",
|
|
" predictions = predict(params, inputs)\n",
|
|
" return np.sum((predictions - targets)**2)\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"def init_layer(key, n_in, n_out):\n",
|
|
" k1, k2 = random.split(key)\n",
|
|
" W = random.normal(k1, (n_in, n_out))\n",
|
|
" b = random.normal(k2, (n_out,))\n",
|
|
" return W, b\n",
|
|
"\n",
|
|
"layer_sizes = [5, 2, 3]\n",
|
|
"\n",
|
|
"key = random.PRNGKey(0)\n",
|
|
"key, *keys = random.split(key, len(layer_sizes))\n",
|
|
"params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
|
|
"\n",
|
|
"key, *keys = random.split(key, 3)\n",
|
|
"inputs = random.normal(keys[0], (8, 5))\n",
|
|
"targets = random.normal(keys[1], (8, 3))\n",
|
|
"batch = (inputs, targets)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "LiTBibJdHz4K",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"print(loss(params, batch))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "a3KFpwH3H4Cl",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"step_size = 1e-2\n",
|
|
"\n",
|
|
"for _ in range(20):\n",
|
|
" grads = grad(loss)(params, batch)\n",
|
|
" params = [(W - step_size * dW, b - step_size * db)\n",
|
|
" for (W, b), (dW, db) in zip(params, grads)]"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "YLltDr0GH7LX",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"print(loss(params, batch))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bmxAPFC0I8b0",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"Other JAX autodiff highlights:\n",
|
|
"\n",
|
|
"* Forward- and reverse-mode, totally composable\n",
|
|
"* Fast Jacobians and Hessians\n",
|
|
"* Complex number support (holomorphic and non-holomorphic)\n",
|
|
"* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n",
|
|
"\n",
|
|
"\n",
|
|
"For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "TRkxaVLJKNre",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## End-to-end compilation with XLA using `jit`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "bKo4rX9-KSW7",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"from jax import jit"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "94iIgZSfKWh8",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"key = random.PRNGKey(0)\n",
|
|
"x = random.normal(key, (5000, 5000))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "Ybuz8Ag9KXMd",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"def f(x):\n",
|
|
" y = x\n",
|
|
" for _ in range(10):\n",
|
|
" y = y - 0.1 * y + 3.\n",
|
|
" return y[:100, :100]\n",
|
|
"\n",
|
|
"f(x)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "Y9dx5ifSKaGJ",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"g = jit(f)\n",
|
|
"g(x)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "UtsS67BvKYkC",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"%timeit f(x).block_until_ready()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "-vfcaSo9KbvR",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"%timeit g(x).block_until_ready()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "E3BQF1_AKeLn",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"grad(jit(grad(jit(grad(np.tanh)))))(1.0)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "AvXl1WDPKjmV",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"### Constraints that come with using `jit`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "mCtwRF18KnsE",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"def f(x):\n",
|
|
" if x > 0:\n",
|
|
" return 2 * x ** 2\n",
|
|
" else:\n",
|
|
" return 3 * x\n",
|
|
"\n",
|
|
"g = jit(f)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "_82tY-ZSKqv4",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"f(2)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "TjSAFc-iKrcB",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"g(2)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "RhizP9pjKsug",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"def f(x, n):\n",
|
|
" i = 0\n",
|
|
" while i < n:\n",
|
|
" x = x * x\n",
|
|
" i += 1\n",
|
|
" return x\n",
|
|
"\n",
|
|
"g = jit(f)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "Wn6haTmUK-Q8",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"f(np.array([1., 2., 3.]), 5)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "HwBy1I04K-81",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"g(np.array([1., 2., 3.]), 5)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "XmaTryZaK_3M",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"g = jit(f, static_argnums=(1,))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "HcWjxVktV4fa",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"g(np.array([1., 2., 3.]), 5)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "0M_-pJe7LOcO",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## Vectorization with `vmap`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "8XIot_ndLRH1",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"from jax import vmap"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "tRvCZn2wBkXP",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"print(vmap(lambda x: x**2)(np.arange(8)))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "icfsXizI_rkD",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"from jax import make_jaxpr\n",
|
|
"\n",
|
|
"make_jaxpr(np.dot)(np.ones(8), np.ones(8))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "uQm4cvAbA6M3",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"make_jaxpr(vmap(np.dot))(np.ones((10, 8)), np.ones((10, 8)))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "NeiFfCHEBLsU",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"make_jaxpr(vmap(vmap(np.dot)))(np.ones((10, 10, 8)), np.ones((10, 10, 8)))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "csX71fkSCZrp",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"perex_grads = vmap(grad(loss), in_axes=(None, 0))\n",
|
|
"make_jaxpr(perex_grads)(params, batch)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Tmf1NT2Wqv5p",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## Parallel accelerators with pmap"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "t6RRAFn1CEln",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"jax.devices()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "tEK1I6Duqunw",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"from jax import pmap"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "S-iCNfeGqzkY",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"y = pmap(lambda x: x ** 2)(np.arange(8))\n",
|
|
"print(y)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "xgutf5JPP3wi",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"y"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "xxShG3Tdq4Gj",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"z = y / 2\n",
|
|
"print(z)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "uvDL2_bCq7kq",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"import matplotlib.pyplot as plt\n",
|
|
"plt.plot(y)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "Xg76CmLYq_Q6",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"keys = random.split(random.PRNGKey(0), 8)\n",
|
|
"mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n",
|
|
"result = pmap(np.dot)(mats, mats)\n",
|
|
"print(pmap(np.mean)(result))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "jbw_hRx7rDzX",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"timeit -n 5 -r 5 pmap(np.dot)(mats, mats).block_until_ready()"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xf5N9ZRirJhL",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"### Collective communication operations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "9i1PfxUvrThh",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"from functools import partial\n",
|
|
"from jax.lax import psum\n",
|
|
"\n",
|
|
"@partial(pmap, axis_name='i')\n",
|
|
"def normalize(x):\n",
|
|
" return x / psum(x, 'i')\n",
|
|
"\n",
|
|
"print(normalize(np.arange(8.)))"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "lnvwnlOFrVa-",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"@partial(pmap, axis_name='rows')\n",
|
|
"@partial(pmap, axis_name='cols')\n",
|
|
"def f(x):\n",
|
|
" row_sum = psum(x, 'rows')\n",
|
|
" col_sum = psum(x, 'cols')\n",
|
|
" total_sum = psum(x, ('rows', 'cols'))\n",
|
|
" return row_sum, col_sum, total_sum\n",
|
|
"\n",
|
|
"x = np.arange(8.).reshape((4, 2))\n",
|
|
"a, b, c = f(x)\n",
|
|
"\n",
|
|
"print(\"input:\\n\", x)\n",
|
|
"print(\"row sum:\\n\", a)\n",
|
|
"print(\"col sum:\\n\", b)\n",
|
|
"print(\"total sum:\\n\", c)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "f-FBsWeo1AXE",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<img src=\"https://raw.githubusercontent.com/google/jax/master/cloud_tpu_colabs/images/nested_pmap.png\" width=\"70%\"/>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "jC-KIMQ1q-lK",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"For more, see the [`pmap` cookbook](https://colab.sandbox.google.com/github/google/jax/blob/master/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "-A-oVDo6rdWA",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"### Compose pmap with other transforms!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "WC_dMIN2rgTZ",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"@pmap\n",
|
|
"def f(x):\n",
|
|
" y = np.sin(x)\n",
|
|
" @pmap\n",
|
|
" def g(z):\n",
|
|
" return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum()\n",
|
|
" return grad(lambda w: np.sum(g(w)))(x)\n",
|
|
" \n",
|
|
"f(x)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "apuACjPWrixV",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"grad(lambda x: np.sum(f(x)))(x)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WD9xtROsYX4i",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"### Compose everything"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "h65c9AQCWAyn",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"from jax import jvp, vjp # forward and reverse-mode\n",
|
|
"\n",
|
|
"curry = lambda f: partial(partial, f)\n",
|
|
"\n",
|
|
"@curry\n",
|
|
"def jacfwd(fun, x):\n",
|
|
" pushfwd = partial(jvp, fun, (x,)) # jvp!\n",
|
|
" std_basis = np.eye(onp.size(x)).reshape((-1,) + np.shape(x)),\n",
|
|
" y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis) # vmap!\n",
|
|
" return jac_flat.reshape(np.shape(y) + np.shape(x))\n",
|
|
"\n",
|
|
"@curry\n",
|
|
"def jacrev(fun, x):\n",
|
|
" y, pullback = vjp(fun, x) # vjp!\n",
|
|
" std_basis = np.eye(onp.size(y)).reshape((-1,) + np.shape(y))\n",
|
|
" jac_flat, = vmap(pullback)(std_basis) # vmap!\n",
|
|
" return jac_flat.reshape(np.shape(y) + np.shape(x))\n",
|
|
"\n",
|
|
"def hessian(fun):\n",
|
|
" return jit(jacfwd(jacrev(fun))) # jit!"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "G9qDX84RWhW7",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"input_hess = hessian(lambda inputs: loss(params, (inputs, targets)))\n",
|
|
"per_example_hess = pmap(input_hess) # pmap!\n",
|
|
"per_example_hess(inputs)"
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "u3ggM_WYZ8QC",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
""
|
|
],
|
|
"execution_count": 0,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |