mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Update sharded-computation doc to use make_mesh()
This commit is contained in:
parent
1289640f09
commit
7569dd5438
@ -188,12 +188,9 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Pardon the boilerplate; constructing a sharding will become easier in future!\n",
|
||||
"from jax.experimental import mesh_utils\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"\n",
|
||||
"P = jax.sharding.PartitionSpec\n",
|
||||
"devices = mesh_utils.create_device_mesh((2, 4))\n",
|
||||
"mesh = jax.sharding.Mesh(devices, ('x', 'y'))\n",
|
||||
"mesh = jax.make_mesh((2, 4), ('x', 'y'))\n",
|
||||
"sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))\n",
|
||||
"print(sharding)"
|
||||
]
|
||||
@ -402,9 +399,7 @@
|
||||
"@jax.jit\n",
|
||||
"def f_contract_2(x):\n",
|
||||
" out = x.sum(axis=0)\n",
|
||||
" # mesh = jax.create_mesh((8,), 'x')\n",
|
||||
" devices = mesh_utils.create_device_mesh(8)\n",
|
||||
" mesh = jax.sharding.Mesh(devices, 'x')\n",
|
||||
" mesh = jax.make_mesh((8,), ('x',))\n",
|
||||
" sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
|
||||
" return jax.lax.with_sharding_constraint(out, sharding)\n",
|
||||
"\n",
|
||||
@ -457,8 +452,7 @@
|
||||
],
|
||||
"source": [
|
||||
"from jax.experimental.shard_map import shard_map\n",
|
||||
"P = jax.sharding.PartitionSpec\n",
|
||||
"mesh = jax.sharding.Mesh(jax.devices(), 'x')\n",
|
||||
"mesh = jax.make_mesh((8,), ('x',))\n",
|
||||
"\n",
|
||||
"f_elementwise_sharded = shard_map(\n",
|
||||
" f_elementwise,\n",
|
||||
@ -656,8 +650,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"P = jax.sharding.PartitionSpec\n",
|
||||
"mesh = jax.sharding.Mesh(jax.devices(), 'x')\n",
|
||||
"mesh = jax.make_mesh((8,), ('x',))\n",
|
||||
"sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
|
||||
"\n",
|
||||
"x_sharded = jax.device_put(x, sharding)\n",
|
||||
|
@ -72,12 +72,9 @@ Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimens
|
||||
```{code-cell}
|
||||
:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5
|
||||
|
||||
# Pardon the boilerplate; constructing a sharding will become easier in future!
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
P = jax.sharding.PartitionSpec
|
||||
devices = mesh_utils.create_device_mesh((2, 4))
|
||||
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
|
||||
mesh = jax.make_mesh((2, 4), ('x', 'y'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
print(sharding)
|
||||
```
|
||||
@ -146,9 +143,7 @@ For example, suppose that within `f_contract` above, you'd prefer the output not
|
||||
@jax.jit
|
||||
def f_contract_2(x):
|
||||
out = x.sum(axis=0)
|
||||
# mesh = jax.create_mesh((8,), 'x')
|
||||
devices = mesh_utils.create_device_mesh(8)
|
||||
mesh = jax.sharding.Mesh(devices, 'x')
|
||||
mesh = jax.make_mesh((8,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
return jax.lax.with_sharding_constraint(out, sharding)
|
||||
|
||||
@ -174,8 +169,7 @@ In the automatic parallelism methods explored above, you can write a function as
|
||||
:outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
P = jax.sharding.PartitionSpec
|
||||
mesh = jax.sharding.Mesh(jax.devices(), 'x')
|
||||
mesh = jax.make_mesh((8,), ('x',))
|
||||
|
||||
f_elementwise_sharded = shard_map(
|
||||
f_elementwise,
|
||||
@ -265,8 +259,7 @@ If you shard the leading axis of both `x` and `weights` in the same way, then th
|
||||
```{code-cell}
|
||||
:outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5
|
||||
|
||||
P = jax.sharding.PartitionSpec
|
||||
mesh = jax.sharding.Mesh(jax.devices(), 'x')
|
||||
mesh = jax.make_mesh((8,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
x_sharded = jax.device_put(x, sharding)
|
||||
|
Loading…
x
Reference in New Issue
Block a user