{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ebUMqK9mGIDm" }, "source": [ "## The basics: interactive NumPy on GPU and TPU\n", "\n", "---\n", "\n" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "27TqNtiQF97X" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from jax import random\n", "\n", "key = random.key(0)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "cRWoxSCNGU4o" }, "outputs": [], "source": [ "key, subkey = random.split(key)\n", "x = random.normal(key, (5000, 5000))\n", "\n", "print(x.shape)\n", "print(x.dtype)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "diPllsvgGfSA" }, "outputs": [], "source": [ "y = jnp.dot(x, x)\n", "print(y[0, 0])" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "8-psauxnGiRk" }, "outputs": [], "source": [ "x" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "-2FMQ8UeoTJ8" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(x[0])" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "DRnwCKFuGk8P" }, "outputs": [], "source": [ "jnp.dot(x, x.T)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "z4VX5PkMHJIu" }, "outputs": [], "source": [ "print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "ORZ9Odu85BCJ" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "x_cpu = np.array(x)\n", "%timeit -n 1 -r 1 np.dot(x_cpu, x_cpu)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "5BKh0eeAGvO5" }, "outputs": [], "source": [ "%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "fm4Q2zpFHUAu" }, "source": [ "## Automatic differentiation" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "MCIQbyUYHWn1" }, "outputs": [], "source": [ "from jax import grad" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "kfqZpKYsHo4j" }, "outputs": [], "source": [ "def f(x):\n", " if x > 0:\n", " return 2 * x ** 3\n", " else:\n", " return 3 * x" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "K_26_odPHqLJ" }, "outputs": [], "source": [ "key = random.key(0)\n", "x = random.normal(key, ())\n", "\n", "print(grad(f)(x))\n", "print(grad(f)(-x))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "q5V3A6loHrhS" }, "outputs": [], "source": [ "print(grad(grad(f))(-x))\n", "print(grad(grad(grad(f)))(-x))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "ba4WY4ArHv8I" }, "outputs": [], "source": [ "def predict(params, inputs):\n", " for W, b in params:\n", " outputs = jnp.dot(inputs, W) + b\n", " inputs = jnp.tanh(outputs) # inputs to the next layer\n", " return outputs # no activation on last layer\n", "\n", "def loss(params, batch):\n", " inputs, targets = batch\n", " predictions = predict(params, inputs)\n", " return jnp.sum((predictions - targets)**2)\n", "\n", "\n", "\n", "def init_layer(key, n_in, n_out):\n", " k1, k2 = random.split(key)\n", " W = random.normal(k1, (n_in, n_out))\n", " b = random.normal(k2, (n_out,))\n", " return W, b\n", "\n", "layer_sizes = [5, 2, 3]\n", "\n", "key = random.key(0)\n", "key, *keys = random.split(key, len(layer_sizes))\n", "params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", "\n", "key, *keys = random.split(key, 3)\n", "inputs = random.normal(keys[0], (8, 5))\n", "targets = random.normal(keys[1], (8, 3))\n", "batch = (inputs, targets)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "LiTBibJdHz4K" }, "outputs": [], "source": [ "print(loss(params, batch))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "a3KFpwH3H4Cl" }, "outputs": [], "source": [ "step_size = 1e-2\n", "\n", "for _ in range(20):\n", " grads = grad(loss)(params, batch)\n", " params = [(W - step_size * dW, b - step_size * db)\n", " for (W, b), (dW, db) in zip(params, grads)]" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "YLltDr0GH7LX" }, "outputs": [], "source": [ "print(loss(params, batch))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "bmxAPFC0I8b0" }, "source": [ "Other JAX autodiff highlights:\n", "\n", "* Forward- and reverse-mode, totally composable\n", "* Fast Jacobians and Hessians\n", "* Complex number support (holomorphic and non-holomorphic)\n", "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "TRkxaVLJKNre" }, "source": [ "## End-to-end compilation with XLA using `jit`" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "bKo4rX9-KSW7" }, "outputs": [], "source": [ "from jax import jit" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "94iIgZSfKWh8" }, "outputs": [], "source": [ "key = random.key(0)\n", "x = random.normal(key, (5000, 5000))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Ybuz8Ag9KXMd" }, "outputs": [], "source": [ "def f(x):\n", " y = x\n", " for _ in range(10):\n", " y = y - 0.1 * y + 3.\n", " return y[:100, :100]\n", "\n", "f(x)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Y9dx5ifSKaGJ" }, "outputs": [], "source": [ "g = jit(f)\n", "g(x)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "UtsS67BvKYkC" }, "outputs": [], "source": [ "%timeit f(x).block_until_ready()" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "-vfcaSo9KbvR" }, "outputs": [], "source": [ "%timeit g(x).block_until_ready()" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "E3BQF1_AKeLn" }, "outputs": [], "source": [ "grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AvXl1WDPKjmV" }, "source": [ "### Constraints that come with using `jit`" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "mCtwRF18KnsE" }, "outputs": [], "source": [ "def f(x):\n", " if x > 0:\n", " return 2 * x ** 2\n", " else:\n", " return 3 * x\n", "\n", "g = jit(f)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "_82tY-ZSKqv4" }, "outputs": [], "source": [ "f(2)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "TjSAFc-iKrcB" }, "outputs": [], "source": [ "try:\n", " g(2)\n", "except Exception as e:\n", " print(e)\n", " pass" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "RhizP9pjKsug" }, "outputs": [], "source": [ "def f(x, n):\n", " i = 0\n", " while i < n:\n", " x = x * x\n", " i += 1\n", " return x\n", "\n", "g = jit(f)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Wn6haTmUK-Q8" }, "outputs": [], "source": [ "f(jnp.array([1., 2., 3.]), 5)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "HwBy1I04K-81" }, "outputs": [], "source": [ "try:\n", " g(jnp.array([1., 2., 3.]), 5)\n", "except Exception as e:\n", " print(e)\n", " pass" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "XmaTryZaK_3M" }, "outputs": [], "source": [ "g = jit(f, static_argnums=(1,))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "HcWjxVktV4fa" }, "outputs": [], "source": [ "g(jnp.array([1., 2., 3.]), 5)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0M_-pJe7LOcO" }, "source": [ "## Vectorization with `vmap`" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "8XIot_ndLRH1" }, "outputs": [], "source": [ "from jax import vmap" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "tRvCZn2wBkXP" }, "outputs": [], "source": [ "print(vmap(lambda x: x**2)(jnp.arange(8)))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "icfsXizI_rkD" }, "outputs": [], "source": [ "from jax import make_jaxpr\n", "\n", "make_jaxpr(jnp.dot)(jnp.ones(8), jnp.ones(8))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "uQm4cvAbA6M3" }, "outputs": [], "source": [ "make_jaxpr(vmap(jnp.dot))(jnp.ones((10, 8)), jnp.ones((10, 8)))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "NeiFfCHEBLsU" }, "outputs": [], "source": [ "make_jaxpr(vmap(vmap(jnp.dot)))(jnp.ones((10, 10, 8)), jnp.ones((10, 10, 8)))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "csX71fkSCZrp" }, "outputs": [], "source": [ "perex_grads = vmap(grad(loss), in_axes=(None, 0))\n", "make_jaxpr(perex_grads)(params, batch)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Tmf1NT2Wqv5p" }, "source": [ "## Parallel accelerators with pmap" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "t6RRAFn1CEln" }, "outputs": [], "source": [ "jax.devices()" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "tEK1I6Duqunw" }, "outputs": [], "source": [ "from jax import pmap" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "S-iCNfeGqzkY" }, "outputs": [], "source": [ "y = pmap(lambda x: x ** 2)(jnp.arange(8))\n", "print(y)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "xgutf5JPP3wi" }, "outputs": [], "source": [ "y" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "xxShG3Tdq4Gj" }, "outputs": [], "source": [ "z = y / 2\n", "print(z)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "uvDL2_bCq7kq" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(y)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Xg76CmLYq_Q6" }, "outputs": [], "source": [ "keys = random.split(random.key(0), 8)\n", "mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n", "result = pmap(jnp.dot)(mats, mats)\n", "print(pmap(jnp.mean)(result))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "jbw_hRx7rDzX" }, "outputs": [], "source": [ "timeit -n 5 -r 5 pmap(jnp.dot)(mats, mats).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xf5N9ZRirJhL" }, "source": [ "### Collective communication operations" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "9i1PfxUvrThh" }, "outputs": [], "source": [ "from functools import partial\n", "from jax.lax import psum\n", "\n", "@partial(pmap, axis_name='i')\n", "def normalize(x):\n", " return x / psum(x, 'i')\n", "\n", "print(normalize(jnp.arange(8.)))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "lnvwnlOFrVa-" }, "outputs": [], "source": [ "@partial(pmap, axis_name='rows')\n", "@partial(pmap, axis_name='cols')\n", "def f(x):\n", " row_sum = psum(x, 'rows')\n", " col_sum = psum(x, 'cols')\n", " total_sum = psum(x, ('rows', 'cols'))\n", " return row_sum, col_sum, total_sum\n", "\n", "x = jnp.arange(8.).reshape((4, 2))\n", "a, b, c = f(x)\n", "\n", "print(\"input:\\n\", x)\n", "print(\"row sum:\\n\", a)\n", "print(\"col sum:\\n\", b)\n", "print(\"total sum:\\n\", c)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "f-FBsWeo1AXE" }, "source": [ "<img src=\"https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/nested_pmap.png\" width=\"70%\"/>" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jC-KIMQ1q-lK" }, "source": [ "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-A-oVDo6rdWA" }, "source": [ "### Compose pmap with other transforms!" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "WC_dMIN2rgTZ" }, "outputs": [], "source": [ "@pmap\n", "def f(x):\n", " y = jnp.sin(x)\n", " @pmap\n", " def g(z):\n", " return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n", " return grad(lambda w: jnp.sum(g(w)))(x)\n", "\n", "f(x)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "apuACjPWrixV" }, "outputs": [], "source": [ "grad(lambda x: jnp.sum(f(x)))(x)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WD9xtROsYX4i" }, "source": [ "### Compose everything" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "h65c9AQCWAyn" }, "outputs": [], "source": [ "from jax import jvp, vjp # forward and reverse-mode\n", "\n", "curry = lambda f: partial(partial, f)\n", "\n", "@curry\n", "def jacfwd(fun, x):\n", " pushfwd = partial(jvp, fun, (x,)) # jvp!\n", " std_basis = jnp.eye(np.size(x)).reshape((-1,) + jnp.shape(x)),\n", " y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis) # vmap!\n", " return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))\n", "\n", "@curry\n", "def jacrev(fun, x):\n", " y, pullback = vjp(fun, x) # vjp!\n", " std_basis = jnp.eye(np.size(y)).reshape((-1,) + jnp.shape(y))\n", " jac_flat, = vmap(pullback)(std_basis) # vmap!\n", " return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))\n", "\n", "def hessian(fun):\n", " return jit(jacfwd(jacrev(fun))) # jit!" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "G9qDX84RWhW7" }, "outputs": [], "source": [ "input_hess = hessian(lambda inputs: loss(params, (inputs, targets)))\n", "per_example_hess = pmap(input_hess) # pmap!\n", "per_example_hess(inputs)" ] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [ "AvXl1WDPKjmV" ], "name": "JAX demo.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 1 }