From e8f43d1cef048e7ac95d5a9648913bb8f043d63d Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 13 Mar 2025 11:54:52 -0400 Subject: [PATCH] Explicit sharding docs --- docs/advanced_guide.rst | 1 + docs/conf.py | 1 + docs/notebooks/explicit-sharding.ipynb | 713 +++++++++++++++++++++++++ docs/notebooks/explicit-sharding.md | 417 +++++++++++++++ 4 files changed, 1132 insertions(+) create mode 100644 docs/notebooks/explicit-sharding.ipynb create mode 100644 docs/notebooks/explicit-sharding.md diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index 85ed315c9..db2e83ae2 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -12,6 +12,7 @@ operations. :maxdepth: 1 notebooks/Distributed_arrays_and_automatic_parallelization + notebooks/explicit-sharding notebooks/shard_map multi_process distributed_data_loading diff --git a/docs/conf.py b/docs/conf.py index bbc417af1..45964b6d8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -222,6 +222,7 @@ nb_execution_excludepatterns = [ 'jep/9407-type-promotion.*', # TODO(jakevdp): enable execution on the following if possible: 'notebooks/Distributed_arrays_and_automatic_parallelization.*', + 'notebooks/explicit-sharding.*', 'notebooks/autodiff_remat.*', # Fails on readthedocs with Kernel Died 'notebooks/convolutions.ipynb', diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb new file mode 100644 index 000000000..44d46bddb --- /dev/null +++ b/docs/notebooks/explicit-sharding.ipynb @@ -0,0 +1,713 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ZVJCNxUcVkkm" + }, + "source": [ + "# Explicit sharding (a.k.a. \"sharding in types\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ATLBMlw3VcCJ" + }, + "source": [ + "JAX's traditional automatic sharding leaves sharding decisions to the compiler.\n", + "You can provide hints to the compiler using\n", + "`jax.lax.with_sharding_constraint` but for the most part you're supposed to be\n", + "focussed on the math while the compiler worries about sharding.\n", + "\n", + "But what if you have a strong opinion about how you want your program sharded?\n", + "With enough calls to `with_sharding_constraint` you can probably guide the\n", + "compiler's hand to make it do what you want. But \"compiler tickling\" is\n", + "famously not a fun programming model. Where should you put the sharding\n", + "constraints? You could put them on every single intermediate but that's a lot\n", + "of work and it's also easy to make mistakes that way because there's no way to\n", + "check that the shardings make sense together. More commonly, people add just\n", + "enough sharding annotations to constrain the compiler. But this is a slow\n", + "iterative process. It's hard to know ahead of time what XLA's gSPMD pass will\n", + "do (it's a whole-program optimization) so all you can do is add annotations,\n", + "inspect XLA's sharding choices to see what happened, and repeat.\n", + "\n", + "To fix this we've come up with a different style of sharding programming we\n", + "call \"explicit sharding\" or \"sharding in types\". The idea is that sharding\n", + "propagation happens at the JAX level at trace time. Each JAX operation has a\n", + "sharding rule that takes the shardings of the op's arguments and produces a\n", + "sharding for the op's result. For most operations these rules are simple and\n", + "obvious because there's only one reasonable choice. But for some operations it's\n", + "unclear how to shard the result. In that case we ask the programmer\n", + "to provide an `out_sharding` argument explicitly and we throw a (trace-time)\n", + "error otherwise. Since the shardings are propagated at trace time they can\n", + "also be _queried_ at trace time too. In the rest of this doc we'll describe\n", + "how to use explicit sharding mode. Note that this is a new feature so we\n", + "expect there to be bugs and unimplemented cases. Please let us know when you\n", + "find something that doesn't work!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hVi6mApuVw3r", + "outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf" + }, + "outputs": [], + "source": [ + "import jax\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", + "from jax.sharding import PartitionSpec as P, AxisType, set_mesh\n", + "from jax.experimental.shard import reshard, auto_axes\n", + "from jax._src.mesh import get_abstract_mesh\n", + "\n", + "jax.config.update('jax_num_cpu_devices', 8)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oU5O6yOLWqbP" + }, + "source": [ + "## Setting up an explicit mesh\n", + "\n", + "The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n", + "the JAX-level _type_ of a value includes a description of how the value is sharded.\n", + "We can query the JAX-level type of any JAX value (or Numpy array, or Python\n", + "scalar) using `jax.typeof`:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mzDIDvj7Vw0k", + "outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "JAX-level type of some_array: ShapedArray(int32[8])\n" + ] + } + ], + "source": [ + "some_array = np.arange(8)\n", + "print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TZzp_1sXW061" + }, + "source": [ + "Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n", + "is almost _defined_ as \"the information about a value we have access to while\n", + "under a jit)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IyPx_-IBVwxr", + "outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "JAX-level type of x during tracing: ShapedArray(int32[8])\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@jax.jit\n", + "def foo(x):\n", + " print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n", + " return x + x\n", + "\n", + "foo(some_array)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c3gNPzfZW45K" + }, + "source": [ + "These types show the shape and dtype of array but they don't appear to\n", + "show sharding. (Actually, they _did_ show sharding, but the shardings were\n", + "trivial. See \"Concrete array shardings\", below.) To start seeing some\n", + "interesting shardings we need to set up an explicit-sharding mesh. We use\n", + "`set_mesh` to set it as the current mesh for the remainder of this notebook.\n", + "(If you only want to set the mesh for some particular scope and return to the previous\n", + "mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NO2ulM_QW7a8", + "outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n" + ] + } + ], + "source": [ + "mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Explicit, AxisType.Explicit))\n", + "set_mesh(mesh)\n", + "\n", + "print(f\"Current mesh is: {get_abstract_mesh()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V7bVz6tzW_Eb" + }, + "source": [ + "Now we can create some sharded arrays using `reshard`:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1-TzmA0AXCAf", + "outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "replicated_array type: ShapedArray(int32[4,2])\n", + "sharded_array type: ShapedArray(int32[4@X,2])\n" + ] + } + ], + "source": [ + "replicated_array = np.arange(8).reshape(4, 2)\n", + "sharded_array = reshard(replicated_array, P(\"X\", None))\n", + "\n", + "print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n", + "print(f\"sharded_array type: {jax.typeof(sharded_array)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B0jBBXtgXBxr" + }, + "source": [ + "We should read the type `f32[4@X, 2]` as \"a 4-by-2 array of 32-bit floats whose first dimension\n", + "is sharded along mesh axis 'X'. The array is replicated along all other mesh\n", + "axes\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N8yMauHAXKtX" + }, + "source": [ + "These shardings associated with JAX-level types propagate through operations. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Gy7ABds3XND3", + "outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "arg0 sharding: ShapedArray(int32[4@X,1])\n", + "arg1 sharding: ShapedArray(int32[1,8@Y])\n", + "result sharding: ShapedArray(int32[4@X,8@Y])\n" + ] + } + ], + "source": [ + "arg0 = reshard(np.arange(4).reshape(4, 1), P(\"X\", None))\n", + "arg1 = reshard(np.arange(8).reshape(1, 8), P(None, \"Y\"))\n", + "\n", + "result = arg0 + arg1\n", + "\n", + "print(f\"arg0 sharding: {jax.typeof(arg0)}\")\n", + "print(f\"arg1 sharding: {jax.typeof(arg1)}\")\n", + "print(f\"result sharding: {jax.typeof(result)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lwsygUmVXPCk" + }, + "source": [ + "We can do the same type querying under a jit:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "grCcotr-XQjY", + "outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x sharding: ShapedArray(int32[4@X,1])\n", + "y sharding: ShapedArray(int32[1,8@Y])\n", + "ans sharding: ShapedArray(int32[4@X,8@Y])\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n", + " [ 1, 2, 3, 4, 5, 6, 7, 8],\n", + " [ 2, 3, 4, 5, 6, 7, 8, 9],\n", + " [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@jax.jit\n", + "def add_arrays(x, y):\n", + " ans = x + y\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"y sharding: {jax.typeof(y)}\")\n", + " print(f\"ans sharding: {jax.typeof(ans)}\")\n", + " return ans\n", + "\n", + "add_arrays(arg0, arg1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lVd6a5ufXZoH" + }, + "source": [ + "That's the gist of it. Shardings propagate deterministically at trace time and\n", + "we can query them at trace time." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ETtwK3LCXSkd" + }, + "source": [ + "## Sharding rules and operations with ambiguous sharding\n", + "\n", + "Each op has a sharding rule which specifies its output sharding given its\n", + "input shardings. A sharding rule may also throw a (trace-time) error. Each op\n", + "is free to implement whatever sharding rule it likes, but the usual pattern is\n", + "the following: For each output axis we identify zero of more corresponding\n", + "input axes. The output axis is then\n", + "sharded according to the “consensus” sharding of the corresponding input axes. i.e., it's\n", + "`None` if the input shardings are all `None`, and it's the common non-None input sharding\n", + "if there’s exactly one of them, or an error (requiring an explicit out_sharding=... kwarg) otherwise." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "an8-Fq1uXehp" + }, + "source": [ + "This procedure is done on an axis-by-axis basis. When it’s done, we might end\n", + "up with an array sharding that mentions a mesh axis more than once, which is\n", + "illegal. In that case we raise a (trace-time) sharding error and ask for an\n", + "explicit out_sharding.\n", + "\n", + "Here are some example sharding rules:\n", + " * nullary ops like `jnp.zeros`, `jnp.arange`: These ops create arrays out of whole\n", + " cloth so they don’t have input shardings to propagate. Their output is\n", + " unsharded by default unless overridden by the out_sharding kwarg.\n", + " * unary elementwise ops like `sin`, `exp`: The output is sharded the same as the\n", + " input.\n", + " * binary ops (`+`, `-`, `*` etc.): Axis shardings of “zipped” dimensions\n", + " must match (or be `None`). “Outer product” dimensions (dimensions that\n", + " appear in only one argument) are sharded as they are in the input. If the\n", + " result ends up mentioning a mesh axis more than once it's an error.\n", + " * `reshape.` Reshape is a particularly tricky op. An output axis can map to more\n", + " than one input axis (when reshape is used to merge axes) or just a part\n", + " of an input axis (when reshape is used to split axes). Our usual rules\n", + " don’t apply. Instead we treat reshape as follows. We strip away singleton\n", + " axes (these can’t be sharded anyway. Then\n", + " we decide whether the reshape is a “split” (splitting a single axis into\n", + " two or more adjacent axes), a “merge” (merging two or more adjacent axes\n", + " into a single one) or something else. If we have a split or merge case in\n", + " which the split/merged axes are sharded as None then we shard the\n", + " resulting split/merged axes as None and the other axes according to their\n", + " corresponding input axis shardings. In all other cases we throw an error\n", + " and require the user to provide an `out_shardings` argument." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZMp6w48Xmd7" + }, + "source": [ + "## JAX transformations and higher-order functions\n", + "\n", + "The staged-out representation of JAX programs is explicitly typed. (We call\n", + "the types “avals” but that’s not important.) In explicit-sharding mode, the\n", + "sharding is part of that type. This means that shardings need to match\n", + "wherever types need to match. For example, the two sides of a `lax.cond` need to\n", + "have results with matching shardings. And the carry of `lax.scan` needs to have the\n", + "same sharding at the input and the output of the scan body. And when you\n", + "contruct a jaxpr without concrete arguments using `make_jaxpr` you need to\n", + "provide shardings too. Certain JAX transformations perform type-level\n", + "operations. Automatic differentation constructs a tangent type for each primal\n", + "type in the original computation (e.g. `TangentOf(float) == float`,\n", + "`TangentOf(int) == float0`). With sharding in the types, this means that tangent\n", + "values are sharded in the same way as their primal values. Vmap and scan also\n", + "do type-level operations, they lift an array shape to a rank-augmented version\n", + "of that shape. That extra array axis needs a sharding. We can infer it from the\n", + "arguments to the vmap/scan but they all need to agree. And a nullary vmap/scan\n", + "needs an explicit sharding argument just as it needs an explicit length\n", + "argument." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ERJx4p0tXoS3" + }, + "source": [ + "## Working around unimplemented sharding rules using `auto_sharding`\n", + "\n", + "The implementation of explicit sharding is still a work-in-progress and there\n", + "are plenty of ops that are missing sharding rules. For example, `scatter` and\n", + "`gather` (i.e. indexing ops).\n", + "\n", + "Normally we wouldn't suggest using a feature with so many unimplemented cases,\n", + "but in this instance there's a reasonable fallback you can use: `auto_axes`.\n", + "The idea is that you can temporarily drop into a context where the mesh axes\n", + "are \"auto\" rather than \"explicit\". You explicitly specify how you intend the\n", + "final result of the `auto_axes` to be sharded as it gets returned to the calling context.\n", + "\n", + "This works as a fallback for ops with unimplemented sharding rules. It also\n", + "works when you want to override the sharding-in-types type system. For\n", + "example, suppose we want to add a `f32[4@X, 4]` to a `f32[4, 4@X]`. Our\n", + "sharding rule for addition would throw an error: the result would need to be\n", + "`f32[4@X, 4@X]`, which tries uses a mesh axis twice, which is illegal. But say you\n", + "want to perform the operation anyway, and you want the result to be sharded along\n", + "the first axis only, like `f32[4@X, 4]`. You can do this as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fpFEaMBcXsJG", + "outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ERROR!\n", + "add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]\n", + "=== try again with auto_axes ===\n", + "We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n", + "Result type: ShapedArray(int32[4@X,4])\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result type: ShapedArray(int32[4@X,4])\n" + ] + } + ], + "source": [ + "some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", None))\n", + "some_y = reshard(np.arange(16).reshape(4, 4), P(None, \"X\"))\n", + "\n", + "try:\n", + " some_x + some_y\n", + "except Exception as e:\n", + " print(\"ERROR!\")\n", + " print(e)\n", + "\n", + "print(\"=== try again with auto_axes ===\")\n", + "\n", + "@auto_axes\n", + "def add_with_out_sharding_kwarg(x, y):\n", + " print(f\"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}\")\n", + " return x + y\n", + "\n", + "result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P(\"X\", None))\n", + "print(f\"Result type: {jax.typeof(result)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8-_zDr-AXvb6" + }, + "source": [ + "## Using a mixture of sharding modes\n", + "\n", + "JAX now has three styles of parallelism:\n", + "\n", + " * *Automatic sharding* is where you treat all the devices as a single logical\n", + " machine and write a \"global view\" array program for that machine. The\n", + " compiler decides how to partition the data and computation across the\n", + " available devices. You can give hints to the compiler using\n", + " `with_sharding_constraint`.\n", + " * *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n", + " you're writing a global-view program. The difference is that the sharding\n", + " of each array is part of the array's JAX-level type making it an explicit\n", + " part of the programming model. These shardings are propagated at the JAX\n", + " level and queryable at trace time. It's still the compiler's responsibility\n", + " to turn the whole-array program into per-device programs (turning `jnp.sum`\n", + " into `psum` for example) but the compiler is heavily constrained by the\n", + " user-supplied shardings.\n", + " * *Manual Sharding* (`shard_map`) is where you write a program from the\n", + " perspective of a single device. Communication between devices happens via\n", + " explicit collective operations like psum.\n", + "\n", + "A summary table:\n", + "\n", + "| Mode | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|\n", + "| Auto | No | No |\n", + "| Explicit (new) | Yes | No |\n", + "| Manual | Yes | Yes |\n", + "\n", + "The current mesh tells us which sharding mode we're in. We can query it with\n", + "`get_abstract_mesh`:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "geptWrdYX0OM", + "outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n" + ] + } + ], + "source": [ + "print(f\"Current mesh is: {get_abstract_mesh()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AQQjzUeGX4P6" + }, + "source": [ + "Since `axis_types=(Explicit, Explicit)`, this means we're in fully-explicit\n", + "mode. Notice that the sharding mode is associated with a mesh _axis_, not the\n", + "mesh as a whole. We can actually mix sharding modes by having a different\n", + "sharding mode for each mesh axis. Shardings (on JAX-level types) can only\n", + "mention _explicit_ mesh axes and collective operations like `psum` can only\n", + "mention _manual_ mesh axes." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AQQjzUeGX4P6" + }, + "source": [ + "## Concrete array shardings can mention `Auto` mesh axis\n", + "\n", + "You can query the sharding of a concrete array `x` with `x.sharding`. You\n", + "might expect the result to be the same as the sharding associated with the\n", + "value's type, `jax.typeof(x).sharding`. It might not be! The concrete array sharding, `x.sharding`, describes the sharding along\n", + "both `Explicit` and `Auto` mesh axes. It's the sharding that the compiler\n", + "eventually chose. Whereas the type-specificed sharding,\n", + "`jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh\n", + "axes. The `Auto` axes are deliberately hidden from the type because they're\n", + "the purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than,\n", + "the type-specified sharding. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ivLl6bxmX7EZ", + "outputId": "6d7b7fce-68b6-47f1-b214-d62bda8d7b6e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit)) ===\n", + "Concrete value sharding: PartitionSpec('X',)\n", + "Type-specified sharding: PartitionSpec('X',)\n", + "=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto)) ===\n", + "Concrete value sharding: PartitionSpec('X',)\n", + "Type-specified sharding: PartitionSpec(None,)\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n", + " -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n", + " -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def compare_shardings(x):\n", + " print(f\"=== with mesh: {get_abstract_mesh()} ===\")\n", + " print(f\"Concrete value sharding: {x.sharding.spec}\")\n", + " print(f\"Type-specified sharding: {jax.typeof(x).sharding.spec}\")\n", + "\n", + "my_array = jnp.sin(reshard(np.arange(8), P(\"X\")))\n", + "compare_shardings(my_array)\n", + "\n", + "@auto_axes\n", + "def check_in_auto_context(x):\n", + " compare_shardings(x)\n", + " return x\n", + "\n", + "check_in_auto_context(my_array, out_shardings=P(\"X\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MRFccsi5X8so" + }, + "source": [ + "Notice that at the top level, where we're currently in a fully `Explicit` mesh\n", + "context, the concrete array sharding and type-specified sharding agree. But\n", + "under the `auto_axes` decorator we're in a fully `Auto` mesh context and the\n", + "two shardings disagree: the type-specified sharding is `P(None)` whereas the\n", + "concrete array sharding is `P(\"X\")` (though it could be anything! It's up to\n", + "the compiler)." + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md new file mode 100644 index 000000000..cd1a580bd --- /dev/null +++ b/docs/notebooks/explicit-sharding.md @@ -0,0 +1,417 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + ++++ {"id": "ZVJCNxUcVkkm"} + +# Explicit sharding (a.k.a. "sharding in types") + ++++ {"id": "ATLBMlw3VcCJ"} + +JAX's traditional automatic sharding leaves sharding decisions to the compiler. +You can provide hints to the compiler using +`jax.lax.with_sharding_constraint` but for the most part you're supposed to be +focussed on the math while the compiler worries about sharding. + +But what if you have a strong opinion about how you want your program sharded? +With enough calls to `with_sharding_constraint` you can probably guide the +compiler's hand to make it do what you want. But "compiler tickling" is +famously not a fun programming model. Where should you put the sharding +constraints? You could put them on every single intermediate but that's a lot +of work and it's also easy to make mistakes that way because there's no way to +check that the shardings make sense together. More commonly, people add just +enough sharding annotations to constrain the compiler. But this is a slow +iterative process. It's hard to know ahead of time what XLA's gSPMD pass will +do (it's a whole-program optimization) so all you can do is add annotations, +inspect XLA's sharding choices to see what happened, and repeat. + +To fix this we've come up with a different style of sharding programming we +call "explicit sharding" or "sharding in types". The idea is that sharding +propagation happens at the JAX level at trace time. Each JAX operation has a +sharding rule that takes the shardings of the op's arguments and produces a +sharding for the op's result. For most operations these rules are simple and +obvious because there's only one reasonable choice. But for some operations it's +unclear how to shard the result. In that case we ask the programmer +to provide an `out_sharding` argument explicitly and we throw a (trace-time) +error otherwise. Since the shardings are propagated at trace time they can +also be _queried_ at trace time too. In the rest of this doc we'll describe +how to use explicit sharding mode. Note that this is a new feature so we +expect there to be bugs and unimplemented cases. Please let us know when you +find something that doesn't work! + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: hVi6mApuVw3r +outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf +--- +import jax +import numpy as np +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P, AxisType, set_mesh +from jax.experimental.shard import reshard, auto_axes +from jax._src.mesh import get_abstract_mesh + +jax.config.update('jax_num_cpu_devices', 8) +``` + ++++ {"id": "oU5O6yOLWqbP"} + +## Setting up an explicit mesh + +The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that +the JAX-level _type_ of a value includes a description of how the value is sharded. +We can query the JAX-level type of any JAX value (or Numpy array, or Python +scalar) using `jax.typeof`: + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: mzDIDvj7Vw0k +outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434 +--- +some_array = np.arange(8) +print(f"JAX-level type of some_array: {jax.typeof(some_array)}") +``` + ++++ {"id": "TZzp_1sXW061"} + +Importantly, we can query the type even while tracing under a `jit` (the JAX-level type +is almost _defined_ as "the information about a value we have access to while +under a jit). + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: IyPx_-IBVwxr +outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499 +--- +@jax.jit +def foo(x): + print(f"JAX-level type of x during tracing: {jax.typeof(x)}") + return x + x + +foo(some_array) +``` + ++++ {"id": "c3gNPzfZW45K"} + +These types show the shape and dtype of array but they don't appear to +show sharding. (Actually, they _did_ show sharding, but the shardings were +trivial. See "Concrete array shardings", below.) To start seeing some +interesting shardings we need to set up an explicit-sharding mesh. We use +`set_mesh` to set it as the current mesh for the remainder of this notebook. +(If you only want to set the mesh for some particular scope and return to the previous +mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.) + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: NO2ulM_QW7a8 +outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb +--- +mesh = jax.make_mesh((2, 4), ("X", "Y"), + axis_types=(AxisType.Explicit, AxisType.Explicit)) +set_mesh(mesh) + +print(f"Current mesh is: {get_abstract_mesh()}") +``` + ++++ {"id": "V7bVz6tzW_Eb"} + +Now we can create some sharded arrays using `reshard`: + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: 1-TzmA0AXCAf +outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71 +--- +replicated_array = np.arange(8).reshape(4, 2) +sharded_array = reshard(replicated_array, P("X", None)) + +print(f"replicated_array type: {jax.typeof(replicated_array)}") +print(f"sharded_array type: {jax.typeof(sharded_array)}") +``` + ++++ {"id": "B0jBBXtgXBxr"} + +We should read the type `f32[4@X, 2]` as "a 4-by-2 array of 32-bit floats whose first dimension +is sharded along mesh axis 'X'. The array is replicated along all other mesh +axes" + ++++ {"id": "N8yMauHAXKtX"} + +These shardings associated with JAX-level types propagate through operations. For example: + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: Gy7ABds3XND3 +outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b +--- +arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None)) +arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y")) + +result = arg0 + arg1 + +print(f"arg0 sharding: {jax.typeof(arg0)}") +print(f"arg1 sharding: {jax.typeof(arg1)}") +print(f"result sharding: {jax.typeof(result)}") +``` + ++++ {"id": "lwsygUmVXPCk"} + +We can do the same type querying under a jit: + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: grCcotr-XQjY +outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a +--- +@jax.jit +def add_arrays(x, y): + ans = x + y + print(f"x sharding: {jax.typeof(x)}") + print(f"y sharding: {jax.typeof(y)}") + print(f"ans sharding: {jax.typeof(ans)}") + return ans + +add_arrays(arg0, arg1) +``` + ++++ {"id": "lVd6a5ufXZoH"} + +That's the gist of it. Shardings propagate deterministically at trace time and +we can query them at trace time. + ++++ {"id": "ETtwK3LCXSkd"} + +## Sharding rules and operations with ambiguous sharding + +Each op has a sharding rule which specifies its output sharding given its +input shardings. A sharding rule may also throw a (trace-time) error. Each op +is free to implement whatever sharding rule it likes, but the usual pattern is +the following: For each output axis we identify zero of more corresponding +input axes. The output axis is then +sharded according to the “consensus” sharding of the corresponding input axes. i.e., it's +`None` if the input shardings are all `None`, and it's the common non-None input sharding +if there’s exactly one of them, or an error (requiring an explicit out_sharding=... kwarg) otherwise. + ++++ {"id": "an8-Fq1uXehp"} + +This procedure is done on an axis-by-axis basis. When it’s done, we might end +up with an array sharding that mentions a mesh axis more than once, which is +illegal. In that case we raise a (trace-time) sharding error and ask for an +explicit out_sharding. + +Here are some example sharding rules: + * nullary ops like `jnp.zeros`, `jnp.arange`: These ops create arrays out of whole + cloth so they don’t have input shardings to propagate. Their output is + unsharded by default unless overridden by the out_sharding kwarg. + * unary elementwise ops like `sin`, `exp`: The output is sharded the same as the + input. + * binary ops (`+`, `-`, `*` etc.): Axis shardings of “zipped” dimensions + must match (or be `None`). “Outer product” dimensions (dimensions that + appear in only one argument) are sharded as they are in the input. If the + result ends up mentioning a mesh axis more than once it's an error. + * `reshape.` Reshape is a particularly tricky op. An output axis can map to more + than one input axis (when reshape is used to merge axes) or just a part + of an input axis (when reshape is used to split axes). Our usual rules + don’t apply. Instead we treat reshape as follows. We strip away singleton + axes (these can’t be sharded anyway. Then + we decide whether the reshape is a “split” (splitting a single axis into + two or more adjacent axes), a “merge” (merging two or more adjacent axes + into a single one) or something else. If we have a split or merge case in + which the split/merged axes are sharded as None then we shard the + resulting split/merged axes as None and the other axes according to their + corresponding input axis shardings. In all other cases we throw an error + and require the user to provide an `out_shardings` argument. + ++++ {"id": "jZMp6w48Xmd7"} + +## JAX transformations and higher-order functions + +The staged-out representation of JAX programs is explicitly typed. (We call +the types “avals” but that’s not important.) In explicit-sharding mode, the +sharding is part of that type. This means that shardings need to match +wherever types need to match. For example, the two sides of a `lax.cond` need to +have results with matching shardings. And the carry of `lax.scan` needs to have the +same sharding at the input and the output of the scan body. And when you +contruct a jaxpr without concrete arguments using `make_jaxpr` you need to +provide shardings too. Certain JAX transformations perform type-level +operations. Automatic differentation constructs a tangent type for each primal +type in the original computation (e.g. `TangentOf(float) == float`, +`TangentOf(int) == float0`). With sharding in the types, this means that tangent +values are sharded in the same way as their primal values. Vmap and scan also +do type-level operations, they lift an array shape to a rank-augmented version +of that shape. That extra array axis needs a sharding. We can infer it from the +arguments to the vmap/scan but they all need to agree. And a nullary vmap/scan +needs an explicit sharding argument just as it needs an explicit length +argument. + ++++ {"id": "ERJx4p0tXoS3"} + +## Working around unimplemented sharding rules using `auto_sharding` + +The implementation of explicit sharding is still a work-in-progress and there +are plenty of ops that are missing sharding rules. For example, `scatter` and +`gather` (i.e. indexing ops). + +Normally we wouldn't suggest using a feature with so many unimplemented cases, +but in this instance there's a reasonable fallback you can use: `auto_axes`. +The idea is that you can temporarily drop into a context where the mesh axes +are "auto" rather than "explicit". You explicitly specify how you intend the +final result of the `auto_axes` to be sharded as it gets returned to the calling context. + +This works as a fallback for ops with unimplemented sharding rules. It also +works when you want to override the sharding-in-types type system. For +example, suppose we want to add a `f32[4@X, 4]` to a `f32[4, 4@X]`. Our +sharding rule for addition would throw an error: the result would need to be +`f32[4@X, 4@X]`, which tries uses a mesh axis twice, which is illegal. But say you +want to perform the operation anyway, and you want the result to be sharded along +the first axis only, like `f32[4@X, 4]`. You can do this as follows: + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: fpFEaMBcXsJG +outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660 +--- +some_x = reshard(np.arange(16).reshape(4, 4), P("X", None)) +some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X")) + +try: + some_x + some_y +except Exception as e: + print("ERROR!") + print(e) + +print("=== try again with auto_axes ===") + +@auto_axes +def add_with_out_sharding_kwarg(x, y): + print(f"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}") + return x + y + +result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P("X", None)) +print(f"Result type: {jax.typeof(result)}") +``` + ++++ {"id": "8-_zDr-AXvb6"} + +## Using a mixture of sharding modes + +JAX now has three styles of parallelism: + + * *Automatic sharding* is where you treat all the devices as a single logical + machine and write a "global view" array program for that machine. The + compiler decides how to partition the data and computation across the + available devices. You can give hints to the compiler using + `with_sharding_constraint`. + * *Explicit Sharding* (\*new\*) is similar to automatic sharding in that + you're writing a global-view program. The difference is that the sharding + of each array is part of the array's JAX-level type making it an explicit + part of the programming model. These shardings are propagated at the JAX + level and queryable at trace time. It's still the compiler's responsibility + to turn the whole-array program into per-device programs (turning `jnp.sum` + into `psum` for example) but the compiler is heavily constrained by the + user-supplied shardings. + * *Manual Sharding* (`shard_map`) is where you write a program from the + perspective of a single device. Communication between devices happens via + explicit collective operations like psum. + +A summary table: + +| Mode | Explicit sharding? | Explicit Collectives? | +|---|---|---| +| Auto | No | No | +| Explicit (new) | Yes | No | +| Manual | Yes | Yes | + +The current mesh tells us which sharding mode we're in. We can query it with +`get_abstract_mesh`: + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: geptWrdYX0OM +outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f +--- +print(f"Current mesh is: {get_abstract_mesh()}") +``` + ++++ {"id": "AQQjzUeGX4P6"} + +Since `axis_types=(Explicit, Explicit)`, this means we're in fully-explicit +mode. Notice that the sharding mode is associated with a mesh _axis_, not the +mesh as a whole. We can actually mix sharding modes by having a different +sharding mode for each mesh axis. Shardings (on JAX-level types) can only +mention _explicit_ mesh axes and collective operations like `psum` can only +mention _manual_ mesh axes. + ++++ {"id": "AQQjzUeGX4P6"} + +## Concrete array shardings can mention `Auto` mesh axis + +You can query the sharding of a concrete array `x` with `x.sharding`. You +might expect the result to be the same as the sharding associated with the +value's type, `jax.typeof(x).sharding`. It might not be! The concrete array sharding, `x.sharding`, describes the sharding along +both `Explicit` and `Auto` mesh axes. It's the sharding that the compiler +eventually chose. Whereas the type-specificed sharding, +`jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh +axes. The `Auto` axes are deliberately hidden from the type because they're +the purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than, +the type-specified sharding. For example: + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: ivLl6bxmX7EZ +outputId: 6d7b7fce-68b6-47f1-b214-d62bda8d7b6e +--- +def compare_shardings(x): + print(f"=== with mesh: {get_abstract_mesh()} ===") + print(f"Concrete value sharding: {x.sharding.spec}") + print(f"Type-specified sharding: {jax.typeof(x).sharding.spec}") + +my_array = jnp.sin(reshard(np.arange(8), P("X"))) +compare_shardings(my_array) + +@auto_axes +def check_in_auto_context(x): + compare_shardings(x) + return x + +check_in_auto_context(my_array, out_shardings=P("X")) +``` + ++++ {"id": "MRFccsi5X8so"} + +Notice that at the top level, where we're currently in a fully `Explicit` mesh +context, the concrete array sharding and type-specified sharding agree. But +under the `auto_axes` decorator we're in a fully `Auto` mesh context and the +two shardings disagree: the type-specified sharding is `P(None)` whereas the +concrete array sharding is `P("X")` (though it could be anything! It's up to +the compiler).