1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Update sharded-computation doc to use make_mesh()

This commit is contained in:
Jake VanderPlas 2024-09-03 16:04:23 -07:00
parent 1289640f09
commit 7569dd5438
2 changed files with 10 additions and 24 deletions

@ -188,12 +188,9 @@
}
],
"source": [
"# Pardon the boilerplate; constructing a sharding will become easier in future!\n",
"from jax.experimental import mesh_utils\n",
"from jax.sharding import PartitionSpec as P\n",
"\n",
"P = jax.sharding.PartitionSpec\n",
"devices = mesh_utils.create_device_mesh((2, 4))\n",
"mesh = jax.sharding.Mesh(devices, ('x', 'y'))\n",
"mesh = jax.make_mesh((2, 4), ('x', 'y'))\n",
"sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))\n",
"print(sharding)"
]
@ -402,9 +399,7 @@
"@jax.jit\n",
"def f_contract_2(x):\n",
" out = x.sum(axis=0)\n",
" # mesh = jax.create_mesh((8,), 'x')\n",
" devices = mesh_utils.create_device_mesh(8)\n",
" mesh = jax.sharding.Mesh(devices, 'x')\n",
" mesh = jax.make_mesh((8,), ('x',))\n",
" sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
" return jax.lax.with_sharding_constraint(out, sharding)\n",
"\n",
@ -457,8 +452,7 @@
],
"source": [
"from jax.experimental.shard_map import shard_map\n",
"P = jax.sharding.PartitionSpec\n",
"mesh = jax.sharding.Mesh(jax.devices(), 'x')\n",
"mesh = jax.make_mesh((8,), ('x',))\n",
"\n",
"f_elementwise_sharded = shard_map(\n",
" f_elementwise,\n",
@ -656,8 +650,7 @@
}
],
"source": [
"P = jax.sharding.PartitionSpec\n",
"mesh = jax.sharding.Mesh(jax.devices(), 'x')\n",
"mesh = jax.make_mesh((8,), ('x',))\n",
"sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
"\n",
"x_sharded = jax.device_put(x, sharding)\n",

@ -72,12 +72,9 @@ Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimens
```{code-cell}
:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5
# Pardon the boilerplate; constructing a sharding will become easier in future!
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec as P
P = jax.sharding.PartitionSpec
devices = mesh_utils.create_device_mesh((2, 4))
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
```
@ -146,9 +143,7 @@ For example, suppose that within `f_contract` above, you'd prefer the output not
@jax.jit
def f_contract_2(x):
out = x.sum(axis=0)
# mesh = jax.create_mesh((8,), 'x')
devices = mesh_utils.create_device_mesh(8)
mesh = jax.sharding.Mesh(devices, 'x')
mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
return jax.lax.with_sharding_constraint(out, sharding)
@ -174,8 +169,7 @@ In the automatic parallelism methods explored above, you can write a function as
:outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2
from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
mesh = jax.make_mesh((8,), ('x',))
f_elementwise_sharded = shard_map(
f_elementwise,
@ -265,8 +259,7 @@ If you shard the leading axis of both `x` and `weights` in the same way, then th
```{code-cell}
:outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x_sharded = jax.device_put(x, sharding)