[Pallas] Update Pallas docs with new figures and TPUCompilerParams

This commit is contained in:
Justin Fu 2024-09-06 14:26:52 -07:00
parent 57c0d59d04
commit 51a666fb8c
9 changed files with 41 additions and 41 deletions

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 18 KiB

After

Width:  |  Height:  |  Size: 26 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 94 KiB

After

Width:  |  Height:  |  Size: 70 KiB

View File

@ -148,10 +148,8 @@ grid axes over cores. This is an opt-in procedure. To allow that,
..
pallas_call(
...,
compiler_params=dict(
mosaic=dict(
dimension_semantics=["parallel", "parallel", "arbitrary"]
)
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=["parallel", "parallel", "arbitrary"]
),
)

View File

@ -196,11 +196,7 @@
"- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang.\n",
"- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely.\n",
"- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.\n",
"- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states.\n",
"\n",
"### Megacore\n",
"\n",
"Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core."
"- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states."
]
},
{
@ -576,7 +572,7 @@
"kernel = pl.pallas_call(\n",
" example_kernel,\n",
" ...,\n",
" compiler_params=dict(mosaic=dict(collective_id=0)),\n",
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
")\n",
"```"
]
@ -815,7 +811,7 @@
" all_reduce_kernel,\n",
" out_shape=out_shape,\n",
" grid_spec=grid_spec,\n",
" compiler_params=dict(mosaic=dict(collective_id=0)),\n",
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
")\n",
"\n",
"pallas_result = jax.jit(\n",
@ -1169,7 +1165,7 @@
" reduce_scatter_kernel,\n",
" out_shape=out_shape,\n",
" grid_spec=grid_spec,\n",
" compiler_params=dict(mosaic=dict(collective_id=0)),\n",
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
" )(input_arr)[0]\n",
"\n",
"\n",
@ -1626,7 +1622,7 @@
" reduce_scatter_kernel,\n",
" out_shape=out_shape,\n",
" grid_spec=grid_spec,\n",
" compiler_params=dict(mosaic=dict(collective_id=0)),\n",
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
" )(input_arr)[0]\n",
"\n",
"\n",
@ -1705,6 +1701,10 @@
"source": [
"## Final Notes\n",
"\n",
"### Megacore\n",
"\n",
"Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n",
"\n",
"### Interaction with XLA\n",
"\n",
"In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.\n",

View File

@ -183,10 +183,6 @@ Some common causes of the above include:
- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.
- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states.
### Megacore
Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.
+++ {"id": "vpGSN1Sui0Bu"}
### Example: Right Permute (`lax.ppermute`)
@ -498,7 +494,7 @@ When using barrier semaphores, the `collective_id` compiler parameter must be pa
kernel = pl.pallas_call(
example_kernel,
...,
compiler_params=dict(mosaic=dict(collective_id=0)),
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)
```
@ -709,7 +705,7 @@ kernel = pl.pallas_call(
all_reduce_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
compiler_params=dict(mosaic=dict(collective_id=0)),
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)
pallas_result = jax.jit(
@ -1042,7 +1038,7 @@ def pallas_reduce_scatter(input_arr):
reduce_scatter_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
compiler_params=dict(mosaic=dict(collective_id=0)),
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)(input_arr)[0]
@ -1460,7 +1456,7 @@ def pallas_reduce_scatter(input_arr):
reduce_scatter_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
compiler_params=dict(mosaic=dict(collective_id=0)),
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)(input_arr)[0]
@ -1518,6 +1514,10 @@ print(
## Final Notes
### Megacore
Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.
### Interaction with XLA
In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.

View File

@ -210,8 +210,8 @@
" pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n",
" out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n",
" grid=(m // bm, n // bn, k // bk),\n",
" compiler_params=dict(mosaic=dict(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
" compiler_params=pltpu.TPUCompilerParams(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
" )(x, y)"
]
},
@ -466,8 +466,8 @@
" grid=(m // bm, n // bn, k // bk),\n",
" ),\n",
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
" compiler_params=dict(mosaic=dict(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
" compiler_params=pltpu.TPUCompilerParams(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
" )(x, y)"
]
},
@ -741,8 +741,8 @@
" grid=(m // bm, n // bn, k // bk),\n",
" ),\n",
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
" compiler_params=dict(mosaic=dict(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
" compiler_params=pltpu.TPUCompilerParams(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
" )(x, y)"
]
},
@ -929,8 +929,8 @@
" grid=(m // bm, n // bn, k // bk),\n",
" ),\n",
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
" compiler_params=dict(mosaic=dict(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
" compiler_params=pltpu.TPUCompilerParams(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
" )(x, y)"
]
},

View File

@ -167,8 +167,8 @@ def matmul(
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
grid=(m // bm, n // bn, k // bk),
compiler_params=dict(mosaic=dict(
dimension_semantics=("parallel", "parallel", "arbitrary"))),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
```
@ -321,8 +321,8 @@ def matmul(
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=dict(mosaic=dict(
dimension_semantics=("parallel", "parallel", "arbitrary"))),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
```
@ -489,8 +489,8 @@ def matmul(
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=dict(mosaic=dict(
dimension_semantics=("parallel", "parallel", "arbitrary"))),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
```
@ -613,8 +613,8 @@ def matmul(
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=dict(mosaic=dict(
dimension_semantics=("parallel", "parallel", "arbitrary"))),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
```

View File

@ -33,6 +33,7 @@
"\n",
"import jax\n",
"from jax.experimental import pallas as pl\n",
"from jax.experimental.pallas import tpu as pltpu\n",
"import jax.numpy as jnp\n",
"import numpy as np"
]
@ -696,7 +697,7 @@
" in_specs=[block_spec, block_spec],\n",
" out_specs=block_spec,\n",
" grid=(2,),\n",
" compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n",
" compiler_params=pltpu.TPUCompilerParams(dimension_semantics=(\"parallel\",))\n",
" )(x, y)\n",
"\n",
"x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",

View File

@ -29,6 +29,7 @@ pipelines in Pallas that overlap memory I/O with compute.
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
```
@ -465,7 +466,7 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",)))
compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
)(x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))