Update Pallas distributed tutorials with jax.make_mesh

This commit is contained in:
Justin Fu 2024-10-21 12:35:07 -07:00
parent 16fca386a3
commit 0b46a236c1
2 changed files with 50 additions and 62 deletions

View File

@ -45,7 +45,6 @@
"import jax\n",
"from jax import lax\n",
"from jax import numpy as jnp\n",
"from jax.experimental import mesh_utils\n",
"from jax.experimental import pallas as pl\n",
"from jax.experimental import shard_map\n",
"from jax.experimental.pallas import tpu as pltpu\n",
@ -245,8 +244,7 @@
],
"source": [
"partition = P(None, 'x')\n",
"devices = mesh_utils.create_device_mesh((1, num_devices))\n",
"mesh = jax.sharding.Mesh(devices, partition)\n",
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
"\n",
"# Create an input array that shards the last dimension across\n",
@ -263,7 +261,7 @@
" dst_ref=output_ref,\n",
" send_sem=send_sem,\n",
" recv_sem=recv_sem,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" remote_copy_op.start()\n",
@ -373,8 +371,7 @@
],
"source": [
"partition = P('x', None)\n",
"devices = mesh_utils.create_device_mesh((num_devices, 1))\n",
"mesh = jax.sharding.Mesh(devices, partition)\n",
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
"\n",
"# Create an input array that shards the first dimension across\n",
@ -413,7 +410,7 @@
" dst_ref=output_ref.at[copy_slot],\n",
" send_sem=send_sem,\n",
" recv_sem=recv_sems.at[outer_step],\n",
" device_id=(right_neighbor, 0),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" remote_copy_op.start()\n",
@ -683,8 +680,7 @@
],
"source": [
"partition = P(None, 'x')\n",
"devices = mesh_utils.create_device_mesh((1, num_devices))\n",
"mesh = jax.sharding.Mesh(devices, partition)\n",
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
"\n",
"input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))\n",
@ -717,13 +713,13 @@
" pltpu.semaphore_signal(\n",
" barrier_sem,\n",
" inc=1,\n",
" device_id=(0, left_neighbor),\n",
" device_id=(left_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" pltpu.semaphore_signal(\n",
" barrier_sem,\n",
" inc=1,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" pltpu.semaphore_wait(barrier_sem, 2)\n",
@ -736,7 +732,7 @@
" dst_ref=hbm_scratch.at[working_slot],\n",
" send_sem=remote_send_sem,\n",
" recv_sem=remote_recv_sem,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" initial_copy.start()\n",
@ -748,7 +744,7 @@
" pltpu.semaphore_signal(\n",
" capacity_sem,\n",
" inc=1,\n",
" device_id=(0, left_neighbor),\n",
" device_id=(left_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
"\n",
@ -769,7 +765,7 @@
" dst_ref=hbm_scratch.at[receiving_slot],\n",
" send_sem=remote_send_sem,\n",
" recv_sem=remote_recv_sem,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" remote_copy.start()\n",
@ -913,8 +909,7 @@
"outputs": [],
"source": [
"partition = P(None, 'x')\n",
"devices = mesh_utils.create_device_mesh((1, num_devices))\n",
"mesh = jax.sharding.Mesh(devices, partition)\n",
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
"\n",
"# We need a block size of (16, 128) to ensure that a half-slice is at least\n",
@ -944,7 +939,7 @@
" pltpu.semaphore_signal(\n",
" semaphore,\n",
" inc=1,\n",
" device_id=(0, neighbor),\n",
" device_id=(neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
"\n",
@ -985,7 +980,7 @@
" dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n",
" send_sem=left_send_sem,\n",
" recv_sem=left_recv_sem,\n",
" device_id=(0, left_neighbor),\n",
" device_id=(left_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
"\n",
@ -994,7 +989,7 @@
" dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
" send_sem=right_send_sem,\n",
" recv_sem=right_recv_sem,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
"\n",
@ -1003,7 +998,7 @@
" dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n",
" send_sem=left_send_sem,\n",
" recv_sem=left_recv_sem,\n",
" device_id=(0, left_neighbor),\n",
" device_id=(left_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" right_copy = pltpu.make_async_remote_copy(\n",
@ -1013,7 +1008,7 @@
" dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
" send_sem=right_send_sem,\n",
" recv_sem=right_recv_sem,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
"\n",
@ -1026,13 +1021,13 @@
" pltpu.semaphore_signal(\n",
" barrier_sem,\n",
" inc=1,\n",
" device_id=(0, left_neighbor),\n",
" device_id=(left_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" pltpu.semaphore_signal(\n",
" barrier_sem,\n",
" inc=1,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" pltpu.semaphore_wait(barrier_sem, 2)\n",
@ -1378,8 +1373,7 @@
"outputs": [],
"source": [
"partition = P(None, 'x')\n",
"devices = mesh_utils.create_device_mesh((1, num_devices))\n",
"mesh = jax.sharding.Mesh(devices, partition)\n",
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
"\n",
"# We pick a large outer kernel block size that we do not want to place\n",
@ -1445,7 +1439,7 @@
" dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n",
" send_sem=left_send_sem,\n",
" recv_sem=left_recv_sem,\n",
" device_id=(0, left_neighbor),\n",
" device_id=(left_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
"\n",
@ -1454,7 +1448,7 @@
" dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
" send_sem=right_send_sem,\n",
" recv_sem=right_recv_sem,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
"\n",
@ -1463,7 +1457,7 @@
" dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n",
" send_sem=left_send_sem,\n",
" recv_sem=left_recv_sem,\n",
" device_id=(0, left_neighbor),\n",
" device_id=(left_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" right_copy = pltpu.make_async_remote_copy(\n",
@ -1471,7 +1465,7 @@
" dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
" send_sem=right_send_sem,\n",
" recv_sem=right_recv_sem,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
"\n",
@ -1484,13 +1478,13 @@
" pltpu.semaphore_signal(\n",
" barrier_sem,\n",
" inc=1,\n",
" device_id=(0, left_neighbor),\n",
" device_id=(left_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" pltpu.semaphore_signal(\n",
" barrier_sem,\n",
" inc=1,\n",
" device_id=(0, right_neighbor),\n",
" device_id=(right_neighbor,),\n",
" device_id_type=pltpu.DeviceIdType.MESH,\n",
" )\n",
" pltpu.semaphore_wait(barrier_sem, 2)\n",

View File

@ -39,7 +39,6 @@ outputId: 1d8229bd-cab5-495f-93e9-fff2e41db480
import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
@ -207,8 +206,7 @@ id: YkyIKN2thZ-V
outputId: 9b7ed142-d161-4237-fed8-cbce41adc5f0
---
partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# Create an input array that shards the last dimension across
@ -225,7 +223,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
dst_ref=output_ref,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy_op.start()
@ -309,8 +307,7 @@ id: ojQEZB5mBRqM
outputId: e1648f54-737c-4921-ca3b-b4c639a38d2b
---
partition = P('x', None)
devices = mesh_utils.create_device_mesh((num_devices, 1))
mesh = jax.sharding.Mesh(devices, partition)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# Create an input array that shards the first dimension across
@ -349,7 +346,7 @@ def all_gather_kernel(input_ref,
dst_ref=output_ref.at[copy_slot],
send_sem=send_sem,
recv_sem=recv_sems.at[outer_step],
device_id=(right_neighbor, 0),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy_op.start()
@ -577,8 +574,7 @@ id: XrY5bMlvBroQ
outputId: 77497000-4496-462e-cc3c-73fb640cc14c
---
partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))
@ -611,13 +607,13 @@ def all_reduce_kernel(
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(0, left_neighbor),
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
@ -630,7 +626,7 @@ def all_reduce_kernel(
dst_ref=hbm_scratch.at[working_slot],
send_sem=remote_send_sem,
recv_sem=remote_recv_sem,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_copy.start()
@ -642,7 +638,7 @@ def all_reduce_kernel(
pltpu.semaphore_signal(
capacity_sem,
inc=1,
device_id=(0, left_neighbor),
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
@ -663,7 +659,7 @@ def all_reduce_kernel(
dst_ref=hbm_scratch.at[receiving_slot],
send_sem=remote_send_sem,
recv_sem=remote_recv_sem,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy.start()
@ -786,8 +782,7 @@ executionInfo:
id: nRauUAxNHg28
---
partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# We need a block size of (16, 128) to ensure that a half-slice is at least
@ -817,7 +812,7 @@ def signal(left_or_right, semaphore):
pltpu.semaphore_signal(
semaphore,
inc=1,
device_id=(0, neighbor),
device_id=(neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
@ -858,7 +853,7 @@ def reduce_scatter_kernel(
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(0, left_neighbor),
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
@ -867,7 +862,7 @@ def reduce_scatter_kernel(
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
@ -876,7 +871,7 @@ def reduce_scatter_kernel(
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(0, left_neighbor),
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
right_copy = pltpu.make_async_remote_copy(
@ -886,7 +881,7 @@ def reduce_scatter_kernel(
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
@ -899,13 +894,13 @@ def reduce_scatter_kernel(
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(0, left_neighbor),
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
@ -1212,8 +1207,7 @@ executionInfo:
id: 27jni-pSartL
---
partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# We pick a large outer kernel block size that we do not want to place
@ -1279,7 +1273,7 @@ def reduce_scatter_kernel(
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(0, left_neighbor),
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
@ -1288,7 +1282,7 @@ def reduce_scatter_kernel(
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
@ -1297,7 +1291,7 @@ def reduce_scatter_kernel(
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(0, left_neighbor),
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
right_copy = pltpu.make_async_remote_copy(
@ -1305,7 +1299,7 @@ def reduce_scatter_kernel(
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
@ -1318,13 +1312,13 @@ def reduce_scatter_kernel(
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(0, left_neighbor),
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(0, right_neighbor),
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)