diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index 44d46bddb..f8369d337 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -1,713 +1,712 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "ZVJCNxUcVkkm" - }, - "source": [ - "# Explicit sharding (a.k.a. \"sharding in types\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ATLBMlw3VcCJ" - }, - "source": [ - "JAX's traditional automatic sharding leaves sharding decisions to the compiler.\n", - "You can provide hints to the compiler using\n", - "`jax.lax.with_sharding_constraint` but for the most part you're supposed to be\n", - "focussed on the math while the compiler worries about sharding.\n", - "\n", - "But what if you have a strong opinion about how you want your program sharded?\n", - "With enough calls to `with_sharding_constraint` you can probably guide the\n", - "compiler's hand to make it do what you want. But \"compiler tickling\" is\n", - "famously not a fun programming model. Where should you put the sharding\n", - "constraints? You could put them on every single intermediate but that's a lot\n", - "of work and it's also easy to make mistakes that way because there's no way to\n", - "check that the shardings make sense together. More commonly, people add just\n", - "enough sharding annotations to constrain the compiler. But this is a slow\n", - "iterative process. It's hard to know ahead of time what XLA's gSPMD pass will\n", - "do (it's a whole-program optimization) so all you can do is add annotations,\n", - "inspect XLA's sharding choices to see what happened, and repeat.\n", - "\n", - "To fix this we've come up with a different style of sharding programming we\n", - "call \"explicit sharding\" or \"sharding in types\". The idea is that sharding\n", - "propagation happens at the JAX level at trace time. Each JAX operation has a\n", - "sharding rule that takes the shardings of the op's arguments and produces a\n", - "sharding for the op's result. For most operations these rules are simple and\n", - "obvious because there's only one reasonable choice. But for some operations it's\n", - "unclear how to shard the result. In that case we ask the programmer\n", - "to provide an `out_sharding` argument explicitly and we throw a (trace-time)\n", - "error otherwise. Since the shardings are propagated at trace time they can\n", - "also be _queried_ at trace time too. In the rest of this doc we'll describe\n", - "how to use explicit sharding mode. Note that this is a new feature so we\n", - "expect there to be bugs and unimplemented cases. Please let us know when you\n", - "find something that doesn't work!" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "hVi6mApuVw3r", - "outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf" - }, - "outputs": [], - "source": [ - "import jax\n", - "import numpy as np\n", - "import jax.numpy as jnp\n", - "from jax.sharding import PartitionSpec as P, AxisType, set_mesh\n", - "from jax.experimental.shard import reshard, auto_axes\n", - "from jax._src.mesh import get_abstract_mesh\n", - "\n", - "jax.config.update('jax_num_cpu_devices', 8)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oU5O6yOLWqbP" - }, - "source": [ - "## Setting up an explicit mesh\n", - "\n", - "The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n", - "the JAX-level _type_ of a value includes a description of how the value is sharded.\n", - "We can query the JAX-level type of any JAX value (or Numpy array, or Python\n", - "scalar) using `jax.typeof`:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mzDIDvj7Vw0k", - "outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434" - }, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "JAX-level type of some_array: ShapedArray(int32[8])\n" - ] - } - ], - "source": [ - "some_array = np.arange(8)\n", - "print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TZzp_1sXW061" - }, - "source": [ - "Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n", - "is almost _defined_ as \"the information about a value we have access to while\n", - "under a jit)." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "IyPx_-IBVwxr", - "outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "JAX-level type of x during tracing: ShapedArray(int32[8])\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)" + "cell_type": "markdown", + "metadata": { + "id": "ZVJCNxUcVkkm" + }, + "source": [ + "# Explicit sharding (a.k.a. \"sharding in types\")" ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "@jax.jit\n", - "def foo(x):\n", - " print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n", - " return x + x\n", - "\n", - "foo(some_array)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c3gNPzfZW45K" - }, - "source": [ - "These types show the shape and dtype of array but they don't appear to\n", - "show sharding. (Actually, they _did_ show sharding, but the shardings were\n", - "trivial. See \"Concrete array shardings\", below.) To start seeing some\n", - "interesting shardings we need to set up an explicit-sharding mesh. We use\n", - "`set_mesh` to set it as the current mesh for the remainder of this notebook.\n", - "(If you only want to set the mesh for some particular scope and return to the previous\n", - "mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "NO2ulM_QW7a8", - "outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n" - ] - } - ], - "source": [ - "mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", - " axis_types=(AxisType.Explicit, AxisType.Explicit))\n", - "set_mesh(mesh)\n", - "\n", - "print(f\"Current mesh is: {get_abstract_mesh()}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "V7bVz6tzW_Eb" - }, - "source": [ - "Now we can create some sharded arrays using `reshard`:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "1-TzmA0AXCAf", - "outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "replicated_array type: ShapedArray(int32[4,2])\n", - "sharded_array type: ShapedArray(int32[4@X,2])\n" - ] - } - ], - "source": [ - "replicated_array = np.arange(8).reshape(4, 2)\n", - "sharded_array = reshard(replicated_array, P(\"X\", None))\n", - "\n", - "print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n", - "print(f\"sharded_array type: {jax.typeof(sharded_array)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "B0jBBXtgXBxr" - }, - "source": [ - "We should read the type `f32[4@X, 2]` as \"a 4-by-2 array of 32-bit floats whose first dimension\n", - "is sharded along mesh axis 'X'. The array is replicated along all other mesh\n", - "axes\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N8yMauHAXKtX" - }, - "source": [ - "These shardings associated with JAX-level types propagate through operations. For example:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Gy7ABds3XND3", - "outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "arg0 sharding: ShapedArray(int32[4@X,1])\n", - "arg1 sharding: ShapedArray(int32[1,8@Y])\n", - "result sharding: ShapedArray(int32[4@X,8@Y])\n" - ] - } - ], - "source": [ - "arg0 = reshard(np.arange(4).reshape(4, 1), P(\"X\", None))\n", - "arg1 = reshard(np.arange(8).reshape(1, 8), P(None, \"Y\"))\n", - "\n", - "result = arg0 + arg1\n", - "\n", - "print(f\"arg0 sharding: {jax.typeof(arg0)}\")\n", - "print(f\"arg1 sharding: {jax.typeof(arg1)}\")\n", - "print(f\"result sharding: {jax.typeof(result)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lwsygUmVXPCk" - }, - "source": [ - "We can do the same type querying under a jit:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "grCcotr-XQjY", - "outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x sharding: ShapedArray(int32[4@X,1])\n", - "y sharding: ShapedArray(int32[1,8@Y])\n", - "ans sharding: ShapedArray(int32[4@X,8@Y])\n" - ] }, { - "data": { - "text/plain": [ - "Array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n", - " [ 1, 2, 3, 4, 5, 6, 7, 8],\n", - " [ 2, 3, 4, 5, 6, 7, 8, 9],\n", - " [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)" + "cell_type": "markdown", + "metadata": { + "id": "ATLBMlw3VcCJ" + }, + "source": [ + "JAX's traditional automatic sharding leaves sharding decisions to the compiler.\n", + "You can provide hints to the compiler using\n", + "`jax.lax.with_sharding_constraint` but for the most part you're supposed to be\n", + "focussed on the math while the compiler worries about sharding.\n", + "\n", + "But what if you have a strong opinion about how you want your program sharded?\n", + "With enough calls to `with_sharding_constraint` you can probably guide the\n", + "compiler's hand to make it do what you want. But \"compiler tickling\" is\n", + "famously not a fun programming model. Where should you put the sharding\n", + "constraints? You could put them on every single intermediate but that's a lot\n", + "of work and it's also easy to make mistakes that way because there's no way to\n", + "check that the shardings make sense together. More commonly, people add just\n", + "enough sharding annotations to constrain the compiler. But this is a slow\n", + "iterative process. It's hard to know ahead of time what XLA's gSPMD pass will\n", + "do (it's a whole-program optimization) so all you can do is add annotations,\n", + "inspect XLA's sharding choices to see what happened, and repeat.\n", + "\n", + "To fix this we've come up with a different style of sharding programming we\n", + "call \"explicit sharding\" or \"sharding in types\". The idea is that sharding\n", + "propagation happens at the JAX level at trace time. Each JAX operation has a\n", + "sharding rule that takes the shardings of the op's arguments and produces a\n", + "sharding for the op's result. For most operations these rules are simple and\n", + "obvious because there's only one reasonable choice. But for some operations it's\n", + "unclear how to shard the result. In that case we ask the programmer\n", + "to provide an `out_sharding` argument explicitly and we throw a (trace-time)\n", + "error otherwise. Since the shardings are propagated at trace time they can\n", + "also be _queried_ at trace time too. In the rest of this doc we'll describe\n", + "how to use explicit sharding mode. Note that this is a new feature so we\n", + "expect there to be bugs and unimplemented cases. Please let us know when you\n", + "find something that doesn't work!" ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "@jax.jit\n", - "def add_arrays(x, y):\n", - " ans = x + y\n", - " print(f\"x sharding: {jax.typeof(x)}\")\n", - " print(f\"y sharding: {jax.typeof(y)}\")\n", - " print(f\"ans sharding: {jax.typeof(ans)}\")\n", - " return ans\n", - "\n", - "add_arrays(arg0, arg1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lVd6a5ufXZoH" - }, - "source": [ - "That's the gist of it. Shardings propagate deterministically at trace time and\n", - "we can query them at trace time." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ETtwK3LCXSkd" - }, - "source": [ - "## Sharding rules and operations with ambiguous sharding\n", - "\n", - "Each op has a sharding rule which specifies its output sharding given its\n", - "input shardings. A sharding rule may also throw a (trace-time) error. Each op\n", - "is free to implement whatever sharding rule it likes, but the usual pattern is\n", - "the following: For each output axis we identify zero of more corresponding\n", - "input axes. The output axis is then\n", - "sharded according to the “consensus” sharding of the corresponding input axes. i.e., it's\n", - "`None` if the input shardings are all `None`, and it's the common non-None input sharding\n", - "if there’s exactly one of them, or an error (requiring an explicit out_sharding=... kwarg) otherwise." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "an8-Fq1uXehp" - }, - "source": [ - "This procedure is done on an axis-by-axis basis. When it’s done, we might end\n", - "up with an array sharding that mentions a mesh axis more than once, which is\n", - "illegal. In that case we raise a (trace-time) sharding error and ask for an\n", - "explicit out_sharding.\n", - "\n", - "Here are some example sharding rules:\n", - " * nullary ops like `jnp.zeros`, `jnp.arange`: These ops create arrays out of whole\n", - " cloth so they don’t have input shardings to propagate. Their output is\n", - " unsharded by default unless overridden by the out_sharding kwarg.\n", - " * unary elementwise ops like `sin`, `exp`: The output is sharded the same as the\n", - " input.\n", - " * binary ops (`+`, `-`, `*` etc.): Axis shardings of “zipped” dimensions\n", - " must match (or be `None`). “Outer product” dimensions (dimensions that\n", - " appear in only one argument) are sharded as they are in the input. If the\n", - " result ends up mentioning a mesh axis more than once it's an error.\n", - " * `reshape.` Reshape is a particularly tricky op. An output axis can map to more\n", - " than one input axis (when reshape is used to merge axes) or just a part\n", - " of an input axis (when reshape is used to split axes). Our usual rules\n", - " don’t apply. Instead we treat reshape as follows. We strip away singleton\n", - " axes (these can’t be sharded anyway. Then\n", - " we decide whether the reshape is a “split” (splitting a single axis into\n", - " two or more adjacent axes), a “merge” (merging two or more adjacent axes\n", - " into a single one) or something else. If we have a split or merge case in\n", - " which the split/merged axes are sharded as None then we shard the\n", - " resulting split/merged axes as None and the other axes according to their\n", - " corresponding input axis shardings. In all other cases we throw an error\n", - " and require the user to provide an `out_shardings` argument." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jZMp6w48Xmd7" - }, - "source": [ - "## JAX transformations and higher-order functions\n", - "\n", - "The staged-out representation of JAX programs is explicitly typed. (We call\n", - "the types “avals” but that’s not important.) In explicit-sharding mode, the\n", - "sharding is part of that type. This means that shardings need to match\n", - "wherever types need to match. For example, the two sides of a `lax.cond` need to\n", - "have results with matching shardings. And the carry of `lax.scan` needs to have the\n", - "same sharding at the input and the output of the scan body. And when you\n", - "contruct a jaxpr without concrete arguments using `make_jaxpr` you need to\n", - "provide shardings too. Certain JAX transformations perform type-level\n", - "operations. Automatic differentation constructs a tangent type for each primal\n", - "type in the original computation (e.g. `TangentOf(float) == float`,\n", - "`TangentOf(int) == float0`). With sharding in the types, this means that tangent\n", - "values are sharded in the same way as their primal values. Vmap and scan also\n", - "do type-level operations, they lift an array shape to a rank-augmented version\n", - "of that shape. That extra array axis needs a sharding. We can infer it from the\n", - "arguments to the vmap/scan but they all need to agree. And a nullary vmap/scan\n", - "needs an explicit sharding argument just as it needs an explicit length\n", - "argument." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ERJx4p0tXoS3" - }, - "source": [ - "## Working around unimplemented sharding rules using `auto_sharding`\n", - "\n", - "The implementation of explicit sharding is still a work-in-progress and there\n", - "are plenty of ops that are missing sharding rules. For example, `scatter` and\n", - "`gather` (i.e. indexing ops).\n", - "\n", - "Normally we wouldn't suggest using a feature with so many unimplemented cases,\n", - "but in this instance there's a reasonable fallback you can use: `auto_axes`.\n", - "The idea is that you can temporarily drop into a context where the mesh axes\n", - "are \"auto\" rather than \"explicit\". You explicitly specify how you intend the\n", - "final result of the `auto_axes` to be sharded as it gets returned to the calling context.\n", - "\n", - "This works as a fallback for ops with unimplemented sharding rules. It also\n", - "works when you want to override the sharding-in-types type system. For\n", - "example, suppose we want to add a `f32[4@X, 4]` to a `f32[4, 4@X]`. Our\n", - "sharding rule for addition would throw an error: the result would need to be\n", - "`f32[4@X, 4@X]`, which tries uses a mesh axis twice, which is illegal. But say you\n", - "want to perform the operation anyway, and you want the result to be sharded along\n", - "the first axis only, like `f32[4@X, 4]`. You can do this as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "fpFEaMBcXsJG", - "outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ERROR!\n", - "add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]\n", - "=== try again with auto_axes ===\n", - "We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n", - "Result type: ShapedArray(int32[4@X,4])\n" - ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result type: ShapedArray(int32[4@X,4])\n" - ] - } - ], - "source": [ - "some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", None))\n", - "some_y = reshard(np.arange(16).reshape(4, 4), P(None, \"X\"))\n", - "\n", - "try:\n", - " some_x + some_y\n", - "except Exception as e:\n", - " print(\"ERROR!\")\n", - " print(e)\n", - "\n", - "print(\"=== try again with auto_axes ===\")\n", - "\n", - "@auto_axes\n", - "def add_with_out_sharding_kwarg(x, y):\n", - " print(f\"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}\")\n", - " return x + y\n", - "\n", - "result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P(\"X\", None))\n", - "print(f\"Result type: {jax.typeof(result)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8-_zDr-AXvb6" - }, - "source": [ - "## Using a mixture of sharding modes\n", - "\n", - "JAX now has three styles of parallelism:\n", - "\n", - " * *Automatic sharding* is where you treat all the devices as a single logical\n", - " machine and write a \"global view\" array program for that machine. The\n", - " compiler decides how to partition the data and computation across the\n", - " available devices. You can give hints to the compiler using\n", - " `with_sharding_constraint`.\n", - " * *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n", - " you're writing a global-view program. The difference is that the sharding\n", - " of each array is part of the array's JAX-level type making it an explicit\n", - " part of the programming model. These shardings are propagated at the JAX\n", - " level and queryable at trace time. It's still the compiler's responsibility\n", - " to turn the whole-array program into per-device programs (turning `jnp.sum`\n", - " into `psum` for example) but the compiler is heavily constrained by the\n", - " user-supplied shardings.\n", - " * *Manual Sharding* (`shard_map`) is where you write a program from the\n", - " perspective of a single device. Communication between devices happens via\n", - " explicit collective operations like psum.\n", - "\n", - "A summary table:\n", - "\n", - "| Mode | Explicit sharding? | Explicit Collectives? |\n", - "|---|---|---|\n", - "| Auto | No | No |\n", - "| Explicit (new) | Yes | No |\n", - "| Manual | Yes | Yes |\n", - "\n", - "The current mesh tells us which sharding mode we're in. We can query it with\n", - "`get_abstract_mesh`:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "geptWrdYX0OM", - "outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n" - ] - } - ], - "source": [ - "print(f\"Current mesh is: {get_abstract_mesh()}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AQQjzUeGX4P6" - }, - "source": [ - "Since `axis_types=(Explicit, Explicit)`, this means we're in fully-explicit\n", - "mode. Notice that the sharding mode is associated with a mesh _axis_, not the\n", - "mesh as a whole. We can actually mix sharding modes by having a different\n", - "sharding mode for each mesh axis. Shardings (on JAX-level types) can only\n", - "mention _explicit_ mesh axes and collective operations like `psum` can only\n", - "mention _manual_ mesh axes." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AQQjzUeGX4P6" - }, - "source": [ - "## Concrete array shardings can mention `Auto` mesh axis\n", - "\n", - "You can query the sharding of a concrete array `x` with `x.sharding`. You\n", - "might expect the result to be the same as the sharding associated with the\n", - "value's type, `jax.typeof(x).sharding`. It might not be! The concrete array sharding, `x.sharding`, describes the sharding along\n", - "both `Explicit` and `Auto` mesh axes. It's the sharding that the compiler\n", - "eventually chose. Whereas the type-specificed sharding,\n", - "`jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh\n", - "axes. The `Auto` axes are deliberately hidden from the type because they're\n", - "the purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than,\n", - "the type-specified sharding. For example:" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ivLl6bxmX7EZ", - "outputId": "6d7b7fce-68b6-47f1-b214-d62bda8d7b6e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit)) ===\n", - "Concrete value sharding: PartitionSpec('X',)\n", - "Type-specified sharding: PartitionSpec('X',)\n", - "=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto)) ===\n", - "Concrete value sharding: PartitionSpec('X',)\n", - "Type-specified sharding: PartitionSpec(None,)\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n", - " -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)" + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hVi6mApuVw3r", + "outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf" + }, + "outputs": [], + "source": [ + "import jax\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", + "from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh\n", + "from jax.experimental.shard import reshard, auto_axes\n", + "\n", + "jax.config.update('jax_num_cpu_devices', 8)" ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" }, { - "data": { - "text/plain": [ - "Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n", - " -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)" + "cell_type": "markdown", + "metadata": { + "id": "oU5O6yOLWqbP" + }, + "source": [ + "## Setting up an explicit mesh\n", + "\n", + "The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n", + "the JAX-level _type_ of a value includes a description of how the value is sharded.\n", + "We can query the JAX-level type of any JAX value (or Numpy array, or Python\n", + "scalar) using `jax.typeof`:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mzDIDvj7Vw0k", + "outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "JAX-level type of some_array: ShapedArray(int32[8])\n" + ] + } + ], + "source": [ + "some_array = np.arange(8)\n", + "print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TZzp_1sXW061" + }, + "source": [ + "Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n", + "is almost _defined_ as \"the information about a value we have access to while\n", + "under a jit)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IyPx_-IBVwxr", + "outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "JAX-level type of x during tracing: ShapedArray(int32[8])\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@jax.jit\n", + "def foo(x):\n", + " print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n", + " return x + x\n", + "\n", + "foo(some_array)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c3gNPzfZW45K" + }, + "source": [ + "These types show the shape and dtype of array but they don't appear to\n", + "show sharding. (Actually, they _did_ show sharding, but the shardings were\n", + "trivial. See \"Concrete array shardings\", below.) To start seeing some\n", + "interesting shardings we need to set up an explicit-sharding mesh. We use\n", + "`set_mesh` to set it as the current mesh for the remainder of this notebook.\n", + "(If you only want to set the mesh for some particular scope and return to the previous\n", + "mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NO2ulM_QW7a8", + "outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n" + ] + } + ], + "source": [ + "mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Explicit, AxisType.Explicit))\n", + "set_mesh(mesh)\n", + "\n", + "print(f\"Current mesh is: {get_abstract_mesh()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V7bVz6tzW_Eb" + }, + "source": [ + "Now we can create some sharded arrays using `reshard`:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1-TzmA0AXCAf", + "outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "replicated_array type: ShapedArray(int32[4,2])\n", + "sharded_array type: ShapedArray(int32[4@X,2])\n" + ] + } + ], + "source": [ + "replicated_array = np.arange(8).reshape(4, 2)\n", + "sharded_array = reshard(replicated_array, P(\"X\", None))\n", + "\n", + "print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n", + "print(f\"sharded_array type: {jax.typeof(sharded_array)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B0jBBXtgXBxr" + }, + "source": [ + "We should read the type `f32[4@X, 2]` as \"a 4-by-2 array of 32-bit floats whose first dimension\n", + "is sharded along mesh axis 'X'. The array is replicated along all other mesh\n", + "axes\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N8yMauHAXKtX" + }, + "source": [ + "These shardings associated with JAX-level types propagate through operations. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Gy7ABds3XND3", + "outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "arg0 sharding: ShapedArray(int32[4@X,1])\n", + "arg1 sharding: ShapedArray(int32[1,8@Y])\n", + "result sharding: ShapedArray(int32[4@X,8@Y])\n" + ] + } + ], + "source": [ + "arg0 = reshard(np.arange(4).reshape(4, 1), P(\"X\", None))\n", + "arg1 = reshard(np.arange(8).reshape(1, 8), P(None, \"Y\"))\n", + "\n", + "result = arg0 + arg1\n", + "\n", + "print(f\"arg0 sharding: {jax.typeof(arg0)}\")\n", + "print(f\"arg1 sharding: {jax.typeof(arg1)}\")\n", + "print(f\"result sharding: {jax.typeof(result)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lwsygUmVXPCk" + }, + "source": [ + "We can do the same type querying under a jit:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "grCcotr-XQjY", + "outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x sharding: ShapedArray(int32[4@X,1])\n", + "y sharding: ShapedArray(int32[1,8@Y])\n", + "ans sharding: ShapedArray(int32[4@X,8@Y])\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n", + " [ 1, 2, 3, 4, 5, 6, 7, 8],\n", + " [ 2, 3, 4, 5, 6, 7, 8, 9],\n", + " [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@jax.jit\n", + "def add_arrays(x, y):\n", + " ans = x + y\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"y sharding: {jax.typeof(y)}\")\n", + " print(f\"ans sharding: {jax.typeof(ans)}\")\n", + " return ans\n", + "\n", + "add_arrays(arg0, arg1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lVd6a5ufXZoH" + }, + "source": [ + "That's the gist of it. Shardings propagate deterministically at trace time and\n", + "we can query them at trace time." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ETtwK3LCXSkd" + }, + "source": [ + "## Sharding rules and operations with ambiguous sharding\n", + "\n", + "Each op has a sharding rule which specifies its output sharding given its\n", + "input shardings. A sharding rule may also throw a (trace-time) error. Each op\n", + "is free to implement whatever sharding rule it likes, but the usual pattern is\n", + "the following: For each output axis we identify zero of more corresponding\n", + "input axes. The output axis is then\n", + "sharded according to the “consensus” sharding of the corresponding input axes. i.e., it's\n", + "`None` if the input shardings are all `None`, and it's the common non-None input sharding\n", + "if there’s exactly one of them, or an error (requiring an explicit out_sharding=... kwarg) otherwise." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "an8-Fq1uXehp" + }, + "source": [ + "This procedure is done on an axis-by-axis basis. When it’s done, we might end\n", + "up with an array sharding that mentions a mesh axis more than once, which is\n", + "illegal. In that case we raise a (trace-time) sharding error and ask for an\n", + "explicit out_sharding.\n", + "\n", + "Here are some example sharding rules:\n", + " * nullary ops like `jnp.zeros`, `jnp.arange`: These ops create arrays out of whole\n", + " cloth so they don’t have input shardings to propagate. Their output is\n", + " unsharded by default unless overridden by the out_sharding kwarg.\n", + " * unary elementwise ops like `sin`, `exp`: The output is sharded the same as the\n", + " input.\n", + " * binary ops (`+`, `-`, `*` etc.): Axis shardings of “zipped” dimensions\n", + " must match (or be `None`). “Outer product” dimensions (dimensions that\n", + " appear in only one argument) are sharded as they are in the input. If the\n", + " result ends up mentioning a mesh axis more than once it's an error.\n", + " * `reshape.` Reshape is a particularly tricky op. An output axis can map to more\n", + " than one input axis (when reshape is used to merge axes) or just a part\n", + " of an input axis (when reshape is used to split axes). Our usual rules\n", + " don’t apply. Instead we treat reshape as follows. We strip away singleton\n", + " axes (these can’t be sharded anyway. Then\n", + " we decide whether the reshape is a “split” (splitting a single axis into\n", + " two or more adjacent axes), a “merge” (merging two or more adjacent axes\n", + " into a single one) or something else. If we have a split or merge case in\n", + " which the split/merged axes are sharded as None then we shard the\n", + " resulting split/merged axes as None and the other axes according to their\n", + " corresponding input axis shardings. In all other cases we throw an error\n", + " and require the user to provide an `out_shardings` argument." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZMp6w48Xmd7" + }, + "source": [ + "## JAX transformations and higher-order functions\n", + "\n", + "The staged-out representation of JAX programs is explicitly typed. (We call\n", + "the types “avals” but that’s not important.) In explicit-sharding mode, the\n", + "sharding is part of that type. This means that shardings need to match\n", + "wherever types need to match. For example, the two sides of a `lax.cond` need to\n", + "have results with matching shardings. And the carry of `lax.scan` needs to have the\n", + "same sharding at the input and the output of the scan body. And when you\n", + "contruct a jaxpr without concrete arguments using `make_jaxpr` you need to\n", + "provide shardings too. Certain JAX transformations perform type-level\n", + "operations. Automatic differentation constructs a tangent type for each primal\n", + "type in the original computation (e.g. `TangentOf(float) == float`,\n", + "`TangentOf(int) == float0`). With sharding in the types, this means that tangent\n", + "values are sharded in the same way as their primal values. Vmap and scan also\n", + "do type-level operations, they lift an array shape to a rank-augmented version\n", + "of that shape. That extra array axis needs a sharding. We can infer it from the\n", + "arguments to the vmap/scan but they all need to agree. And a nullary vmap/scan\n", + "needs an explicit sharding argument just as it needs an explicit length\n", + "argument." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ERJx4p0tXoS3" + }, + "source": [ + "## Working around unimplemented sharding rules using `auto_sharding`\n", + "\n", + "The implementation of explicit sharding is still a work-in-progress and there\n", + "are plenty of ops that are missing sharding rules. For example, `scatter` and\n", + "`gather` (i.e. indexing ops).\n", + "\n", + "Normally we wouldn't suggest using a feature with so many unimplemented cases,\n", + "but in this instance there's a reasonable fallback you can use: `auto_axes`.\n", + "The idea is that you can temporarily drop into a context where the mesh axes\n", + "are \"auto\" rather than \"explicit\". You explicitly specify how you intend the\n", + "final result of the `auto_axes` to be sharded as it gets returned to the calling context.\n", + "\n", + "This works as a fallback for ops with unimplemented sharding rules. It also\n", + "works when you want to override the sharding-in-types type system. For\n", + "example, suppose we want to add a `f32[4@X, 4]` to a `f32[4, 4@X]`. Our\n", + "sharding rule for addition would throw an error: the result would need to be\n", + "`f32[4@X, 4@X]`, which tries uses a mesh axis twice, which is illegal. But say you\n", + "want to perform the operation anyway, and you want the result to be sharded along\n", + "the first axis only, like `f32[4@X, 4]`. You can do this as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fpFEaMBcXsJG", + "outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ERROR!\n", + "add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]\n", + "=== try again with auto_axes ===\n", + "We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n", + "Result type: ShapedArray(int32[4@X,4])\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result type: ShapedArray(int32[4@X,4])\n" + ] + } + ], + "source": [ + "some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", None))\n", + "some_y = reshard(np.arange(16).reshape(4, 4), P(None, \"X\"))\n", + "\n", + "try:\n", + " some_x + some_y\n", + "except Exception as e:\n", + " print(\"ERROR!\")\n", + " print(e)\n", + "\n", + "print(\"=== try again with auto_axes ===\")\n", + "\n", + "@auto_axes\n", + "def add_with_out_sharding_kwarg(x, y):\n", + " print(f\"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}\")\n", + " return x + y\n", + "\n", + "result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P(\"X\", None))\n", + "print(f\"Result type: {jax.typeof(result)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8-_zDr-AXvb6" + }, + "source": [ + "## Using a mixture of sharding modes\n", + "\n", + "JAX now has three styles of parallelism:\n", + "\n", + " * *Automatic sharding* is where you treat all the devices as a single logical\n", + " machine and write a \"global view\" array program for that machine. The\n", + " compiler decides how to partition the data and computation across the\n", + " available devices. You can give hints to the compiler using\n", + " `with_sharding_constraint`.\n", + " * *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n", + " you're writing a global-view program. The difference is that the sharding\n", + " of each array is part of the array's JAX-level type making it an explicit\n", + " part of the programming model. These shardings are propagated at the JAX\n", + " level and queryable at trace time. It's still the compiler's responsibility\n", + " to turn the whole-array program into per-device programs (turning `jnp.sum`\n", + " into `psum` for example) but the compiler is heavily constrained by the\n", + " user-supplied shardings.\n", + " * *Manual Sharding* (`shard_map`) is where you write a program from the\n", + " perspective of a single device. Communication between devices happens via\n", + " explicit collective operations like psum.\n", + "\n", + "A summary table:\n", + "\n", + "| Mode | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|\n", + "| Auto | No | No |\n", + "| Explicit (new) | Yes | No |\n", + "| Manual | Yes | Yes |\n", + "\n", + "The current mesh tells us which sharding mode we're in. We can query it with\n", + "`get_abstract_mesh`:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "geptWrdYX0OM", + "outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n" + ] + } + ], + "source": [ + "print(f\"Current mesh is: {get_abstract_mesh()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AQQjzUeGX4P6" + }, + "source": [ + "Since `axis_types=(Explicit, Explicit)`, this means we're in fully-explicit\n", + "mode. Notice that the sharding mode is associated with a mesh _axis_, not the\n", + "mesh as a whole. We can actually mix sharding modes by having a different\n", + "sharding mode for each mesh axis. Shardings (on JAX-level types) can only\n", + "mention _explicit_ mesh axes and collective operations like `psum` can only\n", + "mention _manual_ mesh axes." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AQQjzUeGX4P6" + }, + "source": [ + "## Concrete array shardings can mention `Auto` mesh axis\n", + "\n", + "You can query the sharding of a concrete array `x` with `x.sharding`. You\n", + "might expect the result to be the same as the sharding associated with the\n", + "value's type, `jax.typeof(x).sharding`. It might not be! The concrete array sharding, `x.sharding`, describes the sharding along\n", + "both `Explicit` and `Auto` mesh axes. It's the sharding that the compiler\n", + "eventually chose. Whereas the type-specificed sharding,\n", + "`jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh\n", + "axes. The `Auto` axes are deliberately hidden from the type because they're\n", + "the purview of the compiler. We can think of the concrete array sharding being consistent with, but more specific than,\n", + "the type-specified sharding. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ivLl6bxmX7EZ", + "outputId": "6d7b7fce-68b6-47f1-b214-d62bda8d7b6e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit)) ===\n", + "Concrete value sharding: PartitionSpec('X',)\n", + "Type-specified sharding: PartitionSpec('X',)\n", + "=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto)) ===\n", + "Concrete value sharding: PartitionSpec('X',)\n", + "Type-specified sharding: PartitionSpec(None,)\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n", + " -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,\n", + " -0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def compare_shardings(x):\n", + " print(f\"=== with mesh: {get_abstract_mesh()} ===\")\n", + " print(f\"Concrete value sharding: {x.sharding.spec}\")\n", + " print(f\"Type-specified sharding: {jax.typeof(x).sharding.spec}\")\n", + "\n", + "my_array = jnp.sin(reshard(np.arange(8), P(\"X\")))\n", + "compare_shardings(my_array)\n", + "\n", + "@auto_axes\n", + "def check_in_auto_context(x):\n", + " compare_shardings(x)\n", + " return x\n", + "\n", + "check_in_auto_context(my_array, out_shardings=P(\"X\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MRFccsi5X8so" + }, + "source": [ + "Notice that at the top level, where we're currently in a fully `Explicit` mesh\n", + "context, the concrete array sharding and type-specified sharding agree. But\n", + "under the `auto_axes` decorator we're in a fully `Auto` mesh context and the\n", + "two shardings disagree: the type-specified sharding is `P(None)` whereas the\n", + "concrete array sharding is `P(\"X\")` (though it could be anything! It's up to\n", + "the compiler)." ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" } - ], - "source": [ - "def compare_shardings(x):\n", - " print(f\"=== with mesh: {get_abstract_mesh()} ===\")\n", - " print(f\"Concrete value sharding: {x.sharding.spec}\")\n", - " print(f\"Type-specified sharding: {jax.typeof(x).sharding.spec}\")\n", - "\n", - "my_array = jnp.sin(reshard(np.arange(8), P(\"X\")))\n", - "compare_shardings(my_array)\n", - "\n", - "@auto_axes\n", - "def check_in_auto_context(x):\n", - " compare_shardings(x)\n", - " return x\n", - "\n", - "check_in_auto_context(my_array, out_shardings=P(\"X\"))" - ] + ], + "metadata": { + "colab": { + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "MRFccsi5X8so" - }, - "source": [ - "Notice that at the top level, where we're currently in a fully `Explicit` mesh\n", - "context, the concrete array sharding and type-specified sharding agree. But\n", - "under the `auto_axes` decorator we're in a fully `Auto` mesh context and the\n", - "two shardings disagree: the type-specified sharding is `P(None)` whereas the\n", - "concrete array sharding is `P(\"X\")` (though it could be anything! It's up to\n", - "the compiler)." - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index cd1a580bd..b7368b5eb 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -59,9 +59,8 @@ outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf import jax import numpy as np import jax.numpy as jnp -from jax.sharding import PartitionSpec as P, AxisType, set_mesh +from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh from jax.experimental.shard import reshard, auto_axes -from jax._src.mesh import get_abstract_mesh jax.config.update('jax_num_cpu_devices', 8) ``` diff --git a/jax/sharding.py b/jax/sharding.py index 6ddc81584..55ff0f6ae 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -28,10 +28,11 @@ from jax._src.sharding_impls import ( from jax._src.partition_spec import ( PartitionSpec as PartitionSpec, ) -from jax._src.interpreters.pxla import Mesh as Mesh from jax._src.mesh import ( + Mesh as Mesh, AbstractMesh as AbstractMesh, AxisType as AxisType, + get_abstract_mesh as get_abstract_mesh, ) _deprecations = {