mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
mixing modes
This commit is contained in:
parent
7db59cdcca
commit
3c0027af3b
@ -49,13 +49,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "hVi6mApuVw3r",
|
||||
"outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf"
|
||||
"id": "hVi6mApuVw3r"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -84,13 +80,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "mzDIDvj7Vw0k",
|
||||
"outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434"
|
||||
"outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -119,13 +115,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "IyPx_-IBVwxr",
|
||||
"outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499"
|
||||
"outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -141,7 +137,7 @@
|
||||
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -172,13 +168,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "NO2ulM_QW7a8",
|
||||
"outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb"
|
||||
"outputId": "d888371b-080e-4bff-be5d-ea56beda3aac"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -208,13 +204,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "1-TzmA0AXCAf",
|
||||
"outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71"
|
||||
"outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -256,13 +252,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Gy7ABds3XND3",
|
||||
"outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b"
|
||||
"outputId": "0d72dad2-381a-4e96-f771-40d705da1376"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -297,13 +293,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "grCcotr-XQjY",
|
||||
"outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a"
|
||||
"outputId": "c2db656c-809f-49a6-c948-629d6420360c"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -324,7 +320,7 @@
|
||||
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -460,13 +456,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "fpFEaMBcXsJG",
|
||||
"outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660"
|
||||
"outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -479,13 +475,6 @@
|
||||
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
|
||||
"Result type: ShapedArray(int32[4@X,4])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Result type: ShapedArray(int32[4@X,4])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@ -550,13 +539,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "geptWrdYX0OM",
|
||||
"outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f"
|
||||
"outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -588,7 +577,88 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AQQjzUeGX4P6"
|
||||
"id": "LZWjgiMZ7uSS"
|
||||
},
|
||||
"source": [
|
||||
"You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "IVzPSkp77uCF",
|
||||
"outputId": "db80a604-98ac-4343-8677-23729adf7ffc"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n",
|
||||
"x.sharding: ShapedArray(float32[4@X,4@Y])\n",
|
||||
"\n",
|
||||
"mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n",
|
||||
"y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n",
|
||||
"\n",
|
||||
"z.sharding: ShapedArray(float32[4@X,4@Y])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n",
|
||||
" [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n",
|
||||
" [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n",
|
||||
" [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import functools\n",
|
||||
"\n",
|
||||
"@functools.partial(auto_axes, axes='X')\n",
|
||||
"def g(y):\n",
|
||||
" print(f'mesh inside g: {get_abstract_mesh()}')\n",
|
||||
" print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n",
|
||||
" return y * 2\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"def f(arr1):\n",
|
||||
" print(f'mesh inside f: {get_abstract_mesh()}')\n",
|
||||
" x = jnp.sin(arr1)\n",
|
||||
" print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n",
|
||||
"\n",
|
||||
" z = g(x, out_shardings=P(\"X\", \"Y\"))\n",
|
||||
"\n",
|
||||
" print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n",
|
||||
" return z + 1\n",
|
||||
"\n",
|
||||
"some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
|
||||
"f(some_x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "_3sfJjRq8w9f"
|
||||
},
|
||||
"source": [
|
||||
"As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "sJcWbfAh7UcO"
|
||||
},
|
||||
"source": [
|
||||
"## Concrete array shardings can mention `Auto` mesh axis\n",
|
||||
@ -606,7 +676,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@ -708,5 +778,5 @@
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
@ -50,12 +50,8 @@ expect there to be bugs and unimplemented cases. Please let us know when you
|
||||
find something that doesn't work!
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: hVi6mApuVw3r
|
||||
outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
|
||||
---
|
||||
:id: hVi6mApuVw3r
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
@ -79,7 +75,7 @@ scalar) using `jax.typeof`:
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: mzDIDvj7Vw0k
|
||||
outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434
|
||||
outputId: 09ef049b-461f-47db-bf58-dc10b42fe40a
|
||||
---
|
||||
some_array = np.arange(8)
|
||||
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
|
||||
@ -96,7 +92,7 @@ under a jit).
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: IyPx_-IBVwxr
|
||||
outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499
|
||||
outputId: 0cd3122f-e579-45d7-868d-e42bb0eacddb
|
||||
---
|
||||
@jax.jit
|
||||
def foo(x):
|
||||
@ -121,7 +117,7 @@ mesh afterwards then you can use the context manager `jax.sharding.use_mesh` ins
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: NO2ulM_QW7a8
|
||||
outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb
|
||||
outputId: d888371b-080e-4bff-be5d-ea56beda3aac
|
||||
---
|
||||
mesh = jax.make_mesh((2, 4), ("X", "Y"),
|
||||
axis_types=(AxisType.Explicit, AxisType.Explicit))
|
||||
@ -139,7 +135,7 @@ Now we can create some sharded arrays using `reshard`:
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: 1-TzmA0AXCAf
|
||||
outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71
|
||||
outputId: 1c7cc3ac-4b0e-42b7-facc-c706af10d7d2
|
||||
---
|
||||
replicated_array = np.arange(8).reshape(4, 2)
|
||||
sharded_array = reshard(replicated_array, P("X", None))
|
||||
@ -163,7 +159,7 @@ These shardings associated with JAX-level types propagate through operations. Fo
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: Gy7ABds3XND3
|
||||
outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b
|
||||
outputId: 0d72dad2-381a-4e96-f771-40d705da1376
|
||||
---
|
||||
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
|
||||
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
|
||||
@ -184,7 +180,7 @@ We can do the same type querying under a jit:
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: grCcotr-XQjY
|
||||
outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a
|
||||
outputId: c2db656c-809f-49a6-c948-629d6420360c
|
||||
---
|
||||
@jax.jit
|
||||
def add_arrays(x, y):
|
||||
@ -294,7 +290,7 @@ the first axis only, like `f32[4@X, 4]`. You can do this as follows:
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: fpFEaMBcXsJG
|
||||
outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660
|
||||
outputId: 5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef
|
||||
---
|
||||
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
|
||||
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
|
||||
@ -355,7 +351,7 @@ The current mesh tells us which sharding mode we're in. We can query it with
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: geptWrdYX0OM
|
||||
outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f
|
||||
outputId: b8c3813f-60bb-4ccf-9da7-73462c57963f
|
||||
---
|
||||
print(f"Current mesh is: {get_abstract_mesh()}")
|
||||
```
|
||||
@ -369,7 +365,45 @@ sharding mode for each mesh axis. Shardings (on JAX-level types) can only
|
||||
mention _explicit_ mesh axes and collective operations like `psum` can only
|
||||
mention _manual_ mesh axes.
|
||||
|
||||
+++ {"id": "AQQjzUeGX4P6"}
|
||||
+++ {"id": "LZWjgiMZ7uSS"}
|
||||
|
||||
You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: IVzPSkp77uCF
|
||||
outputId: db80a604-98ac-4343-8677-23729adf7ffc
|
||||
---
|
||||
import functools
|
||||
|
||||
@functools.partial(auto_axes, axes='X')
|
||||
def g(y):
|
||||
print(f'mesh inside g: {get_abstract_mesh()}')
|
||||
print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n')
|
||||
return y * 2
|
||||
|
||||
@jax.jit
|
||||
def f(arr1):
|
||||
print(f'mesh inside f: {get_abstract_mesh()}')
|
||||
x = jnp.sin(arr1)
|
||||
print(f'x.sharding: {jax.typeof(x)}', end='\n\n')
|
||||
|
||||
z = g(x, out_shardings=P("X", "Y"))
|
||||
|
||||
print(f'z.sharding: {jax.typeof(z)}', end="\n\n")
|
||||
return z + 1
|
||||
|
||||
some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y"))
|
||||
f(some_x)
|
||||
```
|
||||
|
||||
+++ {"id": "_3sfJjRq8w9f"}
|
||||
|
||||
As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.
|
||||
|
||||
+++ {"id": "sJcWbfAh7UcO"}
|
||||
|
||||
## Concrete array shardings can mention `Auto` mesh axis
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user