{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(sharded-computation)=\n", "# Introduction to parallel programming\n", "\n", "\n", "\n", "This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n", "\n", "The tutorial covers three modes of parallel computation:\n", "\n", "- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", "- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n", "- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", "\n", "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.\n", "\n", "If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "outputId": "18905ae4-7b5e-4bb9-acb4-d8ab914cb456" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jax\n", "jax.devices()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Key concept: Data sharding\n", "\n", "Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.\n", "\n", "How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n", "\n", "In the simplest cases, arrays are sharded on a single device, as demonstrated below:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "outputId": "39fdbb79-d5c0-4ea6-8b20-88b2c502a27a" }, "outputs": [ { "data": { "text/plain": [ "{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jax.numpy as jnp\n", "arr = jnp.arange(32.0).reshape(4, 8)\n", "arr.devices()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "outputId": "536f773a-7ef4-4526-c58b-ab4d486bf5a1" }, "outputs": [ { "data": { "text/plain": [ "SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arr.sharding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array. For example, {func}`jax.debug.visualize_array_sharding` displays how the array is stored in memory of a single device:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "outputId": "74a793e9-b13b-4d07-d8ec-7e25c547036d" }, "outputs": [ { "data": { "text/html": [ "
\n", " \n", " \n", " \n", " \n", " TPU 0 \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jax.debug.visualize_array_sharding(arr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create an array with a non-trivial sharding, you can define a {mod}`jax.sharding` specification for the array and pass this to {func}`jax.device_put`.\n", "\n", "Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes, where {class}`jax.sharding.Mesh` allows for precise device placement:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "outputId": "0b397dba-3ddc-4aca-f002-2beab7e6b8a5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))\n" ] } ], "source": [ "from jax.sharding import PartitionSpec as P\n", "\n", "mesh = jax.make_mesh((2, 4), ('x', 'y'))\n", "sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))\n", "print(sharding)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Passing this `Sharding` object to {func}`jax.device_put`, you can obtain a sharded array:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "outputId": "c8ceedba-05ca-4156-e6e4-1e98bb664a66" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", " [16. 17. 18. 19. 20. 21. 22. 23.]\n", " [24. 25. 26. 27. 28. 29. 30. 31.]]\n" ] }, { "data": { "text/html": [ "
\n", " \n", " TPU 0 TPU 1 TPU 2 TPU 3 \n", " \n", " \n", " \n", " \n", " \n", " TPU 6 TPU 7 TPU 4 TPU 5 \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "arr_sharded = jax.device_put(arr, sharding)\n", "\n", "print(arr_sharded)\n", "jax.debug.visualize_array_sharding(arr_sharded)" ] }, { "cell_type": "markdown", "metadata": { "id": "UEObolTqw4pp" }, "source": [ "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", "\n", "The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n", "\n", "To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aKNeOHTJnqmS", "outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pinned_host\n", "device\n" ] } ], "source": [ "s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n", "s_dev = s_host.with_memory_kind('device')\n", "arr_host = jax.device_put(arr, s_host)\n", "arr_dev = jax.device_put(arr, s_dev)\n", "print(arr_host.sharding.memory_kind)\n", "print(arr_dev.sharding.memory_kind)" ] }, { "cell_type": "markdown", "metadata": { "id": "jDHYnVqHwaST" }, "source": [ "## 1. Automatic parallelism via `jit`\n", "\n", "Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n", "\n", "The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices.\n", "In the simplest of cases, those heuristics boil down to *computation follows data*.\n", "\n", "To demonstrate how auto-parallelization works in JAX, below is an example that uses a {func}`jax.jit`-decorated staged-out function: it's a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "outputId": "de46f86a-6907-49c8-f36c-ed835e78bc3d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shardings match: True\n" ] } ], "source": [ "@jax.jit\n", "def f_elementwise(x):\n", " return 2 * jnp.sin(x) + 1\n", "\n", "result = f_elementwise(arr_sharded)\n", "\n", "print(\"shardings match:\", result.sharding == arr_sharded.sharding)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.\n", "\n", "Here, you sum along the leading axis of `x`, and visualize how the result values are stored across multiple devices (with {func}`jax.debug.visualize_array_sharding`):" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "outputId": "90c3b997-3653-4a7b-c8ff-12a270f11d02" }, "outputs": [ { "data": { "text/html": [ "
TPU 0,6 TPU 1,7 TPU 2,4 TPU 3,5 \n", " \n", "\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,6\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 1,7\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mTPU 2,4\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mTPU 3,5\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[48. 52. 56. 60. 64. 68. 72. 76.]\n" ] } ], "source": [ "@jax.jit\n", "def f_contract(x):\n", " return x.sum(axis=0)\n", "\n", "result = f_contract(arr_sharded)\n", "jax.debug.visualize_array_sharding(result)\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "Q4N5mrr9i_ki" }, "source": [ "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n", "\n", "### 1.1 Sharding transformation between memory types\n", "\n", "The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n", "\n", "#### Example 1: Pinned host to device memory\n", "\n", "In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PXu3MhafyRHo", "outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", " [16. 17. 18. 19. 20. 21. 22. 23.]\n", " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", "device\n" ] } ], "source": [ "f = jax.jit(lambda x: x, out_shardings=s_dev)\n", "out_dev = f(arr_host)\n", "print(out_dev)\n", "print(out_dev.sharding.memory_kind)" ] }, { "cell_type": "markdown", "metadata": { "id": "LuYFqpcBySiX" }, "source": [ "#### Example 2: Device to pinned_host memory\n", "\n", "In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qLsgNlKfybRw", "outputId": "a16448b9-7e39-408f-b200-505f65ad4464" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", " [16. 17. 18. 19. 20. 21. 22. 23.]\n", " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", "pinned_host\n" ] } ], "source": [ "g = jax.jit(lambda x: x, out_shardings=s_host)\n", "out_host = g(arr_dev)\n", "print(out_host)\n", "print(out_host.sharding.memory_kind)" ] }, { "cell_type": "markdown", "metadata": { "id": "7BGD31-owaSU" }, "source": [ "## 2. Semi-automated sharding with constraints\n", "\n", "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", "\n", "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "outputId": "8468f5c6-76ca-4367-c9f2-93c723687cfd" }, "outputs": [ { "data": { "text/html": [ "
TPU 0 TPU 1 TPU 2 TPU 3 TPU 6 TPU 7 TPU 4 TPU 5 \n", " \n", "\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[48. 52. 56. 60. 64. 68. 72. 76.]\n" ] } ], "source": [ "@jax.jit\n", "def f_contract_2(x):\n", " out = x.sum(axis=0)\n", " sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", " return jax.lax.with_sharding_constraint(out, sharding)\n", "\n", "result = f_contract_2(arr_sharded)\n", "jax.debug.visualize_array_sharding(result)\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This gives you a function with the particular output sharding you'd like.\n", "\n", "## 3. Manual parallelism with `shard_map`\n", "\n", "In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n", "\n", "`shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below:\n", "\n", "- As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names.\n", "- The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together.\n", "\n", "**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "outputId": "435c32f3-557a-4676-c11b-17e6bab8c1e2" }, "outputs": [ { "data": { "text/plain": [ "Array([ 1. , 2.682942 , 2.818595 , 1.28224 , -0.513605 ,\n", " -0.9178486 , 0.44116896, 2.3139732 , 2.9787164 , 1.824237 ,\n", " -0.08804226, -0.99998045, -0.07314599, 1.8403342 , 2.9812148 ,\n", " 2.3005757 , 0.42419332, -0.92279506, -0.50197446, 1.2997544 ,\n", " 2.8258905 , 2.6733112 , 0.98229736, -0.69244075, -0.81115675,\n", " 0.7352965 , 2.525117 , 2.912752 , 1.5418116 , -0.32726777,\n", " -0.97606325, 0.19192469], dtype=float32)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jax.experimental.shard_map import shard_map\n", "mesh = jax.make_mesh((8,), ('x',))\n", "\n", "f_elementwise_sharded = shard_map(\n", " f_elementwise,\n", " mesh=mesh,\n", " in_specs=P('x'),\n", " out_specs=P('x'))\n", "\n", "arr = jnp.arange(32)\n", "f_elementwise_sharded(arr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function you write only \"sees\" a single batch of the data, which you can check by printing the device local shape:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "outputId": "99a3dc6e-154a-4ef6-8eaa-3dd0b68fb1da" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "global shape: x.shape=(32,)\n", "device local shape: x.shape=(4,)\n" ] } ], "source": [ "x = jnp.arange(32)\n", "print(f\"global shape: {x.shape=}\")\n", "\n", "def f(x):\n", " print(f\"device local shape: {x.shape=}\")\n", " return x * 2\n", "\n", "y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because each of your functions only \"sees\" the device-local part of the data, it means that aggregation-like functions require some extra thought.\n", "\n", "For example, here's what a `shard_map` of a {func}`jax.numpy.sum` looks like:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "outputId": "1e9a45f5-5418-4246-c75b-f9bc6dcbbe72" }, "outputs": [ { "data": { "text/plain": [ "Array([ 6, 22, 38, 54, 70, 86, 102, 118], dtype=int32)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(x):\n", " return jnp.sum(x, keepdims=True)\n", "\n", "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Your function `f` operates separately on each shard, and the resulting summation reflects this.\n", "\n", "If you want to sum across shards, you need to explicitly request it using collective operations like {func}`jax.lax.psum`:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "outputId": "4fd29e80-4fee-42b7-ff80-29f9887ab38d" }, "outputs": [ { "data": { "text/plain": [ "Array(496, dtype=int32)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(x):\n", " sum_in_shard = x.sum()\n", " return jax.lax.psum(sum_in_shard, 'x')\n", "\n", "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because the output no longer has a sharded dimension, set `out_specs=P()` (recall that the `out_specs` argument identifies how the blocks are assembled back together in `shard_map`).\n", "\n", "## Comparing the three approaches\n", "\n", "With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer.\n", "\n", "Start by defining your canonical function like this:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "1TdhfTsoiqS1" }, "outputs": [], "source": [ "@jax.jit\n", "def layer(x, weights, bias):\n", " return jax.nn.sigmoid(x @ weights + bias)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "outputId": "f3007fe4-f6f3-454e-e7c5-3638de484c0a" }, "outputs": [ { "data": { "text/plain": [ "Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "rng = np.random.default_rng(0)\n", "\n", "x = rng.normal(size=(32,))\n", "weights = rng.normal(size=(32, 4))\n", "bias = rng.normal(size=(4,))\n", "\n", "layer(x, weights, bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n", "\n", "If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "outputId": "80be899e-8dbc-4bfc-acd2-0f3d554a0aa5" }, "outputs": [ { "data": { "text/plain": [ "Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mesh = jax.make_mesh((8,), ('x',))\n", "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", "\n", "x_sharded = jax.device_put(x, sharding)\n", "weights_sharded = jax.device_put(weights, sharding)\n", "\n", "layer(x_sharded, weights_sharded, bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "outputId": "bb63e8da-ff4f-4e95-f083-10584882daf4" }, "outputs": [ { "data": { "text/plain": [ "Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@jax.jit\n", "def layer_auto(x, weights, bias):\n", " x = jax.lax.with_sharding_constraint(x, sharding)\n", " weights = jax.lax.with_sharding_constraint(weights, sharding)\n", " return layer(x, weights, bias)\n", "\n", "layer_auto(x, weights, bias) # pass in unsharded inputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "outputId": "568d1c85-39a7-4dba-f09a-0e4f7c2ea918" }, "outputs": [ { "data": { "text/plain": [ "Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from functools import partial\n", "\n", "@jax.jit\n", "@partial(shard_map, mesh=mesh,\n", " in_specs=(P('x'), P('x', None), P(None)),\n", " out_specs=P(None))\n", "def layer_sharded(x, weights, bias):\n", " return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)\n", "\n", "layer_sharded(x, weights, bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Next steps\n", "\n", "This tutorial serves as a brief introduction of sharded and parallel computation in JAX.\n", "\n", "To learn about each SPMD method in-depth, check out these docs:\n", "- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization`\n", "- {doc}`../notebooks/shard_map`" ] } ], "metadata": { "accelerator": "TPU", "colab": { "gpuType": "V28", "provenance": [], "toc_visible": true }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }