From 51a666fb8c022bfa4a4995cb29ae570e67b459bd Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 6 Sep 2024 14:26:52 -0700 Subject: [PATCH] [Pallas] Update Pallas docs with new figures and TPUCompilerParams --- .../pallas/distributed/reduce_sum_1.svg | 2 +- .../pallas/distributed/reduce_sum_2.svg | 2 +- docs/pallas/tpu/details.rst | 6 ++---- docs/pallas/tpu/distributed.ipynb | 18 +++++++++--------- docs/pallas/tpu/distributed.md | 16 ++++++++-------- docs/pallas/tpu/matmul.ipynb | 16 ++++++++-------- docs/pallas/tpu/matmul.md | 16 ++++++++-------- docs/pallas/tpu/pipelining.ipynb | 3 ++- docs/pallas/tpu/pipelining.md | 3 ++- 9 files changed, 41 insertions(+), 41 deletions(-) diff --git a/docs/_static/pallas/distributed/reduce_sum_1.svg b/docs/_static/pallas/distributed/reduce_sum_1.svg index 6c397a87b..9a527aff6 100644 --- a/docs/_static/pallas/distributed/reduce_sum_1.svg +++ b/docs/_static/pallas/distributed/reduce_sum_1.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_sum_2.svg b/docs/_static/pallas/distributed/reduce_sum_2.svg index ef2a76330..61685cf41 100644 --- a/docs/_static/pallas/distributed/reduce_sum_2.svg +++ b/docs/_static/pallas/distributed/reduce_sum_2.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index ae9505c4e..93d7e5547 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -148,10 +148,8 @@ grid axes over cores. This is an opt-in procedure. To allow that, .. pallas_call( ..., - compiler_params=dict( - mosaic=dict( - dimension_semantics=["parallel", "parallel", "arbitrary"] - ) + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=["parallel", "parallel", "arbitrary"] ), ) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index 5209f2ff8..8552e10d8 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -196,11 +196,7 @@ "- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang.\n", "- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely.\n", "- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.\n", - "- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states.\n", - "\n", - "### Megacore\n", - "\n", - "Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core." + "- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states." ] }, { @@ -576,7 +572,7 @@ "kernel = pl.pallas_call(\n", " example_kernel,\n", " ...,\n", - " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", ")\n", "```" ] @@ -815,7 +811,7 @@ " all_reduce_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", ")\n", "\n", "pallas_result = jax.jit(\n", @@ -1169,7 +1165,7 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", @@ -1626,7 +1622,7 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", @@ -1705,6 +1701,10 @@ "source": [ "## Final Notes\n", "\n", + "### Megacore\n", + "\n", + "Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n", + "\n", "### Interaction with XLA\n", "\n", "In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index b7c058b11..dbdb00e80 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -183,10 +183,6 @@ Some common causes of the above include: - If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states. - If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states. -### Megacore - -Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. - +++ {"id": "vpGSN1Sui0Bu"} ### Example: Right Permute (`lax.ppermute`) @@ -498,7 +494,7 @@ When using barrier semaphores, the `collective_id` compiler parameter must be pa kernel = pl.pallas_call( example_kernel, ..., - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) ``` @@ -709,7 +705,7 @@ kernel = pl.pallas_call( all_reduce_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) pallas_result = jax.jit( @@ -1042,7 +1038,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), )(input_arr)[0] @@ -1460,7 +1456,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), )(input_arr)[0] @@ -1518,6 +1514,10 @@ print( ## Final Notes +### Megacore + +Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. + ### Interaction with XLA In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations. diff --git a/docs/pallas/tpu/matmul.ipynb b/docs/pallas/tpu/matmul.ipynb index 0bd16095c..51ce2ed68 100644 --- a/docs/pallas/tpu/matmul.ipynb +++ b/docs/pallas/tpu/matmul.ipynb @@ -210,8 +210,8 @@ " pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " grid=(m // bm, n // bn, k // bk),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, @@ -466,8 +466,8 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, @@ -741,8 +741,8 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, @@ -929,8 +929,8 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, diff --git a/docs/pallas/tpu/matmul.md b/docs/pallas/tpu/matmul.md index a00880eba..e542dedc7 100644 --- a/docs/pallas/tpu/matmul.md +++ b/docs/pallas/tpu/matmul.md @@ -167,8 +167,8 @@ def matmul( pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))], out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), grid=(m // bm, n // bn, k // bk), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -321,8 +321,8 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -489,8 +489,8 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -613,8 +613,8 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 275a72f38..2a3aa9d11 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -33,6 +33,7 @@ "\n", "import jax\n", "from jax.experimental import pallas as pl\n", + "from jax.experimental.pallas import tpu as pltpu\n", "import jax.numpy as jnp\n", "import numpy as np" ] @@ -696,7 +697,7 @@ " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n", + " compiler_params=pltpu.TPUCompilerParams(dimension_semantics=(\"parallel\",))\n", " )(x, y)\n", "\n", "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index d753b404d..67c1900a0 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -29,6 +29,7 @@ pipelines in Pallas that overlap memory I/O with compute. import jax from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np ``` @@ -465,7 +466,7 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))) + compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",)) )(x, y) x, y = jnp.ones((512, 512)), jnp.ones((512, 512))