Merge pull request #18411 from 8bitmp3:add_kaggle

PiperOrigin-RevId: 579937639
This commit is contained in:
jax authors 2023-11-06 13:20:03 -08:00
commit 5f4d4797b2

View File

@ -1,24 +1,9 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Pmap Cookbook",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "_4ware9HrjIk",
"colab_type": "text"
"id": "_4ware9HrjIk"
},
"source": [
"# Pmap CookBook"
@ -27,35 +12,33 @@
{
"cell_type": "markdown",
"metadata": {
"id": "sk-3cPGIBTq8",
"colab_type": "text"
"id": "sk-3cPGIBTq8"
},
"source": [
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n",
"\n",
"This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n",
"\n",
"To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine or a Cloud TPU.\n",
"**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Google Cloud TPU or a Kaggle TPU VM. The required features are not supported by the Google Colab TPU runtime at this time.\n",
"\n",
"The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Srs8W9F6Jo15",
"colab_type": "code",
"colab": {}
"id": "Srs8W9F6Jo15"
},
"outputs": [],
"source": [
"import jax.numpy as jnp"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hBasY8p1JFId",
"colab_type": "text"
"id": "hBasY8p1JFId"
},
"source": [
"## Basics"
@ -64,8 +47,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "caPiPIWgM7-W",
"colab_type": "text"
"id": "caPiPIWgM7-W"
},
"source": [
"### Pure maps, with no communication"
@ -74,8 +56,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "2e_06-OAJNyi",
"colab_type": "text"
"id": "2e_06-OAJNyi"
},
"source": [
"A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):"
@ -83,36 +64,31 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6gGT77cIImcE",
"colab_type": "code",
"colab": {}
"id": "6gGT77cIImcE"
},
"outputs": [],
"source": [
"from jax import pmap"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-yY3lOFpJIUS",
"colab_type": "code",
"colab": {}
"id": "-yY3lOFpJIUS"
},
"outputs": [],
"source": [
"result = pmap(lambda x: x ** 2)(jnp.arange(7))\n",
"print(result)"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PgKNzxKPNEYA",
"colab_type": "text"
"id": "PgKNzxKPNEYA"
},
"source": [
"In terms of what values are computed, `pmap` is similar to `vmap` in that it transforms a function to map over an array axis:"
@ -120,11 +96,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mmCMQ64QbAbz",
"colab_type": "code",
"colab": {}
"id": "mmCMQ64QbAbz"
},
"outputs": [],
"source": [
"from jax import vmap\n",
"\n",
@ -133,15 +109,12 @@
"\n",
"print(vmap(jnp.add)(x, y))\n",
"print(pmap(jnp.add)(x, y))"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iZgTmx5pFd6z",
"colab_type": "text"
"id": "iZgTmx5pFd6z"
},
"source": [
"But `pmap` and `vmap` differ in in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
@ -149,11 +122,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4N1--GgGFe9d",
"colab_type": "code",
"colab": {}
"id": "4N1--GgGFe9d"
},
"outputs": [],
"source": [
"from jax import make_jaxpr\n",
"\n",
@ -173,20 +146,17 @@
"\n",
"print(\"pmap(f) jaxpr\")\n",
"print(make_jaxpr(pmap(f))(xs, ys))"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BjDnQkzSa_vZ",
"colab_type": "text"
"id": "BjDnQkzSa_vZ"
},
"source": [
"source": [
"Notice that applying `vmap(f)` to these arguments leads to a `dot_general` to express the batch matrix multiplication in a single primitive, while applying `pmap(f)` instead leads to a primitive that calls replicas of the original `f` in parallel.\n",
"\n",
"An important constraint with using `pmap` is that ",
"An important constraint with using `pmap` is that \n",
"the mapped axis size must be less than or equal to the number of XLA devices available (and for nested `pmap` functions, the product of the mapped axis sizes must be less than or equal to the number of XLA devices).\n",
"\n",
"You can use the output of a `pmap` function just like any other value:"
@ -194,38 +164,33 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "H4DXQWobOf7V",
"colab_type": "code",
"colab": {}
"id": "H4DXQWobOf7V"
},
"outputs": [],
"source": [
"y = pmap(lambda x: x ** 2)(jnp.arange(8))\n",
"z = y / 2\n",
"print(z)"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fM1Une9Rfqld",
"colab_type": "code",
"colab": {}
"id": "fM1Une9Rfqld"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(y)"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "644UB23YfbW4",
"colab_type": "text"
"id": "644UB23YfbW4"
},
"source": [
"But while the output here acts just like a NumPy ndarray, if you look closely it has a different type:"
@ -233,22 +198,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "59hnyVOtfavX",
"colab_type": "code",
"colab": {}
"id": "59hnyVOtfavX"
},
"outputs": [],
"source": [
"y"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4brdSdeyf2MP",
"colab_type": "text"
"id": "4brdSdeyf2MP"
},
"source": [
"A sharded `Array` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n",
@ -258,36 +220,31 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BSSllkblg9Rn",
"colab_type": "code",
"colab": {}
"id": "BSSllkblg9Rn"
},
"outputs": [],
"source": [
"y / 2"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "efyMSNGahq6f",
"colab_type": "code",
"colab": {}
"id": "efyMSNGahq6f"
},
"outputs": [],
"source": [
"import numpy as np\n",
"np.sin(y)"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ba4jwfkbOwXW",
"colab_type": "text"
"id": "Ba4jwfkbOwXW"
},
"source": [
"Thinking about device memory is important to maximize performance by avoiding data transfers, but you can always fall back to treating arraylike values as (read-only) NumPy ndarrays and your code will still work.\n",
@ -297,11 +254,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rWl68coLJSi7",
"colab_type": "code",
"colab": {}
"id": "rWl68coLJSi7"
},
"outputs": [],
"source": [
"from jax import random\n",
"\n",
@ -311,44 +268,37 @@
"mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)\n",
"# the stack of matrices is represented logically as a single array\n",
"mats.shape"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nH2gGNgfNOJD",
"colab_type": "code",
"colab": {}
"id": "nH2gGNgfNOJD"
},
"outputs": [],
"source": [
"# run a local matmul on each device in parallel (no data transfer)\n",
"result = pmap(lambda x: jnp.dot(x, x.T))(mats)\n",
"result.shape"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MKTZ59iPNPi5",
"colab_type": "code",
"colab": {}
"id": "MKTZ59iPNPi5"
},
"outputs": [],
"source": [
"# compute the mean on each device in parallel and print the results\n",
"print(pmap(jnp.mean)(result))"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "26iH7sHFiz2l",
"colab_type": "text"
"id": "26iH7sHFiz2l"
},
"source": [
"In this example, the large matrices never had to be moved between devices or back to the host; only one scalar per device was pulled back to the host."
@ -357,8 +307,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "MdRscR5MONuN",
"colab_type": "text"
"id": "MdRscR5MONuN"
},
"source": [
"### Collective communication operations"
@ -367,8 +316,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "bFtajUwp5WYx",
"colab_type": "text"
"id": "bFtajUwp5WYx"
},
"source": [
"In addition to expressing pure maps, where no communication happens between the replicated functions, with `pmap` you can also use special collective communication operations.\n",
@ -378,26 +326,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "d5s8rJVUORQ3",
"colab_type": "code",
"colab": {}
"id": "d5s8rJVUORQ3"
},
"outputs": [],
"source": [
"from jax import lax\n",
"\n",
"normalize = lambda x: x / lax.psum(x, axis_name='i')\n",
"result = pmap(normalize, axis_name='i')(jnp.arange(4.))\n",
"print(result)"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6jd9DVBQPD-Z",
"colab_type": "text"
"id": "6jd9DVBQPD-Z"
},
"source": [
"To use a collective operation like `lax.psum`, you need to supply an `axis_name` argument to `pmap`. The `axis_name` argument associates a name to the mapped axis so that collective operations can refer to it.\n",
@ -407,11 +352,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c48qVvlkPF5p",
"colab_type": "code",
"colab": {}
"id": "c48qVvlkPF5p"
},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
@ -420,15 +365,12 @@
" return x / lax.psum(x, 'i')\n",
"\n",
"print(normalize(jnp.arange(4.)))"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3Pr6n8KkOpmz",
"colab_type": "text"
"id": "3Pr6n8KkOpmz"
},
"source": [
"Axis names are also important for nested use of `pmap`, where collectives can be applied to distinct mapped axes:"
@ -436,11 +378,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IwoeEd16OrD3",
"colab_type": "code",
"colab": {}
"id": "IwoeEd16OrD3"
},
"outputs": [],
"source": [
"@partial(pmap, axis_name='rows')\n",
"@partial(pmap, axis_name='cols')\n",
@ -455,15 +397,12 @@
"\n",
"print(a)\n",
"print(a.sum(0))"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bnc-vlKA6hvI",
"colab_type": "text"
"id": "Bnc-vlKA6hvI"
},
"source": [
"When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n",
@ -475,11 +414,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uazGbMwmf5zO",
"colab_type": "code",
"colab": {}
"id": "uazGbMwmf5zO"
},
"outputs": [],
"source": [
"from jax._src import xla_bridge\n",
"device_count = jax.device_count()\n",
@ -517,15 +456,12 @@
"for _ in range(20):\n",
" reshaped_board = step(reshaped_board)\n",
" print_board(reshaped_board)"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KrkEuY3yO7_M",
"colab_type": "text"
"id": "KrkEuY3yO7_M"
},
"source": [
"## Composing with differentiation"
@ -534,8 +470,7 @@
{
"cell_type": "markdown",
"metadata": {
"id": "dGHE7dfypqqU",
"colab_type": "text"
"id": "dGHE7dfypqqU"
},
"source": [
"As with all things in JAX, you should expect `pmap` to compose with other transformations, including differentiation."
@ -543,11 +478,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VkS7_RcTO_48",
"colab_type": "code",
"colab": {}
"id": "VkS7_RcTO_48"
},
"outputs": [],
"source": [
"from jax import grad\n",
"\n",
@ -560,32 +495,41 @@
" return grad(lambda w: jnp.sum(g(w)))(x)\n",
" \n",
"f(x)"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4gAJ3QF6PBvi",
"colab_type": "code",
"colab": {}
"id": "4gAJ3QF6PBvi"
},
"outputs": [],
"source": [
"grad(lambda x: jnp.sum(f(x)))(x)"
],
"execution_count": 0,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8mAz9bEfPl2F",
"colab_type": "text"
"id": "8mAz9bEfPl2F"
},
"source": [
"When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the backward pass of the computation is parallelized just like the forward-pass."
]
}
]
],
"metadata": {
"accelerator": "TPU",
"colab": {
"collapsed_sections": [],
"name": "Pmap_Cookbook.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}