mixing modes

This commit is contained in:
Yash Katariya 2025-03-14 18:04:05 -07:00
parent 7db59cdcca
commit 3c0027af3b
2 changed files with 153 additions and 49 deletions

View File

@ -49,13 +49,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hVi6mApuVw3r",
"outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf"
"id": "hVi6mApuVw3r"
},
"outputs": [],
"source": [
@ -84,13 +80,13 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mzDIDvj7Vw0k",
"outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434"
"outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a"
},
"outputs": [
{
@ -119,13 +115,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IyPx_-IBVwxr",
"outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499"
"outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb"
},
"outputs": [
{
@ -141,7 +137,7 @@
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
]
},
"execution_count": 3,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -172,13 +168,13 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NO2ulM_QW7a8",
"outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb"
"outputId": "d888371b-080e-4bff-be5d-ea56beda3aac"
},
"outputs": [
{
@ -208,13 +204,13 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1-TzmA0AXCAf",
"outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71"
"outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2"
},
"outputs": [
{
@ -256,13 +252,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Gy7ABds3XND3",
"outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b"
"outputId": "0d72dad2-381a-4e96-f771-40d705da1376"
},
"outputs": [
{
@ -297,13 +293,13 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "grCcotr-XQjY",
"outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a"
"outputId": "c2db656c-809f-49a6-c948-629d6420360c"
},
"outputs": [
{
@ -324,7 +320,7 @@
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
]
},
"execution_count": 7,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@ -460,13 +456,13 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fpFEaMBcXsJG",
"outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660"
"outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef"
},
"outputs": [
{
@ -479,13 +475,6 @@
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
"Result type: ShapedArray(int32[4@X,4])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result type: ShapedArray(int32[4@X,4])\n"
]
}
],
"source": [
@ -550,13 +539,13 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "geptWrdYX0OM",
"outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f"
"outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f"
},
"outputs": [
{
@ -588,7 +577,88 @@
{
"cell_type": "markdown",
"metadata": {
"id": "AQQjzUeGX4P6"
"id": "LZWjgiMZ7uSS"
},
"source": [
"You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IVzPSkp77uCF",
"outputId": "db80a604-98ac-4343-8677-23729adf7ffc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n",
"x.sharding: ShapedArray(float32[4@X,4@Y])\n",
"\n",
"mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n",
"y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n",
"\n",
"z.sharding: ShapedArray(float32[4@X,4@Y])\n",
"\n"
]
},
{
"data": {
"text/plain": [
"Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n",
" [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n",
" [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n",
" [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import functools\n",
"\n",
"@functools.partial(auto_axes, axes='X')\n",
"def g(y):\n",
" print(f'mesh inside g: {get_abstract_mesh()}')\n",
" print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n",
" return y * 2\n",
"\n",
"@jax.jit\n",
"def f(arr1):\n",
" print(f'mesh inside f: {get_abstract_mesh()}')\n",
" x = jnp.sin(arr1)\n",
" print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n",
"\n",
" z = g(x, out_shardings=P(\"X\", \"Y\"))\n",
"\n",
" print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n",
" return z + 1\n",
"\n",
"some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
"f(some_x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3sfJjRq8w9f"
},
"source": [
"As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJcWbfAh7UcO"
},
"source": [
"## Concrete array shardings can mention `Auto` mesh axis\n",
@ -606,7 +676,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -708,5 +778,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 0
}

View File

@ -50,12 +50,8 @@ expect there to be bugs and unimplemented cases. Please let us know when you
find something that doesn't work!
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: hVi6mApuVw3r
outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
---
:id: hVi6mApuVw3r
import jax
import numpy as np
import jax.numpy as jnp
@ -79,7 +75,7 @@ scalar) using `jax.typeof`:
colab:
base_uri: https://localhost:8080/
id: mzDIDvj7Vw0k
outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434
outputId: 09ef049b-461f-47db-bf58-dc10b42fe40a
---
some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
@ -96,7 +92,7 @@ under a jit).
colab:
base_uri: https://localhost:8080/
id: IyPx_-IBVwxr
outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499
outputId: 0cd3122f-e579-45d7-868d-e42bb0eacddb
---
@jax.jit
def foo(x):
@ -121,7 +117,7 @@ mesh afterwards then you can use the context manager `jax.sharding.use_mesh` ins
colab:
base_uri: https://localhost:8080/
id: NO2ulM_QW7a8
outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb
outputId: d888371b-080e-4bff-be5d-ea56beda3aac
---
mesh = jax.make_mesh((2, 4), ("X", "Y"),
axis_types=(AxisType.Explicit, AxisType.Explicit))
@ -139,7 +135,7 @@ Now we can create some sharded arrays using `reshard`:
colab:
base_uri: https://localhost:8080/
id: 1-TzmA0AXCAf
outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71
outputId: 1c7cc3ac-4b0e-42b7-facc-c706af10d7d2
---
replicated_array = np.arange(8).reshape(4, 2)
sharded_array = reshard(replicated_array, P("X", None))
@ -163,7 +159,7 @@ These shardings associated with JAX-level types propagate through operations. Fo
colab:
base_uri: https://localhost:8080/
id: Gy7ABds3XND3
outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b
outputId: 0d72dad2-381a-4e96-f771-40d705da1376
---
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
@ -184,7 +180,7 @@ We can do the same type querying under a jit:
colab:
base_uri: https://localhost:8080/
id: grCcotr-XQjY
outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a
outputId: c2db656c-809f-49a6-c948-629d6420360c
---
@jax.jit
def add_arrays(x, y):
@ -294,7 +290,7 @@ the first axis only, like `f32[4@X, 4]`. You can do this as follows:
colab:
base_uri: https://localhost:8080/
id: fpFEaMBcXsJG
outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660
outputId: 5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef
---
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
@ -355,7 +351,7 @@ The current mesh tells us which sharding mode we're in. We can query it with
colab:
base_uri: https://localhost:8080/
id: geptWrdYX0OM
outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f
outputId: b8c3813f-60bb-4ccf-9da7-73462c57963f
---
print(f"Current mesh is: {get_abstract_mesh()}")
```
@ -369,7 +365,45 @@ sharding mode for each mesh axis. Shardings (on JAX-level types) can only
mention _explicit_ mesh axes and collective operations like `psum` can only
mention _manual_ mesh axes.
+++ {"id": "AQQjzUeGX4P6"}
+++ {"id": "LZWjgiMZ7uSS"}
You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: IVzPSkp77uCF
outputId: db80a604-98ac-4343-8677-23729adf7ffc
---
import functools
@functools.partial(auto_axes, axes='X')
def g(y):
print(f'mesh inside g: {get_abstract_mesh()}')
print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n')
return y * 2
@jax.jit
def f(arr1):
print(f'mesh inside f: {get_abstract_mesh()}')
x = jnp.sin(arr1)
print(f'x.sharding: {jax.typeof(x)}', end='\n\n')
z = g(x, out_shardings=P("X", "Y"))
print(f'z.sharding: {jax.typeof(z)}', end="\n\n")
return z + 1
some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y"))
f(some_x)
```
+++ {"id": "_3sfJjRq8w9f"}
As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.
+++ {"id": "sJcWbfAh7UcO"}
## Concrete array shardings can mention `Auto` mesh axis