{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ebUMqK9mGIDm" }, "source": [ "## The basics: interactive NumPy on GPU and TPU\n", "\n", "---\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "27TqNtiQF97X" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from jax import random" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "cRWoxSCNGU4o" }, "outputs": [], "source": [ "key = random.key(0)\n", "key, subkey = random.split(key)\n", "x = random.normal(key, (5000, 5000))\n", "\n", "print(x.shape)\n", "print(x.dtype)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "diPllsvgGfSA" }, "outputs": [], "source": [ "y = jnp.dot(x, x)\n", "print(y[0, 0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "8-psauxnGiRk" }, "outputs": [], "source": [ "x" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "-2FMQ8UeoTJ8" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(x[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(jnp.dot(x, x.T))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "z4VX5PkMHJIu" }, "outputs": [], "source": [ "print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "ORZ9Odu85BCJ" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "x_cpu = np.array(x)\n", "%timeit -n 5 -r 2 np.dot(x_cpu, x_cpu)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "5BKh0eeAGvO5" }, "outputs": [], "source": [ "%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "fm4Q2zpFHUAu" }, "source": [ "## Automatic differentiation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "MCIQbyUYHWn1" }, "outputs": [], "source": [ "from jax import grad" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "kfqZpKYsHo4j" }, "outputs": [], "source": [ "def f(x):\n", " if x > 0:\n", " return 2 * x ** 3\n", " else:\n", " return 3 * x" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "K_26_odPHqLJ" }, "outputs": [], "source": [ "key = random.key(0)\n", "x = random.normal(key, ())\n", "\n", "print(grad(f)(x))\n", "print(grad(f)(-x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "q5V3A6loHrhS" }, "outputs": [], "source": [ "print(grad(grad(f))(-x))\n", "print(grad(grad(grad(f)))(-x))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "bmxAPFC0I8b0" }, "source": [ "Other JAX autodiff highlights:\n", "\n", "* Forward- and reverse-mode, totally composable\n", "* Fast Jacobians and Hessians\n", "* Complex number support (holomorphic and non-holomorphic)\n", "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "TRkxaVLJKNre" }, "source": [ "## End-to-end compilation with XLA with `jit`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "bKo4rX9-KSW7" }, "outputs": [], "source": [ "from jax import jit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "94iIgZSfKWh8" }, "outputs": [], "source": [ "key = random.key(0)\n", "x = random.normal(key, (5000, 5000))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "Ybuz8Ag9KXMd" }, "outputs": [], "source": [ "def f(x):\n", " y = x\n", " for _ in range(10):\n", " y = y - 0.1 * y + 3.\n", " return y[:100, :100]\n", "\n", "f(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "Y9dx5ifSKaGJ" }, "outputs": [], "source": [ "g = jit(f)\n", "g(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "UtsS67BvKYkC" }, "outputs": [], "source": [ "%timeit f(x).block_until_ready()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "-vfcaSo9KbvR" }, "outputs": [], "source": [ "%timeit -n 100 g(x).block_until_ready()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "E3BQF1_AKeLn" }, "outputs": [], "source": [ "grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Tmf1NT2Wqv5p" }, "source": [ "## Parallelization over multiple accelerators with pmap" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "t6RRAFn1CEln" }, "outputs": [], "source": [ "jax.device_count()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "tEK1I6Duqunw" }, "outputs": [], "source": [ "from jax import pmap" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "S-iCNfeGqzkY" }, "outputs": [], "source": [ "y = pmap(lambda x: x ** 2)(jnp.arange(8))\n", "print(y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "xgutf5JPP3wi" }, "outputs": [], "source": [ "y" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "uvDL2_bCq7kq" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(y)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xf5N9ZRirJhL" }, "source": [ "### Collective communication operations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from functools import partial\n", "from jax.lax import psum\n", "\n", "@partial(pmap, axis_name='i')\n", "def f(x):\n", " total = psum(x, 'i')\n", " return x / total, total\n", "\n", "normalized, total = f(jnp.arange(8.))\n", "\n", "print(f\"normalized:\\n{normalized}\\n\")\n", "print(\"total:\", total)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jC-KIMQ1q-lK" }, "source": [ "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Automatic parallelization with sharded_jit (new!)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from jax.experimental import sharded_jit, PartitionSpec as P" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from jax import lax\n", "\n", "conv = lambda image, kernel: lax.conv(image, kernel, (1, 1), 'SAME')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image = jnp.ones((1, 8, 2000, 1000)).astype(np.float32)\n", "kernel = jnp.array(np.random.random((8, 8, 5, 5)).astype(np.float32))\n", "\n", "np.set_printoptions(edgeitems=1)\n", "conv(image, kernel)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit conv(image, kernel).block_until_ready()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_partitions = P(1, 1, 4, 2)\n", "sharded_conv = sharded_jit(conv,\n", " in_parts=(image_partitions, None),\n", " out_parts=image_partitions)\n", "\n", "sharded_conv(image, kernel)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit -n 10 sharded_conv(image, kernel).block_until_ready()" ] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [ "AvXl1WDPKjmV" ], "name": "JAX demo.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }