Explicit sharding docs

This commit is contained in:
Dougal 2025-03-13 11:54:52 -04:00
parent 8fbe3b1333
commit e8f43d1cef
4 changed files with 1132 additions and 0 deletions

View File

@ -12,6 +12,7 @@ operations.
:maxdepth: 1
notebooks/Distributed_arrays_and_automatic_parallelization
notebooks/explicit-sharding
notebooks/shard_map
multi_process
distributed_data_loading

View File

@ -222,6 +222,7 @@ nb_execution_excludepatterns = [
'jep/9407-type-promotion.*',
# TODO(jakevdp): enable execution on the following if possible:
'notebooks/Distributed_arrays_and_automatic_parallelization.*',
'notebooks/explicit-sharding.*',
'notebooks/autodiff_remat.*',
# Fails on readthedocs with Kernel Died
'notebooks/convolutions.ipynb',

View File

@ -0,0 +1,713 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ZVJCNxUcVkkm"
},
"source": [
"# Explicit sharding (a.k.a. \"sharding in types\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ATLBMlw3VcCJ"
},
"source": [
"JAX's traditional automatic sharding leaves sharding decisions to the compiler.\n",
"You can provide hints to the compiler using\n",
"`jax.lax.with_sharding_constraint` but for the most part you're supposed to be\n",
"focussed on the math while the compiler worries about sharding.\n",
"\n",
"But what if you have a strong opinion about how you want your program sharded?\n",
"With enough calls to `with_sharding_constraint` you can probably guide the\n",
"compiler's hand to make it do what you want. But \"compiler tickling\" is\n",
"famously not a fun programming model. Where should you put the sharding\n",
"constraints? You could put them on every single intermediate but that's a lot\n",
"of work and it's also easy to make mistakes that way because there's no way to\n",
"check that the shardings make sense together. More commonly, people add just\n",
"enough sharding annotations to constrain the compiler. But this is a slow\n",
"iterative process. It's hard to know ahead of time what XLA's gSPMD pass will\n",
"do (it's a whole-program optimization) so all you can do is add annotations,\n",
"inspect XLA's sharding choices to see what happened, and repeat.\n",
"\n",
"To fix this we've come up with a different style of sharding programming we\n",
"call \"explicit sharding\" or \"sharding in types\". The idea is that sharding\n",
"propagation happens at the JAX level at trace time. Each JAX operation has a\n",
"sharding rule that takes the shardings of the op's arguments and produces a\n",
"sharding for the op's result. For most operations these rules are simple and\n",
"obvious because there's only one reasonable choice. But for some operations it's\n",
"unclear how to shard the result. In that case we ask the programmer\n",
"to provide an `out_sharding` argument explicitly and we throw a (trace-time)\n",
"error otherwise. Since the shardings are propagated at trace time they can\n",
"also be _queried_ at trace time too. In the rest of this doc we'll describe\n",
"how to use explicit sharding mode. Note that this is a new feature so we\n",
"expect there to be bugs and unimplemented cases. Please let us know when you\n",
"find something that doesn't work!"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hVi6mApuVw3r",
"outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf"
},
"outputs": [],
"source": [
"import jax\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"from jax.sharding import PartitionSpec as P, AxisType, set_mesh\n",
"from jax.experimental.shard import reshard, auto_axes\n",
"from jax._src.mesh import get_abstract_mesh\n",
"\n",
"jax.config.update('jax_num_cpu_devices', 8)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oU5O6yOLWqbP"
},
"source": [
"## Setting up an explicit mesh\n",
"\n",
"The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n",
"the JAX-level _type_ of a value includes a description of how the value is sharded.\n",
"We can query the JAX-level type of any JAX value (or Numpy array, or Python\n",
"scalar) using `jax.typeof`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mzDIDvj7Vw0k",
"outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX-level type of some_array: ShapedArray(int32[8])\n"
]
}
],
"source": [
"some_array = np.arange(8)\n",
"print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TZzp_1sXW061"
},
"source": [
"Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n",
"is almost _defined_ as \"the information about a value we have access to while\n",
"under a jit)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IyPx_-IBVwxr",
"outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX-level type of x during tracing: ShapedArray(int32[8])\n"
]
},
{
"data": {
"text/plain": [
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@jax.jit\n",
"def foo(x):\n",
" print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n",
" return x + x\n",
"\n",
"foo(some_array)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c3gNPzfZW45K"
},
"source": [
"These types show the shape and dtype of array but they don't appear to\n",
"show sharding. (Actually, they _did_ show sharding, but the shardings were\n",
"trivial. See \"Concrete array shardings\", below.) To start seeing some\n",
"interesting shardings we need to set up an explicit-sharding mesh. We use\n",
"`set_mesh` to set it as the current mesh for the remainder of this notebook.\n",
"(If you only want to set the mesh for some particular scope and return to the previous\n",
"mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NO2ulM_QW7a8",
"outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n"
]
}
],
"source": [
"mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n",
" axis_types=(AxisType.Explicit, AxisType.Explicit))\n",
"set_mesh(mesh)\n",
"\n",
"print(f\"Current mesh is: {get_abstract_mesh()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V7bVz6tzW_Eb"
},
"source": [
"Now we can create some sharded arrays using `reshard`:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1-TzmA0AXCAf",
"outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"replicated_array type: ShapedArray(int32[4,2])\n",
"sharded_array type: ShapedArray(int32[4@X,2])\n"
]
}
],
"source": [
"replicated_array = np.arange(8).reshape(4, 2)\n",
"sharded_array = reshard(replicated_array, P(\"X\", None))\n",
"\n",
"print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n",
"print(f\"sharded_array type: {jax.typeof(sharded_array)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B0jBBXtgXBxr"
},
"source": [
"We should read the type `f32[4@X, 2]` as \"a 4-by-2 array of 32-bit floats whose first dimension\n",
"is sharded along mesh axis 'X'. The array is replicated along all other mesh\n",
"axes\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N8yMauHAXKtX"
},
"source": [
"These shardings associated with JAX-level types propagate through operations. For example:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Gy7ABds3XND3",
"outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"arg0 sharding: ShapedArray(int32[4@X,1])\n",
"arg1 sharding: ShapedArray(int32[1,8@Y])\n",
"result sharding: ShapedArray(int32[4@X,8@Y])\n"
]
}
],
"source": [
"arg0 = reshard(np.arange(4).reshape(4, 1), P(\"X\", None))\n",
"arg1 = reshard(np.arange(8).reshape(1, 8), P(None, \"Y\"))\n",
"\n",
"result = arg0 + arg1\n",
"\n",
"print(f\"arg0 sharding: {jax.typeof(arg0)}\")\n",
"print(f\"arg1 sharding: {jax.typeof(arg1)}\")\n",
"print(f\"result sharding: {jax.typeof(result)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lwsygUmVXPCk"
},
"source": [
"We can do the same type querying under a jit:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "grCcotr-XQjY",
"outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x sharding: ShapedArray(int32[4@X,1])\n",
"y sharding: ShapedArray(int32[1,8@Y])\n",
"ans sharding: ShapedArray(int32[4@X,8@Y])\n"
]
},
{
"data": {
"text/plain": [
"Array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n",
" [ 1, 2, 3, 4, 5, 6, 7, 8],\n",
" [ 2, 3, 4, 5, 6, 7, 8, 9],\n",
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@jax.jit\n",
"def add_arrays(x, y):\n",
" ans = x + y\n",
" print(f\"x sharding: {jax.typeof(x)}\")\n",
" print(f\"y sharding: {jax.typeof(y)}\")\n",
" print(f\"ans sharding: {jax.typeof(ans)}\")\n",
" return ans\n",
"\n",
"add_arrays(arg0, arg1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lVd6a5ufXZoH"
},
"source": [
"That's the gist of it. Shardings propagate deterministically at trace time and\n",
"we can query them at trace time."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ETtwK3LCXSkd"
},
"source": [
"## Sharding rules and operations with ambiguous sharding\n",
"\n",
"Each op has a sharding rule which specifies its output sharding given its\n",
"input shardings. A sharding rule may also throw a (trace-time) error. Each op\n",
"is free to implement whatever sharding rule it likes, but the usual pattern is\n",
"the following: For each output axis we identify zero of more corresponding\n",
"input axes. The output axis is then\n",
"sharded according to the “consensus” sharding of the corresponding input axes. i.e., it's\n",
"`None` if the input shardings are all `None`, and it's the common non-None input sharding\n",
"if theres exactly one of them, or an error (requiring an explicit out_sharding=... kwarg) otherwise."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "an8-Fq1uXehp"
},
"source": [
"This procedure is done on an axis-by-axis basis. When its done, we might end\n",
"up with an array sharding that mentions a mesh axis more than once, which is\n",
"illegal. In that case we raise a (trace-time) sharding error and ask for an\n",
"explicit out_sharding.\n",
"\n",
"Here are some example sharding rules:\n",
" * nullary ops like `jnp.zeros`, `jnp.arange`: These ops create arrays out of whole\n",
" cloth so they dont have input shardings to propagate. Their output is\n",
" unsharded by default unless overridden by the out_sharding kwarg.\n",
" * unary elementwise ops like `sin`, `exp`: The output is sharded the same as the\n",
" input.\n",
" * binary ops (`+`, `-`, `*` etc.): Axis shardings of “zipped” dimensions\n",
" must match (or be `None`). “Outer product” dimensions (dimensions that\n",
" appear in only one argument) are sharded as they are in the input. If the\n",
" result ends up mentioning a mesh axis more than once it's an error.\n",
" * `reshape.` Reshape is a particularly tricky op. An output axis can map to more\n",
" than one input axis (when reshape is used to merge axes) or just a part\n",
" of an input axis (when reshape is used to split axes). Our usual rules\n",
" dont apply. Instead we treat reshape as follows. We strip away singleton\n",
" axes (these cant be sharded anyway. Then\n",
" we decide whether the reshape is a “split” (splitting a single axis into\n",
" two or more adjacent axes), a “merge” (merging two or more adjacent axes\n",
" into a single one) or something else. If we have a split or merge case in\n",
" which the split/merged axes are sharded as None then we shard the\n",
" resulting split/merged axes as None and the other axes according to their\n",
" corresponding input axis shardings. In all other cases we throw an error\n",
" and require the user to provide an `out_shardings` argument."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jZMp6w48Xmd7"
},
"source": [
"## JAX transformations and higher-order functions\n",
"\n",
"The staged-out representation of JAX programs is explicitly typed. (We call\n",
"the types “avals” but thats not important.) In explicit-sharding mode, the\n",
"sharding is part of that type. This means that shardings need to match\n",
"wherever types need to match. For example, the two sides of a `lax.cond` need to\n",
"have results with matching shardings. And the carry of `lax.scan` needs to have the\n",
"same sharding at the input and the output of the scan body. And when you\n",
"contruct a jaxpr without concrete arguments using `make_jaxpr` you need to\n",
"provide shardings too. Certain JAX transformations perform type-level\n",
"operations. Automatic differentation constructs a tangent type for each primal\n",
"type in the original computation (e.g. `TangentOf(float) == float`,\n",
"`TangentOf(int) == float0`). With sharding in the types, this means that tangent\n",
"values are sharded in the same way as their primal values. Vmap and scan also\n",
"do type-level operations, they lift an array shape to a rank-augmented version\n",
"of that shape. That extra array axis needs a sharding. We can infer it from the\n",
"arguments to the vmap/scan but they all need to agree. And a nullary vmap/scan\n",
"needs an explicit sharding argument just as it needs an explicit length\n",
"argument."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ERJx4p0tXoS3"
},
"source": [
"## Working around unimplemented sharding rules using `auto_sharding`\n",
"\n",
"The implementation of explicit sharding is still a work-in-progress and there\n",
"are plenty of ops that are missing sharding rules. For example, `scatter` and\n",
"`gather` (i.e. indexing ops).\n",
"\n",
"Normally we wouldn't suggest using a feature with so many unimplemented cases,\n",
"but in this instance there's a reasonable fallback you can use: `auto_axes`.\n",
"The idea is that you can temporarily drop into a context where the mesh axes\n",
"are \"auto\" rather than \"explicit\". You explicitly specify how you intend the\n",
"final result of the `auto_axes` to be sharded as it gets returned to the calling context.\n",
"\n",
"This works as a fallback for ops with unimplemented sharding rules. It also\n",
"works when you want to override the sharding-in-types type system. For\n",
"example, suppose we want to add a `f32[4@X, 4]` to a `f32[4, 4@X]`. Our\n",
"sharding rule for addition would throw an error: the result would need to be\n",
"`f32[4@X, 4@X]`, which tries uses a mesh axis twice, which is illegal. But say you\n",
"want to perform the operation anyway, and you want the result to be sharded along\n",
"the first axis only, like `f32[4@X, 4]`. You can do this as follows:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fpFEaMBcXsJG",
"outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ERROR!\n",
"add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]\n",
"=== try again with auto_axes ===\n",
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
"Result type: ShapedArray(int32[4@X,4])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result type: ShapedArray(int32[4@X,4])\n"
]
}
],
"source": [
"some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", None))\n",
"some_y = reshard(np.arange(16).reshape(4, 4), P(None, \"X\"))\n",
"\n",
"try:\n",
" some_x + some_y\n",
"except Exception as e:\n",
" print(\"ERROR!\")\n",
" print(e)\n",
"\n",
"print(\"=== try again with auto_axes ===\")\n",
"\n",
"@auto_axes\n",
"def add_with_out_sharding_kwarg(x, y):\n",
" print(f\"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}\")\n",
" return x + y\n",
"\n",
"result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P(\"X\", None))\n",
"print(f\"Result type: {jax.typeof(result)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8-_zDr-AXvb6"
},
"source": [
"## Using a mixture of sharding modes\n",
"\n",
"JAX now has three styles of parallelism:\n",
"\n",
" * *Automatic sharding* is where you treat all the devices as a single logical\n",
" machine and write a \"global view\" array program for that machine. The\n",
" compiler decides how to partition the data and computation across the\n",
" available devices. You can give hints to the compiler using\n",
" `with_sharding_constraint`.\n",
" * *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n",
" you're writing a global-view program. The difference is that the sharding\n",
" of each array is part of the array's JAX-level type making it an explicit\n",
" part of the programming model. These shardings are propagated at the JAX\n",
" level and queryable at trace time. It's still the compiler's responsibility\n",
" to turn the whole-array program into per-device programs (turning `jnp.sum`\n",
" into `psum` for example) but the compiler is heavily constrained by the\n",
" user-supplied shardings.\n",
" * *Manual Sharding* (`shard_map`) is where you write a program from the\n",
" perspective of a single device. Communication between devices happens via\n",
" explicit collective operations like psum.\n",
"\n",
"A summary table:\n",
"\n",
"| Mode | Explicit sharding? | Explicit Collectives? |\n",
"|---|---|---|\n",
"| Auto | No | No |\n",
"| Explicit (new) | Yes | No |\n",
"| Manual | Yes | Yes |\n",
"\n",
"The current mesh tells us which sharding mode we're in. We can query it with\n",
"`get_abstract_mesh`:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "geptWrdYX0OM",
"outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n"
]
}
],
"source": [
"print(f\"Current mesh is: {get_abstract_mesh()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AQQjzUeGX4P6"
},
"source": [
"Since `axis_types=(Explicit, Explicit)`, this means we're in fully-explicit\n",
"mode. Notice that the sharding mode is associated with a mesh _axis_, not the\n",
"mesh as a whole. We can actually mix sharding modes by having a different\n",
"sharding mode for each mesh axis. Shardings (on JAX-level types) can only\n",
"mention _explicit_ mesh axes and collective operations like `psum` can only\n",
"mention _manual_ mesh axes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AQQjzUeGX4P6"
},
"source": [
"## Concrete array shardings can mention `Auto` mesh axis\n",
"\n",
"You can query the sharding of a concrete array `x` with `x.sharding`. You\n",
"might expect the result to be the same as the sharding associated with the\n",
"value's type, `jax.typeof(x).sharding`. It might not be! The concrete array sharding, `x.sharding`, describes the sharding along\n",
"both `Explicit` and `Auto` mesh axes. It's the sharding that the compiler\n",
"eventually chose. Whereas the type-specificed sharding,\n",
"`jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh\n",
"axes. The `Auto` axes are deliberately hidden from the type because they're\n",
"the purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than,\n",
"the type-specified sharding. For example:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ivLl6bxmX7EZ",
"outputId": "6d7b7fce-68b6-47f1-b214-d62bda8d7b6e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit)) ===\n",
"Concrete value sharding: PartitionSpec('X',)\n",
"Type-specified sharding: PartitionSpec('X',)\n",
"=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto)) ===\n",
"Concrete value sharding: PartitionSpec('X',)\n",
"Type-specified sharding: PartitionSpec(None,)\n"
]
},
{
"data": {
"text/plain": [
"Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n",
" -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n",
" -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def compare_shardings(x):\n",
" print(f\"=== with mesh: {get_abstract_mesh()} ===\")\n",
" print(f\"Concrete value sharding: {x.sharding.spec}\")\n",
" print(f\"Type-specified sharding: {jax.typeof(x).sharding.spec}\")\n",
"\n",
"my_array = jnp.sin(reshard(np.arange(8), P(\"X\")))\n",
"compare_shardings(my_array)\n",
"\n",
"@auto_axes\n",
"def check_in_auto_context(x):\n",
" compare_shardings(x)\n",
" return x\n",
"\n",
"check_in_auto_context(my_array, out_shardings=P(\"X\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MRFccsi5X8so"
},
"source": [
"Notice that at the top level, where we're currently in a fully `Explicit` mesh\n",
"context, the concrete array sharding and type-specified sharding agree. But\n",
"under the `auto_axes` decorator we're in a fully `Auto` mesh context and the\n",
"two shardings disagree: the type-specified sharding is `P(None)` whereas the\n",
"concrete array sharding is `P(\"X\")` (though it could be anything! It's up to\n",
"the compiler)."
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,417 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---
+++ {"id": "ZVJCNxUcVkkm"}
# Explicit sharding (a.k.a. "sharding in types")
+++ {"id": "ATLBMlw3VcCJ"}
JAX's traditional automatic sharding leaves sharding decisions to the compiler.
You can provide hints to the compiler using
`jax.lax.with_sharding_constraint` but for the most part you're supposed to be
focussed on the math while the compiler worries about sharding.
But what if you have a strong opinion about how you want your program sharded?
With enough calls to `with_sharding_constraint` you can probably guide the
compiler's hand to make it do what you want. But "compiler tickling" is
famously not a fun programming model. Where should you put the sharding
constraints? You could put them on every single intermediate but that's a lot
of work and it's also easy to make mistakes that way because there's no way to
check that the shardings make sense together. More commonly, people add just
enough sharding annotations to constrain the compiler. But this is a slow
iterative process. It's hard to know ahead of time what XLA's gSPMD pass will
do (it's a whole-program optimization) so all you can do is add annotations,
inspect XLA's sharding choices to see what happened, and repeat.
To fix this we've come up with a different style of sharding programming we
call "explicit sharding" or "sharding in types". The idea is that sharding
propagation happens at the JAX level at trace time. Each JAX operation has a
sharding rule that takes the shardings of the op's arguments and produces a
sharding for the op's result. For most operations these rules are simple and
obvious because there's only one reasonable choice. But for some operations it's
unclear how to shard the result. In that case we ask the programmer
to provide an `out_sharding` argument explicitly and we throw a (trace-time)
error otherwise. Since the shardings are propagated at trace time they can
also be _queried_ at trace time too. In the rest of this doc we'll describe
how to use explicit sharding mode. Note that this is a new feature so we
expect there to be bugs and unimplemented cases. Please let us know when you
find something that doesn't work!
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: hVi6mApuVw3r
outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
---
import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, AxisType, set_mesh
from jax.experimental.shard import reshard, auto_axes
from jax._src.mesh import get_abstract_mesh
jax.config.update('jax_num_cpu_devices', 8)
```
+++ {"id": "oU5O6yOLWqbP"}
## Setting up an explicit mesh
The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that
the JAX-level _type_ of a value includes a description of how the value is sharded.
We can query the JAX-level type of any JAX value (or Numpy array, or Python
scalar) using `jax.typeof`:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: mzDIDvj7Vw0k
outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434
---
some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
```
+++ {"id": "TZzp_1sXW061"}
Importantly, we can query the type even while tracing under a `jit` (the JAX-level type
is almost _defined_ as "the information about a value we have access to while
under a jit).
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: IyPx_-IBVwxr
outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499
---
@jax.jit
def foo(x):
print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
return x + x
foo(some_array)
```
+++ {"id": "c3gNPzfZW45K"}
These types show the shape and dtype of array but they don't appear to
show sharding. (Actually, they _did_ show sharding, but the shardings were
trivial. See "Concrete array shardings", below.) To start seeing some
interesting shardings we need to set up an explicit-sharding mesh. We use
`set_mesh` to set it as the current mesh for the remainder of this notebook.
(If you only want to set the mesh for some particular scope and return to the previous
mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.)
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: NO2ulM_QW7a8
outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb
---
mesh = jax.make_mesh((2, 4), ("X", "Y"),
axis_types=(AxisType.Explicit, AxisType.Explicit))
set_mesh(mesh)
print(f"Current mesh is: {get_abstract_mesh()}")
```
+++ {"id": "V7bVz6tzW_Eb"}
Now we can create some sharded arrays using `reshard`:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: 1-TzmA0AXCAf
outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71
---
replicated_array = np.arange(8).reshape(4, 2)
sharded_array = reshard(replicated_array, P("X", None))
print(f"replicated_array type: {jax.typeof(replicated_array)}")
print(f"sharded_array type: {jax.typeof(sharded_array)}")
```
+++ {"id": "B0jBBXtgXBxr"}
We should read the type `f32[4@X, 2]` as "a 4-by-2 array of 32-bit floats whose first dimension
is sharded along mesh axis 'X'. The array is replicated along all other mesh
axes"
+++ {"id": "N8yMauHAXKtX"}
These shardings associated with JAX-level types propagate through operations. For example:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: Gy7ABds3XND3
outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b
---
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
result = arg0 + arg1
print(f"arg0 sharding: {jax.typeof(arg0)}")
print(f"arg1 sharding: {jax.typeof(arg1)}")
print(f"result sharding: {jax.typeof(result)}")
```
+++ {"id": "lwsygUmVXPCk"}
We can do the same type querying under a jit:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: grCcotr-XQjY
outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a
---
@jax.jit
def add_arrays(x, y):
ans = x + y
print(f"x sharding: {jax.typeof(x)}")
print(f"y sharding: {jax.typeof(y)}")
print(f"ans sharding: {jax.typeof(ans)}")
return ans
add_arrays(arg0, arg1)
```
+++ {"id": "lVd6a5ufXZoH"}
That's the gist of it. Shardings propagate deterministically at trace time and
we can query them at trace time.
+++ {"id": "ETtwK3LCXSkd"}
## Sharding rules and operations with ambiguous sharding
Each op has a sharding rule which specifies its output sharding given its
input shardings. A sharding rule may also throw a (trace-time) error. Each op
is free to implement whatever sharding rule it likes, but the usual pattern is
the following: For each output axis we identify zero of more corresponding
input axes. The output axis is then
sharded according to the “consensus” sharding of the corresponding input axes. i.e., it's
`None` if the input shardings are all `None`, and it's the common non-None input sharding
if theres exactly one of them, or an error (requiring an explicit out_sharding=... kwarg) otherwise.
+++ {"id": "an8-Fq1uXehp"}
This procedure is done on an axis-by-axis basis. When its done, we might end
up with an array sharding that mentions a mesh axis more than once, which is
illegal. In that case we raise a (trace-time) sharding error and ask for an
explicit out_sharding.
Here are some example sharding rules:
* nullary ops like `jnp.zeros`, `jnp.arange`: These ops create arrays out of whole
cloth so they dont have input shardings to propagate. Their output is
unsharded by default unless overridden by the out_sharding kwarg.
* unary elementwise ops like `sin`, `exp`: The output is sharded the same as the
input.
* binary ops (`+`, `-`, `*` etc.): Axis shardings of “zipped” dimensions
must match (or be `None`). “Outer product” dimensions (dimensions that
appear in only one argument) are sharded as they are in the input. If the
result ends up mentioning a mesh axis more than once it's an error.
* `reshape.` Reshape is a particularly tricky op. An output axis can map to more
than one input axis (when reshape is used to merge axes) or just a part
of an input axis (when reshape is used to split axes). Our usual rules
dont apply. Instead we treat reshape as follows. We strip away singleton
axes (these cant be sharded anyway. Then
we decide whether the reshape is a “split” (splitting a single axis into
two or more adjacent axes), a “merge” (merging two or more adjacent axes
into a single one) or something else. If we have a split or merge case in
which the split/merged axes are sharded as None then we shard the
resulting split/merged axes as None and the other axes according to their
corresponding input axis shardings. In all other cases we throw an error
and require the user to provide an `out_shardings` argument.
+++ {"id": "jZMp6w48Xmd7"}
## JAX transformations and higher-order functions
The staged-out representation of JAX programs is explicitly typed. (We call
the types “avals” but thats not important.) In explicit-sharding mode, the
sharding is part of that type. This means that shardings need to match
wherever types need to match. For example, the two sides of a `lax.cond` need to
have results with matching shardings. And the carry of `lax.scan` needs to have the
same sharding at the input and the output of the scan body. And when you
contruct a jaxpr without concrete arguments using `make_jaxpr` you need to
provide shardings too. Certain JAX transformations perform type-level
operations. Automatic differentation constructs a tangent type for each primal
type in the original computation (e.g. `TangentOf(float) == float`,
`TangentOf(int) == float0`). With sharding in the types, this means that tangent
values are sharded in the same way as their primal values. Vmap and scan also
do type-level operations, they lift an array shape to a rank-augmented version
of that shape. That extra array axis needs a sharding. We can infer it from the
arguments to the vmap/scan but they all need to agree. And a nullary vmap/scan
needs an explicit sharding argument just as it needs an explicit length
argument.
+++ {"id": "ERJx4p0tXoS3"}
## Working around unimplemented sharding rules using `auto_sharding`
The implementation of explicit sharding is still a work-in-progress and there
are plenty of ops that are missing sharding rules. For example, `scatter` and
`gather` (i.e. indexing ops).
Normally we wouldn't suggest using a feature with so many unimplemented cases,
but in this instance there's a reasonable fallback you can use: `auto_axes`.
The idea is that you can temporarily drop into a context where the mesh axes
are "auto" rather than "explicit". You explicitly specify how you intend the
final result of the `auto_axes` to be sharded as it gets returned to the calling context.
This works as a fallback for ops with unimplemented sharding rules. It also
works when you want to override the sharding-in-types type system. For
example, suppose we want to add a `f32[4@X, 4]` to a `f32[4, 4@X]`. Our
sharding rule for addition would throw an error: the result would need to be
`f32[4@X, 4@X]`, which tries uses a mesh axis twice, which is illegal. But say you
want to perform the operation anyway, and you want the result to be sharded along
the first axis only, like `f32[4@X, 4]`. You can do this as follows:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: fpFEaMBcXsJG
outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660
---
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
try:
some_x + some_y
except Exception as e:
print("ERROR!")
print(e)
print("=== try again with auto_axes ===")
@auto_axes
def add_with_out_sharding_kwarg(x, y):
print(f"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}")
return x + y
result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P("X", None))
print(f"Result type: {jax.typeof(result)}")
```
+++ {"id": "8-_zDr-AXvb6"}
## Using a mixture of sharding modes
JAX now has three styles of parallelism:
* *Automatic sharding* is where you treat all the devices as a single logical
machine and write a "global view" array program for that machine. The
compiler decides how to partition the data and computation across the
available devices. You can give hints to the compiler using
`with_sharding_constraint`.
* *Explicit Sharding* (\*new\*) is similar to automatic sharding in that
you're writing a global-view program. The difference is that the sharding
of each array is part of the array's JAX-level type making it an explicit
part of the programming model. These shardings are propagated at the JAX
level and queryable at trace time. It's still the compiler's responsibility
to turn the whole-array program into per-device programs (turning `jnp.sum`
into `psum` for example) but the compiler is heavily constrained by the
user-supplied shardings.
* *Manual Sharding* (`shard_map`) is where you write a program from the
perspective of a single device. Communication between devices happens via
explicit collective operations like psum.
A summary table:
| Mode | Explicit sharding? | Explicit Collectives? |
|---|---|---|
| Auto | No | No |
| Explicit (new) | Yes | No |
| Manual | Yes | Yes |
The current mesh tells us which sharding mode we're in. We can query it with
`get_abstract_mesh`:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: geptWrdYX0OM
outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f
---
print(f"Current mesh is: {get_abstract_mesh()}")
```
+++ {"id": "AQQjzUeGX4P6"}
Since `axis_types=(Explicit, Explicit)`, this means we're in fully-explicit
mode. Notice that the sharding mode is associated with a mesh _axis_, not the
mesh as a whole. We can actually mix sharding modes by having a different
sharding mode for each mesh axis. Shardings (on JAX-level types) can only
mention _explicit_ mesh axes and collective operations like `psum` can only
mention _manual_ mesh axes.
+++ {"id": "AQQjzUeGX4P6"}
## Concrete array shardings can mention `Auto` mesh axis
You can query the sharding of a concrete array `x` with `x.sharding`. You
might expect the result to be the same as the sharding associated with the
value's type, `jax.typeof(x).sharding`. It might not be! The concrete array sharding, `x.sharding`, describes the sharding along
both `Explicit` and `Auto` mesh axes. It's the sharding that the compiler
eventually chose. Whereas the type-specificed sharding,
`jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh
axes. The `Auto` axes are deliberately hidden from the type because they're
the purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than,
the type-specified sharding. For example:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: ivLl6bxmX7EZ
outputId: 6d7b7fce-68b6-47f1-b214-d62bda8d7b6e
---
def compare_shardings(x):
print(f"=== with mesh: {get_abstract_mesh()} ===")
print(f"Concrete value sharding: {x.sharding.spec}")
print(f"Type-specified sharding: {jax.typeof(x).sharding.spec}")
my_array = jnp.sin(reshard(np.arange(8), P("X")))
compare_shardings(my_array)
@auto_axes
def check_in_auto_context(x):
compare_shardings(x)
return x
check_in_auto_context(my_array, out_shardings=P("X"))
```
+++ {"id": "MRFccsi5X8so"}
Notice that at the top level, where we're currently in a fully `Explicit` mesh
context, the concrete array sharding and type-specified sharding agree. But
under the `auto_axes` decorator we're in a fully `Auto` mesh context and the
two shardings disagree: the type-specified sharding is `P(None)` whereas the
concrete array sharding is `P("X")` (though it could be anything! It's up to
the compiler).