diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index 95abf803a..b52ec579f 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -45,7 +45,6 @@ "import jax\n", "from jax import lax\n", "from jax import numpy as jnp\n", - "from jax.experimental import mesh_utils\n", "from jax.experimental import pallas as pl\n", "from jax.experimental import shard_map\n", "from jax.experimental.pallas import tpu as pltpu\n", @@ -245,8 +244,7 @@ ], "source": [ "partition = P(None, 'x')\n", - "devices = mesh_utils.create_device_mesh((1, num_devices))\n", - "mesh = jax.sharding.Mesh(devices, partition)\n", + "mesh = jax.make_mesh((num_devices,), ('x',))\n", "sharding = jax.sharding.NamedSharding(mesh, partition)\n", "\n", "# Create an input array that shards the last dimension across\n", @@ -263,7 +261,7 @@ " dst_ref=output_ref,\n", " send_sem=send_sem,\n", " recv_sem=recv_sem,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " remote_copy_op.start()\n", @@ -373,8 +371,7 @@ ], "source": [ "partition = P('x', None)\n", - "devices = mesh_utils.create_device_mesh((num_devices, 1))\n", - "mesh = jax.sharding.Mesh(devices, partition)\n", + "mesh = jax.make_mesh((num_devices,), ('x',))\n", "sharding = jax.sharding.NamedSharding(mesh, partition)\n", "\n", "# Create an input array that shards the first dimension across\n", @@ -413,7 +410,7 @@ " dst_ref=output_ref.at[copy_slot],\n", " send_sem=send_sem,\n", " recv_sem=recv_sems.at[outer_step],\n", - " device_id=(right_neighbor, 0),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " remote_copy_op.start()\n", @@ -683,8 +680,7 @@ ], "source": [ "partition = P(None, 'x')\n", - "devices = mesh_utils.create_device_mesh((1, num_devices))\n", - "mesh = jax.sharding.Mesh(devices, partition)\n", + "mesh = jax.make_mesh((num_devices,), ('x',))\n", "sharding = jax.sharding.NamedSharding(mesh, partition)\n", "\n", "input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))\n", @@ -717,13 +713,13 @@ " pltpu.semaphore_signal(\n", " barrier_sem,\n", " inc=1,\n", - " device_id=(0, left_neighbor),\n", + " device_id=(left_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " pltpu.semaphore_signal(\n", " barrier_sem,\n", " inc=1,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " pltpu.semaphore_wait(barrier_sem, 2)\n", @@ -736,7 +732,7 @@ " dst_ref=hbm_scratch.at[working_slot],\n", " send_sem=remote_send_sem,\n", " recv_sem=remote_recv_sem,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " initial_copy.start()\n", @@ -748,7 +744,7 @@ " pltpu.semaphore_signal(\n", " capacity_sem,\n", " inc=1,\n", - " device_id=(0, left_neighbor),\n", + " device_id=(left_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", "\n", @@ -769,7 +765,7 @@ " dst_ref=hbm_scratch.at[receiving_slot],\n", " send_sem=remote_send_sem,\n", " recv_sem=remote_recv_sem,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " remote_copy.start()\n", @@ -913,8 +909,7 @@ "outputs": [], "source": [ "partition = P(None, 'x')\n", - "devices = mesh_utils.create_device_mesh((1, num_devices))\n", - "mesh = jax.sharding.Mesh(devices, partition)\n", + "mesh = jax.make_mesh((num_devices,), ('x',))\n", "sharding = jax.sharding.NamedSharding(mesh, partition)\n", "\n", "# We need a block size of (16, 128) to ensure that a half-slice is at least\n", @@ -944,7 +939,7 @@ " pltpu.semaphore_signal(\n", " semaphore,\n", " inc=1,\n", - " device_id=(0, neighbor),\n", + " device_id=(neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", "\n", @@ -985,7 +980,7 @@ " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", " send_sem=left_send_sem,\n", " recv_sem=left_recv_sem,\n", - " device_id=(0, left_neighbor),\n", + " device_id=(left_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", "\n", @@ -994,7 +989,7 @@ " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", " send_sem=right_send_sem,\n", " recv_sem=right_recv_sem,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", "\n", @@ -1003,7 +998,7 @@ " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n", " send_sem=left_send_sem,\n", " recv_sem=left_recv_sem,\n", - " device_id=(0, left_neighbor),\n", + " device_id=(left_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " right_copy = pltpu.make_async_remote_copy(\n", @@ -1013,7 +1008,7 @@ " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", " send_sem=right_send_sem,\n", " recv_sem=right_recv_sem,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", "\n", @@ -1026,13 +1021,13 @@ " pltpu.semaphore_signal(\n", " barrier_sem,\n", " inc=1,\n", - " device_id=(0, left_neighbor),\n", + " device_id=(left_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " pltpu.semaphore_signal(\n", " barrier_sem,\n", " inc=1,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " pltpu.semaphore_wait(barrier_sem, 2)\n", @@ -1378,8 +1373,7 @@ "outputs": [], "source": [ "partition = P(None, 'x')\n", - "devices = mesh_utils.create_device_mesh((1, num_devices))\n", - "mesh = jax.sharding.Mesh(devices, partition)\n", + "mesh = jax.make_mesh((num_devices,), ('x',))\n", "sharding = jax.sharding.NamedSharding(mesh, partition)\n", "\n", "# We pick a large outer kernel block size that we do not want to place\n", @@ -1445,7 +1439,7 @@ " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", " send_sem=left_send_sem,\n", " recv_sem=left_recv_sem,\n", - " device_id=(0, left_neighbor),\n", + " device_id=(left_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", "\n", @@ -1454,7 +1448,7 @@ " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", " send_sem=right_send_sem,\n", " recv_sem=right_recv_sem,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", "\n", @@ -1463,7 +1457,7 @@ " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n", " send_sem=left_send_sem,\n", " recv_sem=left_recv_sem,\n", - " device_id=(0, left_neighbor),\n", + " device_id=(left_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " right_copy = pltpu.make_async_remote_copy(\n", @@ -1471,7 +1465,7 @@ " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", " send_sem=right_send_sem,\n", " recv_sem=right_recv_sem,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", "\n", @@ -1484,13 +1478,13 @@ " pltpu.semaphore_signal(\n", " barrier_sem,\n", " inc=1,\n", - " device_id=(0, left_neighbor),\n", + " device_id=(left_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " pltpu.semaphore_signal(\n", " barrier_sem,\n", " inc=1,\n", - " device_id=(0, right_neighbor),\n", + " device_id=(right_neighbor,),\n", " device_id_type=pltpu.DeviceIdType.MESH,\n", " )\n", " pltpu.semaphore_wait(barrier_sem, 2)\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index c71f75ec6..c1f216c61 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -39,7 +39,6 @@ outputId: 1d8229bd-cab5-495f-93e9-fff2e41db480 import jax from jax import lax from jax import numpy as jnp -from jax.experimental import mesh_utils from jax.experimental import pallas as pl from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu @@ -207,8 +206,7 @@ id: YkyIKN2thZ-V outputId: 9b7ed142-d161-4237-fed8-cbce41adc5f0 --- partition = P(None, 'x') -devices = mesh_utils.create_device_mesh((1, num_devices)) -mesh = jax.sharding.Mesh(devices, partition) +mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) # Create an input array that shards the last dimension across @@ -225,7 +223,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): dst_ref=output_ref, send_sem=send_sem, recv_sem=recv_sem, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) remote_copy_op.start() @@ -309,8 +307,7 @@ id: ojQEZB5mBRqM outputId: e1648f54-737c-4921-ca3b-b4c639a38d2b --- partition = P('x', None) -devices = mesh_utils.create_device_mesh((num_devices, 1)) -mesh = jax.sharding.Mesh(devices, partition) +mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) # Create an input array that shards the first dimension across @@ -349,7 +346,7 @@ def all_gather_kernel(input_ref, dst_ref=output_ref.at[copy_slot], send_sem=send_sem, recv_sem=recv_sems.at[outer_step], - device_id=(right_neighbor, 0), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) remote_copy_op.start() @@ -577,8 +574,7 @@ id: XrY5bMlvBroQ outputId: 77497000-4496-462e-cc3c-73fb640cc14c --- partition = P(None, 'x') -devices = mesh_utils.create_device_mesh((1, num_devices)) -mesh = jax.sharding.Mesh(devices, partition) +mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices)) @@ -611,13 +607,13 @@ def all_reduce_kernel( pltpu.semaphore_signal( barrier_sem, inc=1, - device_id=(0, left_neighbor), + device_id=(left_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) pltpu.semaphore_signal( barrier_sem, inc=1, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) pltpu.semaphore_wait(barrier_sem, 2) @@ -630,7 +626,7 @@ def all_reduce_kernel( dst_ref=hbm_scratch.at[working_slot], send_sem=remote_send_sem, recv_sem=remote_recv_sem, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) initial_copy.start() @@ -642,7 +638,7 @@ def all_reduce_kernel( pltpu.semaphore_signal( capacity_sem, inc=1, - device_id=(0, left_neighbor), + device_id=(left_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) @@ -663,7 +659,7 @@ def all_reduce_kernel( dst_ref=hbm_scratch.at[receiving_slot], send_sem=remote_send_sem, recv_sem=remote_recv_sem, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) remote_copy.start() @@ -786,8 +782,7 @@ executionInfo: id: nRauUAxNHg28 --- partition = P(None, 'x') -devices = mesh_utils.create_device_mesh((1, num_devices)) -mesh = jax.sharding.Mesh(devices, partition) +mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) # We need a block size of (16, 128) to ensure that a half-slice is at least @@ -817,7 +812,7 @@ def signal(left_or_right, semaphore): pltpu.semaphore_signal( semaphore, inc=1, - device_id=(0, neighbor), + device_id=(neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) @@ -858,7 +853,7 @@ def reduce_scatter_kernel( dst_ref=hbm_scratch.at[working_slot, left_copy_slice], send_sem=left_send_sem, recv_sem=left_recv_sem, - device_id=(0, left_neighbor), + device_id=(left_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) @@ -867,7 +862,7 @@ def reduce_scatter_kernel( dst_ref=hbm_scratch.at[working_slot, right_copy_slice], send_sem=right_send_sem, recv_sem=right_recv_sem, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) @@ -876,7 +871,7 @@ def reduce_scatter_kernel( dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice], send_sem=left_send_sem, recv_sem=left_recv_sem, - device_id=(0, left_neighbor), + device_id=(left_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) right_copy = pltpu.make_async_remote_copy( @@ -886,7 +881,7 @@ def reduce_scatter_kernel( dst_ref=hbm_scratch.at[working_slot, right_copy_slice], send_sem=right_send_sem, recv_sem=right_recv_sem, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) @@ -899,13 +894,13 @@ def reduce_scatter_kernel( pltpu.semaphore_signal( barrier_sem, inc=1, - device_id=(0, left_neighbor), + device_id=(left_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) pltpu.semaphore_signal( barrier_sem, inc=1, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) pltpu.semaphore_wait(barrier_sem, 2) @@ -1212,8 +1207,7 @@ executionInfo: id: 27jni-pSartL --- partition = P(None, 'x') -devices = mesh_utils.create_device_mesh((1, num_devices)) -mesh = jax.sharding.Mesh(devices, partition) +mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) # We pick a large outer kernel block size that we do not want to place @@ -1279,7 +1273,7 @@ def reduce_scatter_kernel( dst_ref=hbm_scratch.at[working_slot, left_copy_slice], send_sem=left_send_sem, recv_sem=left_recv_sem, - device_id=(0, left_neighbor), + device_id=(left_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) @@ -1288,7 +1282,7 @@ def reduce_scatter_kernel( dst_ref=hbm_scratch.at[working_slot, right_copy_slice], send_sem=right_send_sem, recv_sem=right_recv_sem, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) @@ -1297,7 +1291,7 @@ def reduce_scatter_kernel( dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice], send_sem=left_send_sem, recv_sem=left_recv_sem, - device_id=(0, left_neighbor), + device_id=(left_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) right_copy = pltpu.make_async_remote_copy( @@ -1305,7 +1299,7 @@ def reduce_scatter_kernel( dst_ref=hbm_scratch.at[working_slot, right_copy_slice], send_sem=right_send_sem, recv_sem=right_recv_sem, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) @@ -1318,13 +1312,13 @@ def reduce_scatter_kernel( pltpu.semaphore_signal( barrier_sem, inc=1, - device_id=(0, left_neighbor), + device_id=(left_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) pltpu.semaphore_signal( barrier_sem, inc=1, - device_id=(0, right_neighbor), + device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, ) pltpu.semaphore_wait(barrier_sem, 2)