mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Explicit sharding docs
This commit is contained in:
parent
8fbe3b1333
commit
e8f43d1cef
@ -12,6 +12,7 @@ operations.
|
||||
:maxdepth: 1
|
||||
|
||||
notebooks/Distributed_arrays_and_automatic_parallelization
|
||||
notebooks/explicit-sharding
|
||||
notebooks/shard_map
|
||||
multi_process
|
||||
distributed_data_loading
|
||||
|
@ -222,6 +222,7 @@ nb_execution_excludepatterns = [
|
||||
'jep/9407-type-promotion.*',
|
||||
# TODO(jakevdp): enable execution on the following if possible:
|
||||
'notebooks/Distributed_arrays_and_automatic_parallelization.*',
|
||||
'notebooks/explicit-sharding.*',
|
||||
'notebooks/autodiff_remat.*',
|
||||
# Fails on readthedocs with Kernel Died
|
||||
'notebooks/convolutions.ipynb',
|
||||
|
713
docs/notebooks/explicit-sharding.ipynb
Normal file
713
docs/notebooks/explicit-sharding.ipynb
Normal file
@ -0,0 +1,713 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ZVJCNxUcVkkm"
|
||||
},
|
||||
"source": [
|
||||
"# Explicit sharding (a.k.a. \"sharding in types\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ATLBMlw3VcCJ"
|
||||
},
|
||||
"source": [
|
||||
"JAX's traditional automatic sharding leaves sharding decisions to the compiler.\n",
|
||||
"You can provide hints to the compiler using\n",
|
||||
"`jax.lax.with_sharding_constraint` but for the most part you're supposed to be\n",
|
||||
"focussed on the math while the compiler worries about sharding.\n",
|
||||
"\n",
|
||||
"But what if you have a strong opinion about how you want your program sharded?\n",
|
||||
"With enough calls to `with_sharding_constraint` you can probably guide the\n",
|
||||
"compiler's hand to make it do what you want. But \"compiler tickling\" is\n",
|
||||
"famously not a fun programming model. Where should you put the sharding\n",
|
||||
"constraints? You could put them on every single intermediate but that's a lot\n",
|
||||
"of work and it's also easy to make mistakes that way because there's no way to\n",
|
||||
"check that the shardings make sense together. More commonly, people add just\n",
|
||||
"enough sharding annotations to constrain the compiler. But this is a slow\n",
|
||||
"iterative process. It's hard to know ahead of time what XLA's gSPMD pass will\n",
|
||||
"do (it's a whole-program optimization) so all you can do is add annotations,\n",
|
||||
"inspect XLA's sharding choices to see what happened, and repeat.\n",
|
||||
"\n",
|
||||
"To fix this we've come up with a different style of sharding programming we\n",
|
||||
"call \"explicit sharding\" or \"sharding in types\". The idea is that sharding\n",
|
||||
"propagation happens at the JAX level at trace time. Each JAX operation has a\n",
|
||||
"sharding rule that takes the shardings of the op's arguments and produces a\n",
|
||||
"sharding for the op's result. For most operations these rules are simple and\n",
|
||||
"obvious because there's only one reasonable choice. But for some operations it's\n",
|
||||
"unclear how to shard the result. In that case we ask the programmer\n",
|
||||
"to provide an `out_sharding` argument explicitly and we throw a (trace-time)\n",
|
||||
"error otherwise. Since the shardings are propagated at trace time they can\n",
|
||||
"also be _queried_ at trace time too. In the rest of this doc we'll describe\n",
|
||||
"how to use explicit sharding mode. Note that this is a new feature so we\n",
|
||||
"expect there to be bugs and unimplemented cases. Please let us know when you\n",
|
||||
"find something that doesn't work!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "hVi6mApuVw3r",
|
||||
"outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"import numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax.sharding import PartitionSpec as P, AxisType, set_mesh\n",
|
||||
"from jax.experimental.shard import reshard, auto_axes\n",
|
||||
"from jax._src.mesh import get_abstract_mesh\n",
|
||||
"\n",
|
||||
"jax.config.update('jax_num_cpu_devices', 8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "oU5O6yOLWqbP"
|
||||
},
|
||||
"source": [
|
||||
"## Setting up an explicit mesh\n",
|
||||
"\n",
|
||||
"The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n",
|
||||
"the JAX-level _type_ of a value includes a description of how the value is sharded.\n",
|
||||
"We can query the JAX-level type of any JAX value (or Numpy array, or Python\n",
|
||||
"scalar) using `jax.typeof`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "mzDIDvj7Vw0k",
|
||||
"outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"JAX-level type of some_array: ShapedArray(int32[8])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"some_array = np.arange(8)\n",
|
||||
"print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "TZzp_1sXW061"
|
||||
},
|
||||
"source": [
|
||||
"Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n",
|
||||
"is almost _defined_ as \"the information about a value we have access to while\n",
|
||||
"under a jit)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "IyPx_-IBVwxr",
|
||||
"outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"JAX-level type of x during tracing: ShapedArray(int32[8])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def foo(x):\n",
|
||||
" print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n",
|
||||
" return x + x\n",
|
||||
"\n",
|
||||
"foo(some_array)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "c3gNPzfZW45K"
|
||||
},
|
||||
"source": [
|
||||
"These types show the shape and dtype of array but they don't appear to\n",
|
||||
"show sharding. (Actually, they _did_ show sharding, but the shardings were\n",
|
||||
"trivial. See \"Concrete array shardings\", below.) To start seeing some\n",
|
||||
"interesting shardings we need to set up an explicit-sharding mesh. We use\n",
|
||||
"`set_mesh` to set it as the current mesh for the remainder of this notebook.\n",
|
||||
"(If you only want to set the mesh for some particular scope and return to the previous\n",
|
||||
"mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "NO2ulM_QW7a8",
|
||||
"outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n",
|
||||
" axis_types=(AxisType.Explicit, AxisType.Explicit))\n",
|
||||
"set_mesh(mesh)\n",
|
||||
"\n",
|
||||
"print(f\"Current mesh is: {get_abstract_mesh()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "V7bVz6tzW_Eb"
|
||||
},
|
||||
"source": [
|
||||
"Now we can create some sharded arrays using `reshard`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "1-TzmA0AXCAf",
|
||||
"outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"replicated_array type: ShapedArray(int32[4,2])\n",
|
||||
"sharded_array type: ShapedArray(int32[4@X,2])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"replicated_array = np.arange(8).reshape(4, 2)\n",
|
||||
"sharded_array = reshard(replicated_array, P(\"X\", None))\n",
|
||||
"\n",
|
||||
"print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n",
|
||||
"print(f\"sharded_array type: {jax.typeof(sharded_array)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "B0jBBXtgXBxr"
|
||||
},
|
||||
"source": [
|
||||
"We should read the type `f32[4@X, 2]` as \"a 4-by-2 array of 32-bit floats whose first dimension\n",
|
||||
"is sharded along mesh axis 'X'. The array is replicated along all other mesh\n",
|
||||
"axes\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "N8yMauHAXKtX"
|
||||
},
|
||||
"source": [
|
||||
"These shardings associated with JAX-level types propagate through operations. For example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Gy7ABds3XND3",
|
||||
"outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"arg0 sharding: ShapedArray(int32[4@X,1])\n",
|
||||
"arg1 sharding: ShapedArray(int32[1,8@Y])\n",
|
||||
"result sharding: ShapedArray(int32[4@X,8@Y])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"arg0 = reshard(np.arange(4).reshape(4, 1), P(\"X\", None))\n",
|
||||
"arg1 = reshard(np.arange(8).reshape(1, 8), P(None, \"Y\"))\n",
|
||||
"\n",
|
||||
"result = arg0 + arg1\n",
|
||||
"\n",
|
||||
"print(f\"arg0 sharding: {jax.typeof(arg0)}\")\n",
|
||||
"print(f\"arg1 sharding: {jax.typeof(arg1)}\")\n",
|
||||
"print(f\"result sharding: {jax.typeof(result)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "lwsygUmVXPCk"
|
||||
},
|
||||
"source": [
|
||||
"We can do the same type querying under a jit:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "grCcotr-XQjY",
|
||||
"outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"x sharding: ShapedArray(int32[4@X,1])\n",
|
||||
"y sharding: ShapedArray(int32[1,8@Y])\n",
|
||||
"ans sharding: ShapedArray(int32[4@X,8@Y])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n",
|
||||
" [ 1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||
" [ 2, 3, 4, 5, 6, 7, 8, 9],\n",
|
||||
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def add_arrays(x, y):\n",
|
||||
" ans = x + y\n",
|
||||
" print(f\"x sharding: {jax.typeof(x)}\")\n",
|
||||
" print(f\"y sharding: {jax.typeof(y)}\")\n",
|
||||
" print(f\"ans sharding: {jax.typeof(ans)}\")\n",
|
||||
" return ans\n",
|
||||
"\n",
|
||||
"add_arrays(arg0, arg1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "lVd6a5ufXZoH"
|
||||
},
|
||||
"source": [
|
||||
"That's the gist of it. Shardings propagate deterministically at trace time and\n",
|
||||
"we can query them at trace time."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ETtwK3LCXSkd"
|
||||
},
|
||||
"source": [
|
||||
"## Sharding rules and operations with ambiguous sharding\n",
|
||||
"\n",
|
||||
"Each op has a sharding rule which specifies its output sharding given its\n",
|
||||
"input shardings. A sharding rule may also throw a (trace-time) error. Each op\n",
|
||||
"is free to implement whatever sharding rule it likes, but the usual pattern is\n",
|
||||
"the following: For each output axis we identify zero of more corresponding\n",
|
||||
"input axes. The output axis is then\n",
|
||||
"sharded according to the “consensus” sharding of the corresponding input axes. i.e., it's\n",
|
||||
"`None` if the input shardings are all `None`, and it's the common non-None input sharding\n",
|
||||
"if there’s exactly one of them, or an error (requiring an explicit out_sharding=... kwarg) otherwise."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "an8-Fq1uXehp"
|
||||
},
|
||||
"source": [
|
||||
"This procedure is done on an axis-by-axis basis. When it’s done, we might end\n",
|
||||
"up with an array sharding that mentions a mesh axis more than once, which is\n",
|
||||
"illegal. In that case we raise a (trace-time) sharding error and ask for an\n",
|
||||
"explicit out_sharding.\n",
|
||||
"\n",
|
||||
"Here are some example sharding rules:\n",
|
||||
" * nullary ops like `jnp.zeros`, `jnp.arange`: These ops create arrays out of whole\n",
|
||||
" cloth so they don’t have input shardings to propagate. Their output is\n",
|
||||
" unsharded by default unless overridden by the out_sharding kwarg.\n",
|
||||
" * unary elementwise ops like `sin`, `exp`: The output is sharded the same as the\n",
|
||||
" input.\n",
|
||||
" * binary ops (`+`, `-`, `*` etc.): Axis shardings of “zipped” dimensions\n",
|
||||
" must match (or be `None`). “Outer product” dimensions (dimensions that\n",
|
||||
" appear in only one argument) are sharded as they are in the input. If the\n",
|
||||
" result ends up mentioning a mesh axis more than once it's an error.\n",
|
||||
" * `reshape.` Reshape is a particularly tricky op. An output axis can map to more\n",
|
||||
" than one input axis (when reshape is used to merge axes) or just a part\n",
|
||||
" of an input axis (when reshape is used to split axes). Our usual rules\n",
|
||||
" don’t apply. Instead we treat reshape as follows. We strip away singleton\n",
|
||||
" axes (these can’t be sharded anyway. Then\n",
|
||||
" we decide whether the reshape is a “split” (splitting a single axis into\n",
|
||||
" two or more adjacent axes), a “merge” (merging two or more adjacent axes\n",
|
||||
" into a single one) or something else. If we have a split or merge case in\n",
|
||||
" which the split/merged axes are sharded as None then we shard the\n",
|
||||
" resulting split/merged axes as None and the other axes according to their\n",
|
||||
" corresponding input axis shardings. In all other cases we throw an error\n",
|
||||
" and require the user to provide an `out_shardings` argument."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "jZMp6w48Xmd7"
|
||||
},
|
||||
"source": [
|
||||
"## JAX transformations and higher-order functions\n",
|
||||
"\n",
|
||||
"The staged-out representation of JAX programs is explicitly typed. (We call\n",
|
||||
"the types “avals” but that’s not important.) In explicit-sharding mode, the\n",
|
||||
"sharding is part of that type. This means that shardings need to match\n",
|
||||
"wherever types need to match. For example, the two sides of a `lax.cond` need to\n",
|
||||
"have results with matching shardings. And the carry of `lax.scan` needs to have the\n",
|
||||
"same sharding at the input and the output of the scan body. And when you\n",
|
||||
"contruct a jaxpr without concrete arguments using `make_jaxpr` you need to\n",
|
||||
"provide shardings too. Certain JAX transformations perform type-level\n",
|
||||
"operations. Automatic differentation constructs a tangent type for each primal\n",
|
||||
"type in the original computation (e.g. `TangentOf(float) == float`,\n",
|
||||
"`TangentOf(int) == float0`). With sharding in the types, this means that tangent\n",
|
||||
"values are sharded in the same way as their primal values. Vmap and scan also\n",
|
||||
"do type-level operations, they lift an array shape to a rank-augmented version\n",
|
||||
"of that shape. That extra array axis needs a sharding. We can infer it from the\n",
|
||||
"arguments to the vmap/scan but they all need to agree. And a nullary vmap/scan\n",
|
||||
"needs an explicit sharding argument just as it needs an explicit length\n",
|
||||
"argument."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ERJx4p0tXoS3"
|
||||
},
|
||||
"source": [
|
||||
"## Working around unimplemented sharding rules using `auto_sharding`\n",
|
||||
"\n",
|
||||
"The implementation of explicit sharding is still a work-in-progress and there\n",
|
||||
"are plenty of ops that are missing sharding rules. For example, `scatter` and\n",
|
||||
"`gather` (i.e. indexing ops).\n",
|
||||
"\n",
|
||||
"Normally we wouldn't suggest using a feature with so many unimplemented cases,\n",
|
||||
"but in this instance there's a reasonable fallback you can use: `auto_axes`.\n",
|
||||
"The idea is that you can temporarily drop into a context where the mesh axes\n",
|
||||
"are \"auto\" rather than \"explicit\". You explicitly specify how you intend the\n",
|
||||
"final result of the `auto_axes` to be sharded as it gets returned to the calling context.\n",
|
||||
"\n",
|
||||
"This works as a fallback for ops with unimplemented sharding rules. It also\n",
|
||||
"works when you want to override the sharding-in-types type system. For\n",
|
||||
"example, suppose we want to add a `f32[4@X, 4]` to a `f32[4, 4@X]`. Our\n",
|
||||
"sharding rule for addition would throw an error: the result would need to be\n",
|
||||
"`f32[4@X, 4@X]`, which tries uses a mesh axis twice, which is illegal. But say you\n",
|
||||
"want to perform the operation anyway, and you want the result to be sharded along\n",
|
||||
"the first axis only, like `f32[4@X, 4]`. You can do this as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "fpFEaMBcXsJG",
|
||||
"outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ERROR!\n",
|
||||
"add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]\n",
|
||||
"=== try again with auto_axes ===\n",
|
||||
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
|
||||
"Result type: ShapedArray(int32[4@X,4])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Result type: ShapedArray(int32[4@X,4])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", None))\n",
|
||||
"some_y = reshard(np.arange(16).reshape(4, 4), P(None, \"X\"))\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" some_x + some_y\n",
|
||||
"except Exception as e:\n",
|
||||
" print(\"ERROR!\")\n",
|
||||
" print(e)\n",
|
||||
"\n",
|
||||
"print(\"=== try again with auto_axes ===\")\n",
|
||||
"\n",
|
||||
"@auto_axes\n",
|
||||
"def add_with_out_sharding_kwarg(x, y):\n",
|
||||
" print(f\"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}\")\n",
|
||||
" return x + y\n",
|
||||
"\n",
|
||||
"result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P(\"X\", None))\n",
|
||||
"print(f\"Result type: {jax.typeof(result)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8-_zDr-AXvb6"
|
||||
},
|
||||
"source": [
|
||||
"## Using a mixture of sharding modes\n",
|
||||
"\n",
|
||||
"JAX now has three styles of parallelism:\n",
|
||||
"\n",
|
||||
" * *Automatic sharding* is where you treat all the devices as a single logical\n",
|
||||
" machine and write a \"global view\" array program for that machine. The\n",
|
||||
" compiler decides how to partition the data and computation across the\n",
|
||||
" available devices. You can give hints to the compiler using\n",
|
||||
" `with_sharding_constraint`.\n",
|
||||
" * *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n",
|
||||
" you're writing a global-view program. The difference is that the sharding\n",
|
||||
" of each array is part of the array's JAX-level type making it an explicit\n",
|
||||
" part of the programming model. These shardings are propagated at the JAX\n",
|
||||
" level and queryable at trace time. It's still the compiler's responsibility\n",
|
||||
" to turn the whole-array program into per-device programs (turning `jnp.sum`\n",
|
||||
" into `psum` for example) but the compiler is heavily constrained by the\n",
|
||||
" user-supplied shardings.\n",
|
||||
" * *Manual Sharding* (`shard_map`) is where you write a program from the\n",
|
||||
" perspective of a single device. Communication between devices happens via\n",
|
||||
" explicit collective operations like psum.\n",
|
||||
"\n",
|
||||
"A summary table:\n",
|
||||
"\n",
|
||||
"| Mode | Explicit sharding? | Explicit Collectives? |\n",
|
||||
"|---|---|---|\n",
|
||||
"| Auto | No | No |\n",
|
||||
"| Explicit (new) | Yes | No |\n",
|
||||
"| Manual | Yes | Yes |\n",
|
||||
"\n",
|
||||
"The current mesh tells us which sharding mode we're in. We can query it with\n",
|
||||
"`get_abstract_mesh`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "geptWrdYX0OM",
|
||||
"outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(f\"Current mesh is: {get_abstract_mesh()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AQQjzUeGX4P6"
|
||||
},
|
||||
"source": [
|
||||
"Since `axis_types=(Explicit, Explicit)`, this means we're in fully-explicit\n",
|
||||
"mode. Notice that the sharding mode is associated with a mesh _axis_, not the\n",
|
||||
"mesh as a whole. We can actually mix sharding modes by having a different\n",
|
||||
"sharding mode for each mesh axis. Shardings (on JAX-level types) can only\n",
|
||||
"mention _explicit_ mesh axes and collective operations like `psum` can only\n",
|
||||
"mention _manual_ mesh axes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AQQjzUeGX4P6"
|
||||
},
|
||||
"source": [
|
||||
"## Concrete array shardings can mention `Auto` mesh axis\n",
|
||||
"\n",
|
||||
"You can query the sharding of a concrete array `x` with `x.sharding`. You\n",
|
||||
"might expect the result to be the same as the sharding associated with the\n",
|
||||
"value's type, `jax.typeof(x).sharding`. It might not be! The concrete array sharding, `x.sharding`, describes the sharding along\n",
|
||||
"both `Explicit` and `Auto` mesh axes. It's the sharding that the compiler\n",
|
||||
"eventually chose. Whereas the type-specificed sharding,\n",
|
||||
"`jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh\n",
|
||||
"axes. The `Auto` axes are deliberately hidden from the type because they're\n",
|
||||
"the purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than,\n",
|
||||
"the type-specified sharding. For example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "ivLl6bxmX7EZ",
|
||||
"outputId": "6d7b7fce-68b6-47f1-b214-d62bda8d7b6e"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit)) ===\n",
|
||||
"Concrete value sharding: PartitionSpec('X',)\n",
|
||||
"Type-specified sharding: PartitionSpec('X',)\n",
|
||||
"=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto)) ===\n",
|
||||
"Concrete value sharding: PartitionSpec('X',)\n",
|
||||
"Type-specified sharding: PartitionSpec(None,)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n",
|
||||
" -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n",
|
||||
" -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def compare_shardings(x):\n",
|
||||
" print(f\"=== with mesh: {get_abstract_mesh()} ===\")\n",
|
||||
" print(f\"Concrete value sharding: {x.sharding.spec}\")\n",
|
||||
" print(f\"Type-specified sharding: {jax.typeof(x).sharding.spec}\")\n",
|
||||
"\n",
|
||||
"my_array = jnp.sin(reshard(np.arange(8), P(\"X\")))\n",
|
||||
"compare_shardings(my_array)\n",
|
||||
"\n",
|
||||
"@auto_axes\n",
|
||||
"def check_in_auto_context(x):\n",
|
||||
" compare_shardings(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"check_in_auto_context(my_array, out_shardings=P(\"X\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "MRFccsi5X8so"
|
||||
},
|
||||
"source": [
|
||||
"Notice that at the top level, where we're currently in a fully `Explicit` mesh\n",
|
||||
"context, the concrete array sharding and type-specified sharding agree. But\n",
|
||||
"under the `auto_axes` decorator we're in a fully `Auto` mesh context and the\n",
|
||||
"two shardings disagree: the type-specified sharding is `P(None)` whereas the\n",
|
||||
"concrete array sharding is `P(\"X\")` (though it could be anything! It's up to\n",
|
||||
"the compiler)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"jupytext": {
|
||||
"formats": "ipynb,md:myst"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
417
docs/notebooks/explicit-sharding.md
Normal file
417
docs/notebooks/explicit-sharding.md
Normal file
@ -0,0 +1,417 @@
|
||||
---
|
||||
jupytext:
|
||||
formats: ipynb,md:myst
|
||||
text_representation:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.4
|
||||
kernelspec:
|
||||
display_name: Python 3 (ipykernel)
|
||||
language: python
|
||||
name: python3
|
||||
---
|
||||
|
||||
+++ {"id": "ZVJCNxUcVkkm"}
|
||||
|
||||
# Explicit sharding (a.k.a. "sharding in types")
|
||||
|
||||
+++ {"id": "ATLBMlw3VcCJ"}
|
||||
|
||||
JAX's traditional automatic sharding leaves sharding decisions to the compiler.
|
||||
You can provide hints to the compiler using
|
||||
`jax.lax.with_sharding_constraint` but for the most part you're supposed to be
|
||||
focussed on the math while the compiler worries about sharding.
|
||||
|
||||
But what if you have a strong opinion about how you want your program sharded?
|
||||
With enough calls to `with_sharding_constraint` you can probably guide the
|
||||
compiler's hand to make it do what you want. But "compiler tickling" is
|
||||
famously not a fun programming model. Where should you put the sharding
|
||||
constraints? You could put them on every single intermediate but that's a lot
|
||||
of work and it's also easy to make mistakes that way because there's no way to
|
||||
check that the shardings make sense together. More commonly, people add just
|
||||
enough sharding annotations to constrain the compiler. But this is a slow
|
||||
iterative process. It's hard to know ahead of time what XLA's gSPMD pass will
|
||||
do (it's a whole-program optimization) so all you can do is add annotations,
|
||||
inspect XLA's sharding choices to see what happened, and repeat.
|
||||
|
||||
To fix this we've come up with a different style of sharding programming we
|
||||
call "explicit sharding" or "sharding in types". The idea is that sharding
|
||||
propagation happens at the JAX level at trace time. Each JAX operation has a
|
||||
sharding rule that takes the shardings of the op's arguments and produces a
|
||||
sharding for the op's result. For most operations these rules are simple and
|
||||
obvious because there's only one reasonable choice. But for some operations it's
|
||||
unclear how to shard the result. In that case we ask the programmer
|
||||
to provide an `out_sharding` argument explicitly and we throw a (trace-time)
|
||||
error otherwise. Since the shardings are propagated at trace time they can
|
||||
also be _queried_ at trace time too. In the rest of this doc we'll describe
|
||||
how to use explicit sharding mode. Note that this is a new feature so we
|
||||
expect there to be bugs and unimplemented cases. Please let us know when you
|
||||
find something that doesn't work!
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: hVi6mApuVw3r
|
||||
outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
|
||||
---
|
||||
import jax
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P, AxisType, set_mesh
|
||||
from jax.experimental.shard import reshard, auto_axes
|
||||
from jax._src.mesh import get_abstract_mesh
|
||||
|
||||
jax.config.update('jax_num_cpu_devices', 8)
|
||||
```
|
||||
|
||||
+++ {"id": "oU5O6yOLWqbP"}
|
||||
|
||||
## Setting up an explicit mesh
|
||||
|
||||
The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that
|
||||
the JAX-level _type_ of a value includes a description of how the value is sharded.
|
||||
We can query the JAX-level type of any JAX value (or Numpy array, or Python
|
||||
scalar) using `jax.typeof`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: mzDIDvj7Vw0k
|
||||
outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434
|
||||
---
|
||||
some_array = np.arange(8)
|
||||
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
|
||||
```
|
||||
|
||||
+++ {"id": "TZzp_1sXW061"}
|
||||
|
||||
Importantly, we can query the type even while tracing under a `jit` (the JAX-level type
|
||||
is almost _defined_ as "the information about a value we have access to while
|
||||
under a jit).
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: IyPx_-IBVwxr
|
||||
outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499
|
||||
---
|
||||
@jax.jit
|
||||
def foo(x):
|
||||
print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
|
||||
return x + x
|
||||
|
||||
foo(some_array)
|
||||
```
|
||||
|
||||
+++ {"id": "c3gNPzfZW45K"}
|
||||
|
||||
These types show the shape and dtype of array but they don't appear to
|
||||
show sharding. (Actually, they _did_ show sharding, but the shardings were
|
||||
trivial. See "Concrete array shardings", below.) To start seeing some
|
||||
interesting shardings we need to set up an explicit-sharding mesh. We use
|
||||
`set_mesh` to set it as the current mesh for the remainder of this notebook.
|
||||
(If you only want to set the mesh for some particular scope and return to the previous
|
||||
mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.)
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: NO2ulM_QW7a8
|
||||
outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb
|
||||
---
|
||||
mesh = jax.make_mesh((2, 4), ("X", "Y"),
|
||||
axis_types=(AxisType.Explicit, AxisType.Explicit))
|
||||
set_mesh(mesh)
|
||||
|
||||
print(f"Current mesh is: {get_abstract_mesh()}")
|
||||
```
|
||||
|
||||
+++ {"id": "V7bVz6tzW_Eb"}
|
||||
|
||||
Now we can create some sharded arrays using `reshard`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: 1-TzmA0AXCAf
|
||||
outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71
|
||||
---
|
||||
replicated_array = np.arange(8).reshape(4, 2)
|
||||
sharded_array = reshard(replicated_array, P("X", None))
|
||||
|
||||
print(f"replicated_array type: {jax.typeof(replicated_array)}")
|
||||
print(f"sharded_array type: {jax.typeof(sharded_array)}")
|
||||
```
|
||||
|
||||
+++ {"id": "B0jBBXtgXBxr"}
|
||||
|
||||
We should read the type `f32[4@X, 2]` as "a 4-by-2 array of 32-bit floats whose first dimension
|
||||
is sharded along mesh axis 'X'. The array is replicated along all other mesh
|
||||
axes"
|
||||
|
||||
+++ {"id": "N8yMauHAXKtX"}
|
||||
|
||||
These shardings associated with JAX-level types propagate through operations. For example:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: Gy7ABds3XND3
|
||||
outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b
|
||||
---
|
||||
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
|
||||
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
|
||||
|
||||
result = arg0 + arg1
|
||||
|
||||
print(f"arg0 sharding: {jax.typeof(arg0)}")
|
||||
print(f"arg1 sharding: {jax.typeof(arg1)}")
|
||||
print(f"result sharding: {jax.typeof(result)}")
|
||||
```
|
||||
|
||||
+++ {"id": "lwsygUmVXPCk"}
|
||||
|
||||
We can do the same type querying under a jit:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: grCcotr-XQjY
|
||||
outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a
|
||||
---
|
||||
@jax.jit
|
||||
def add_arrays(x, y):
|
||||
ans = x + y
|
||||
print(f"x sharding: {jax.typeof(x)}")
|
||||
print(f"y sharding: {jax.typeof(y)}")
|
||||
print(f"ans sharding: {jax.typeof(ans)}")
|
||||
return ans
|
||||
|
||||
add_arrays(arg0, arg1)
|
||||
```
|
||||
|
||||
+++ {"id": "lVd6a5ufXZoH"}
|
||||
|
||||
That's the gist of it. Shardings propagate deterministically at trace time and
|
||||
we can query them at trace time.
|
||||
|
||||
+++ {"id": "ETtwK3LCXSkd"}
|
||||
|
||||
## Sharding rules and operations with ambiguous sharding
|
||||
|
||||
Each op has a sharding rule which specifies its output sharding given its
|
||||
input shardings. A sharding rule may also throw a (trace-time) error. Each op
|
||||
is free to implement whatever sharding rule it likes, but the usual pattern is
|
||||
the following: For each output axis we identify zero of more corresponding
|
||||
input axes. The output axis is then
|
||||
sharded according to the “consensus” sharding of the corresponding input axes. i.e., it's
|
||||
`None` if the input shardings are all `None`, and it's the common non-None input sharding
|
||||
if there’s exactly one of them, or an error (requiring an explicit out_sharding=... kwarg) otherwise.
|
||||
|
||||
+++ {"id": "an8-Fq1uXehp"}
|
||||
|
||||
This procedure is done on an axis-by-axis basis. When it’s done, we might end
|
||||
up with an array sharding that mentions a mesh axis more than once, which is
|
||||
illegal. In that case we raise a (trace-time) sharding error and ask for an
|
||||
explicit out_sharding.
|
||||
|
||||
Here are some example sharding rules:
|
||||
* nullary ops like `jnp.zeros`, `jnp.arange`: These ops create arrays out of whole
|
||||
cloth so they don’t have input shardings to propagate. Their output is
|
||||
unsharded by default unless overridden by the out_sharding kwarg.
|
||||
* unary elementwise ops like `sin`, `exp`: The output is sharded the same as the
|
||||
input.
|
||||
* binary ops (`+`, `-`, `*` etc.): Axis shardings of “zipped” dimensions
|
||||
must match (or be `None`). “Outer product” dimensions (dimensions that
|
||||
appear in only one argument) are sharded as they are in the input. If the
|
||||
result ends up mentioning a mesh axis more than once it's an error.
|
||||
* `reshape.` Reshape is a particularly tricky op. An output axis can map to more
|
||||
than one input axis (when reshape is used to merge axes) or just a part
|
||||
of an input axis (when reshape is used to split axes). Our usual rules
|
||||
don’t apply. Instead we treat reshape as follows. We strip away singleton
|
||||
axes (these can’t be sharded anyway. Then
|
||||
we decide whether the reshape is a “split” (splitting a single axis into
|
||||
two or more adjacent axes), a “merge” (merging two or more adjacent axes
|
||||
into a single one) or something else. If we have a split or merge case in
|
||||
which the split/merged axes are sharded as None then we shard the
|
||||
resulting split/merged axes as None and the other axes according to their
|
||||
corresponding input axis shardings. In all other cases we throw an error
|
||||
and require the user to provide an `out_shardings` argument.
|
||||
|
||||
+++ {"id": "jZMp6w48Xmd7"}
|
||||
|
||||
## JAX transformations and higher-order functions
|
||||
|
||||
The staged-out representation of JAX programs is explicitly typed. (We call
|
||||
the types “avals” but that’s not important.) In explicit-sharding mode, the
|
||||
sharding is part of that type. This means that shardings need to match
|
||||
wherever types need to match. For example, the two sides of a `lax.cond` need to
|
||||
have results with matching shardings. And the carry of `lax.scan` needs to have the
|
||||
same sharding at the input and the output of the scan body. And when you
|
||||
contruct a jaxpr without concrete arguments using `make_jaxpr` you need to
|
||||
provide shardings too. Certain JAX transformations perform type-level
|
||||
operations. Automatic differentation constructs a tangent type for each primal
|
||||
type in the original computation (e.g. `TangentOf(float) == float`,
|
||||
`TangentOf(int) == float0`). With sharding in the types, this means that tangent
|
||||
values are sharded in the same way as their primal values. Vmap and scan also
|
||||
do type-level operations, they lift an array shape to a rank-augmented version
|
||||
of that shape. That extra array axis needs a sharding. We can infer it from the
|
||||
arguments to the vmap/scan but they all need to agree. And a nullary vmap/scan
|
||||
needs an explicit sharding argument just as it needs an explicit length
|
||||
argument.
|
||||
|
||||
+++ {"id": "ERJx4p0tXoS3"}
|
||||
|
||||
## Working around unimplemented sharding rules using `auto_sharding`
|
||||
|
||||
The implementation of explicit sharding is still a work-in-progress and there
|
||||
are plenty of ops that are missing sharding rules. For example, `scatter` and
|
||||
`gather` (i.e. indexing ops).
|
||||
|
||||
Normally we wouldn't suggest using a feature with so many unimplemented cases,
|
||||
but in this instance there's a reasonable fallback you can use: `auto_axes`.
|
||||
The idea is that you can temporarily drop into a context where the mesh axes
|
||||
are "auto" rather than "explicit". You explicitly specify how you intend the
|
||||
final result of the `auto_axes` to be sharded as it gets returned to the calling context.
|
||||
|
||||
This works as a fallback for ops with unimplemented sharding rules. It also
|
||||
works when you want to override the sharding-in-types type system. For
|
||||
example, suppose we want to add a `f32[4@X, 4]` to a `f32[4, 4@X]`. Our
|
||||
sharding rule for addition would throw an error: the result would need to be
|
||||
`f32[4@X, 4@X]`, which tries uses a mesh axis twice, which is illegal. But say you
|
||||
want to perform the operation anyway, and you want the result to be sharded along
|
||||
the first axis only, like `f32[4@X, 4]`. You can do this as follows:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: fpFEaMBcXsJG
|
||||
outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660
|
||||
---
|
||||
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
|
||||
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
|
||||
|
||||
try:
|
||||
some_x + some_y
|
||||
except Exception as e:
|
||||
print("ERROR!")
|
||||
print(e)
|
||||
|
||||
print("=== try again with auto_axes ===")
|
||||
|
||||
@auto_axes
|
||||
def add_with_out_sharding_kwarg(x, y):
|
||||
print(f"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}")
|
||||
return x + y
|
||||
|
||||
result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P("X", None))
|
||||
print(f"Result type: {jax.typeof(result)}")
|
||||
```
|
||||
|
||||
+++ {"id": "8-_zDr-AXvb6"}
|
||||
|
||||
## Using a mixture of sharding modes
|
||||
|
||||
JAX now has three styles of parallelism:
|
||||
|
||||
* *Automatic sharding* is where you treat all the devices as a single logical
|
||||
machine and write a "global view" array program for that machine. The
|
||||
compiler decides how to partition the data and computation across the
|
||||
available devices. You can give hints to the compiler using
|
||||
`with_sharding_constraint`.
|
||||
* *Explicit Sharding* (\*new\*) is similar to automatic sharding in that
|
||||
you're writing a global-view program. The difference is that the sharding
|
||||
of each array is part of the array's JAX-level type making it an explicit
|
||||
part of the programming model. These shardings are propagated at the JAX
|
||||
level and queryable at trace time. It's still the compiler's responsibility
|
||||
to turn the whole-array program into per-device programs (turning `jnp.sum`
|
||||
into `psum` for example) but the compiler is heavily constrained by the
|
||||
user-supplied shardings.
|
||||
* *Manual Sharding* (`shard_map`) is where you write a program from the
|
||||
perspective of a single device. Communication between devices happens via
|
||||
explicit collective operations like psum.
|
||||
|
||||
A summary table:
|
||||
|
||||
| Mode | Explicit sharding? | Explicit Collectives? |
|
||||
|---|---|---|
|
||||
| Auto | No | No |
|
||||
| Explicit (new) | Yes | No |
|
||||
| Manual | Yes | Yes |
|
||||
|
||||
The current mesh tells us which sharding mode we're in. We can query it with
|
||||
`get_abstract_mesh`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: geptWrdYX0OM
|
||||
outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f
|
||||
---
|
||||
print(f"Current mesh is: {get_abstract_mesh()}")
|
||||
```
|
||||
|
||||
+++ {"id": "AQQjzUeGX4P6"}
|
||||
|
||||
Since `axis_types=(Explicit, Explicit)`, this means we're in fully-explicit
|
||||
mode. Notice that the sharding mode is associated with a mesh _axis_, not the
|
||||
mesh as a whole. We can actually mix sharding modes by having a different
|
||||
sharding mode for each mesh axis. Shardings (on JAX-level types) can only
|
||||
mention _explicit_ mesh axes and collective operations like `psum` can only
|
||||
mention _manual_ mesh axes.
|
||||
|
||||
+++ {"id": "AQQjzUeGX4P6"}
|
||||
|
||||
## Concrete array shardings can mention `Auto` mesh axis
|
||||
|
||||
You can query the sharding of a concrete array `x` with `x.sharding`. You
|
||||
might expect the result to be the same as the sharding associated with the
|
||||
value's type, `jax.typeof(x).sharding`. It might not be! The concrete array sharding, `x.sharding`, describes the sharding along
|
||||
both `Explicit` and `Auto` mesh axes. It's the sharding that the compiler
|
||||
eventually chose. Whereas the type-specificed sharding,
|
||||
`jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh
|
||||
axes. The `Auto` axes are deliberately hidden from the type because they're
|
||||
the purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than,
|
||||
the type-specified sharding. For example:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: ivLl6bxmX7EZ
|
||||
outputId: 6d7b7fce-68b6-47f1-b214-d62bda8d7b6e
|
||||
---
|
||||
def compare_shardings(x):
|
||||
print(f"=== with mesh: {get_abstract_mesh()} ===")
|
||||
print(f"Concrete value sharding: {x.sharding.spec}")
|
||||
print(f"Type-specified sharding: {jax.typeof(x).sharding.spec}")
|
||||
|
||||
my_array = jnp.sin(reshard(np.arange(8), P("X")))
|
||||
compare_shardings(my_array)
|
||||
|
||||
@auto_axes
|
||||
def check_in_auto_context(x):
|
||||
compare_shardings(x)
|
||||
return x
|
||||
|
||||
check_in_auto_context(my_array, out_shardings=P("X"))
|
||||
```
|
||||
|
||||
+++ {"id": "MRFccsi5X8so"}
|
||||
|
||||
Notice that at the top level, where we're currently in a fully `Explicit` mesh
|
||||
context, the concrete array sharding and type-specified sharding agree. But
|
||||
under the `auto_axes` decorator we're in a fully `Auto` mesh context and the
|
||||
two shardings disagree: the type-specified sharding is `P(None)` whereas the
|
||||
concrete array sharding is `P("X")` (though it could be anything! It's up to
|
||||
the compiler).
|
Loading…
x
Reference in New Issue
Block a user