diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index 850de2541..d656e12d4 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -49,13 +49,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "hVi6mApuVw3r", - "outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf" + "id": "hVi6mApuVw3r" }, "outputs": [], "source": [ @@ -84,13 +80,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mzDIDvj7Vw0k", - "outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434" + "outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a" }, "outputs": [ { @@ -119,13 +115,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "IyPx_-IBVwxr", - "outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499" + "outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb" }, "outputs": [ { @@ -141,7 +137,7 @@ "Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)" ] }, - "execution_count": 3, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -172,13 +168,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NO2ulM_QW7a8", - "outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb" + "outputId": "d888371b-080e-4bff-be5d-ea56beda3aac" }, "outputs": [ { @@ -208,13 +204,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1-TzmA0AXCAf", - "outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71" + "outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2" }, "outputs": [ { @@ -256,13 +252,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Gy7ABds3XND3", - "outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b" + "outputId": "0d72dad2-381a-4e96-f771-40d705da1376" }, "outputs": [ { @@ -297,13 +293,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "grCcotr-XQjY", - "outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a" + "outputId": "c2db656c-809f-49a6-c948-629d6420360c" }, "outputs": [ { @@ -324,7 +320,7 @@ " [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)" ] }, - "execution_count": 7, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -460,13 +456,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fpFEaMBcXsJG", - "outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660" + "outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef" }, "outputs": [ { @@ -479,13 +475,6 @@ "We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n", "Result type: ShapedArray(int32[4@X,4])\n" ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result type: ShapedArray(int32[4@X,4])\n" - ] } ], "source": [ @@ -550,13 +539,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "geptWrdYX0OM", - "outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f" + "outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f" }, "outputs": [ { @@ -588,7 +577,88 @@ { "cell_type": "markdown", "metadata": { - "id": "AQQjzUeGX4P6" + "id": "LZWjgiMZ7uSS" + }, + "source": [ + "You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IVzPSkp77uCF", + "outputId": "db80a604-98ac-4343-8677-23729adf7ffc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n", + "x.sharding: ShapedArray(float32[4@X,4@Y])\n", + "\n", + "mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n", + "y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n", + "\n", + "z.sharding: ShapedArray(float32[4@X,4@Y])\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n", + " [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n", + " [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n", + " [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import functools\n", + "\n", + "@functools.partial(auto_axes, axes='X')\n", + "def g(y):\n", + " print(f'mesh inside g: {get_abstract_mesh()}')\n", + " print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n", + " return y * 2\n", + "\n", + "@jax.jit\n", + "def f(arr1):\n", + " print(f'mesh inside f: {get_abstract_mesh()}')\n", + " x = jnp.sin(arr1)\n", + " print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n", + "\n", + " z = g(x, out_shardings=P(\"X\", \"Y\"))\n", + "\n", + " print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n", + " return z + 1\n", + "\n", + "some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n", + "f(some_x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_3sfJjRq8w9f" + }, + "source": [ + "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sJcWbfAh7UcO" }, "source": [ "## Concrete array shardings can mention `Auto` mesh axis\n", @@ -606,7 +676,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -708,5 +778,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 0 } diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index b7368b5eb..7c59a675d 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -50,12 +50,8 @@ expect there to be bugs and unimplemented cases. Please let us know when you find something that doesn't work! ```{code-cell} ipython3 ---- -colab: - base_uri: https://localhost:8080/ -id: hVi6mApuVw3r -outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf ---- +:id: hVi6mApuVw3r + import jax import numpy as np import jax.numpy as jnp @@ -79,7 +75,7 @@ scalar) using `jax.typeof`: colab: base_uri: https://localhost:8080/ id: mzDIDvj7Vw0k -outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434 +outputId: 09ef049b-461f-47db-bf58-dc10b42fe40a --- some_array = np.arange(8) print(f"JAX-level type of some_array: {jax.typeof(some_array)}") @@ -96,7 +92,7 @@ under a jit). colab: base_uri: https://localhost:8080/ id: IyPx_-IBVwxr -outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499 +outputId: 0cd3122f-e579-45d7-868d-e42bb0eacddb --- @jax.jit def foo(x): @@ -121,7 +117,7 @@ mesh afterwards then you can use the context manager `jax.sharding.use_mesh` ins colab: base_uri: https://localhost:8080/ id: NO2ulM_QW7a8 -outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb +outputId: d888371b-080e-4bff-be5d-ea56beda3aac --- mesh = jax.make_mesh((2, 4), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)) @@ -139,7 +135,7 @@ Now we can create some sharded arrays using `reshard`: colab: base_uri: https://localhost:8080/ id: 1-TzmA0AXCAf -outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71 +outputId: 1c7cc3ac-4b0e-42b7-facc-c706af10d7d2 --- replicated_array = np.arange(8).reshape(4, 2) sharded_array = reshard(replicated_array, P("X", None)) @@ -163,7 +159,7 @@ These shardings associated with JAX-level types propagate through operations. Fo colab: base_uri: https://localhost:8080/ id: Gy7ABds3XND3 -outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b +outputId: 0d72dad2-381a-4e96-f771-40d705da1376 --- arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None)) arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y")) @@ -184,7 +180,7 @@ We can do the same type querying under a jit: colab: base_uri: https://localhost:8080/ id: grCcotr-XQjY -outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a +outputId: c2db656c-809f-49a6-c948-629d6420360c --- @jax.jit def add_arrays(x, y): @@ -294,7 +290,7 @@ the first axis only, like `f32[4@X, 4]`. You can do this as follows: colab: base_uri: https://localhost:8080/ id: fpFEaMBcXsJG -outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660 +outputId: 5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef --- some_x = reshard(np.arange(16).reshape(4, 4), P("X", None)) some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X")) @@ -355,7 +351,7 @@ The current mesh tells us which sharding mode we're in. We can query it with colab: base_uri: https://localhost:8080/ id: geptWrdYX0OM -outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f +outputId: b8c3813f-60bb-4ccf-9da7-73462c57963f --- print(f"Current mesh is: {get_abstract_mesh()}") ``` @@ -369,7 +365,45 @@ sharding mode for each mesh axis. Shardings (on JAX-level types) can only mention _explicit_ mesh axes and collective operations like `psum` can only mention _manual_ mesh axes. -+++ {"id": "AQQjzUeGX4P6"} ++++ {"id": "LZWjgiMZ7uSS"} + +You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example: + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: IVzPSkp77uCF +outputId: db80a604-98ac-4343-8677-23729adf7ffc +--- +import functools + +@functools.partial(auto_axes, axes='X') +def g(y): + print(f'mesh inside g: {get_abstract_mesh()}') + print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n') + return y * 2 + +@jax.jit +def f(arr1): + print(f'mesh inside f: {get_abstract_mesh()}') + x = jnp.sin(arr1) + print(f'x.sharding: {jax.typeof(x)}', end='\n\n') + + z = g(x, out_shardings=P("X", "Y")) + + print(f'z.sharding: {jax.typeof(z)}', end="\n\n") + return z + 1 + +some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y")) +f(some_x) +``` + ++++ {"id": "_3sfJjRq8w9f"} + +As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`. + ++++ {"id": "sJcWbfAh7UcO"} ## Concrete array shardings can mention `Auto` mesh axis