mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas] Update Pallas docs with new figures and TPUCompilerParams
This commit is contained in:
parent
57c0d59d04
commit
51a666fb8c
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 26 KiB |
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 70 KiB |
@ -148,10 +148,8 @@ grid axes over cores. This is an opt-in procedure. To allow that,
|
||||
..
|
||||
pallas_call(
|
||||
...,
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=["parallel", "parallel", "arbitrary"]
|
||||
)
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=["parallel", "parallel", "arbitrary"]
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -196,11 +196,7 @@
|
||||
"- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang.\n",
|
||||
"- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely.\n",
|
||||
"- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.\n",
|
||||
"- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states.\n",
|
||||
"\n",
|
||||
"### Megacore\n",
|
||||
"\n",
|
||||
"Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core."
|
||||
"- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -576,7 +572,7 @@
|
||||
"kernel = pl.pallas_call(\n",
|
||||
" example_kernel,\n",
|
||||
" ...,\n",
|
||||
" compiler_params=dict(mosaic=dict(collective_id=0)),\n",
|
||||
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
|
||||
")\n",
|
||||
"```"
|
||||
]
|
||||
@ -815,7 +811,7 @@
|
||||
" all_reduce_kernel,\n",
|
||||
" out_shape=out_shape,\n",
|
||||
" grid_spec=grid_spec,\n",
|
||||
" compiler_params=dict(mosaic=dict(collective_id=0)),\n",
|
||||
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"pallas_result = jax.jit(\n",
|
||||
@ -1169,7 +1165,7 @@
|
||||
" reduce_scatter_kernel,\n",
|
||||
" out_shape=out_shape,\n",
|
||||
" grid_spec=grid_spec,\n",
|
||||
" compiler_params=dict(mosaic=dict(collective_id=0)),\n",
|
||||
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
|
||||
" )(input_arr)[0]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -1626,7 +1622,7 @@
|
||||
" reduce_scatter_kernel,\n",
|
||||
" out_shape=out_shape,\n",
|
||||
" grid_spec=grid_spec,\n",
|
||||
" compiler_params=dict(mosaic=dict(collective_id=0)),\n",
|
||||
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
|
||||
" )(input_arr)[0]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -1705,6 +1701,10 @@
|
||||
"source": [
|
||||
"## Final Notes\n",
|
||||
"\n",
|
||||
"### Megacore\n",
|
||||
"\n",
|
||||
"Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n",
|
||||
"\n",
|
||||
"### Interaction with XLA\n",
|
||||
"\n",
|
||||
"In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.\n",
|
||||
|
@ -183,10 +183,6 @@ Some common causes of the above include:
|
||||
- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.
|
||||
- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states.
|
||||
|
||||
### Megacore
|
||||
|
||||
Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.
|
||||
|
||||
+++ {"id": "vpGSN1Sui0Bu"}
|
||||
|
||||
### Example: Right Permute (`lax.ppermute`)
|
||||
@ -498,7 +494,7 @@ When using barrier semaphores, the `collective_id` compiler parameter must be pa
|
||||
kernel = pl.pallas_call(
|
||||
example_kernel,
|
||||
...,
|
||||
compiler_params=dict(mosaic=dict(collective_id=0)),
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
|
||||
)
|
||||
```
|
||||
|
||||
@ -709,7 +705,7 @@ kernel = pl.pallas_call(
|
||||
all_reduce_kernel,
|
||||
out_shape=out_shape,
|
||||
grid_spec=grid_spec,
|
||||
compiler_params=dict(mosaic=dict(collective_id=0)),
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
|
||||
)
|
||||
|
||||
pallas_result = jax.jit(
|
||||
@ -1042,7 +1038,7 @@ def pallas_reduce_scatter(input_arr):
|
||||
reduce_scatter_kernel,
|
||||
out_shape=out_shape,
|
||||
grid_spec=grid_spec,
|
||||
compiler_params=dict(mosaic=dict(collective_id=0)),
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
|
||||
)(input_arr)[0]
|
||||
|
||||
|
||||
@ -1460,7 +1456,7 @@ def pallas_reduce_scatter(input_arr):
|
||||
reduce_scatter_kernel,
|
||||
out_shape=out_shape,
|
||||
grid_spec=grid_spec,
|
||||
compiler_params=dict(mosaic=dict(collective_id=0)),
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
|
||||
)(input_arr)[0]
|
||||
|
||||
|
||||
@ -1518,6 +1514,10 @@ print(
|
||||
|
||||
## Final Notes
|
||||
|
||||
### Megacore
|
||||
|
||||
Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.
|
||||
|
||||
### Interaction with XLA
|
||||
|
||||
In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.
|
||||
|
@ -210,8 +210,8 @@
|
||||
" pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n",
|
||||
" out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n",
|
||||
" grid=(m // bm, n // bn, k // bk),\n",
|
||||
" compiler_params=dict(mosaic=dict(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
|
||||
" compiler_params=pltpu.TPUCompilerParams(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
|
||||
" )(x, y)"
|
||||
]
|
||||
},
|
||||
@ -466,8 +466,8 @@
|
||||
" grid=(m // bm, n // bn, k // bk),\n",
|
||||
" ),\n",
|
||||
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
|
||||
" compiler_params=dict(mosaic=dict(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
|
||||
" compiler_params=pltpu.TPUCompilerParams(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
|
||||
" )(x, y)"
|
||||
]
|
||||
},
|
||||
@ -741,8 +741,8 @@
|
||||
" grid=(m // bm, n // bn, k // bk),\n",
|
||||
" ),\n",
|
||||
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
|
||||
" compiler_params=dict(mosaic=dict(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
|
||||
" compiler_params=pltpu.TPUCompilerParams(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
|
||||
" )(x, y)"
|
||||
]
|
||||
},
|
||||
@ -929,8 +929,8 @@
|
||||
" grid=(m // bm, n // bn, k // bk),\n",
|
||||
" ),\n",
|
||||
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
|
||||
" compiler_params=dict(mosaic=dict(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
|
||||
" compiler_params=pltpu.TPUCompilerParams(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
|
||||
" )(x, y)"
|
||||
]
|
||||
},
|
||||
|
@ -167,8 +167,8 @@ def matmul(
|
||||
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
|
||||
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
|
||||
grid=(m // bm, n // bn, k // bk),
|
||||
compiler_params=dict(mosaic=dict(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary"))),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary")),
|
||||
)(x, y)
|
||||
```
|
||||
|
||||
@ -321,8 +321,8 @@ def matmul(
|
||||
grid=(m // bm, n // bn, k // bk),
|
||||
),
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
|
||||
compiler_params=dict(mosaic=dict(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary"))),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary")),
|
||||
)(x, y)
|
||||
```
|
||||
|
||||
@ -489,8 +489,8 @@ def matmul(
|
||||
grid=(m // bm, n // bn, k // bk),
|
||||
),
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
|
||||
compiler_params=dict(mosaic=dict(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary"))),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary")),
|
||||
)(x, y)
|
||||
```
|
||||
|
||||
@ -613,8 +613,8 @@ def matmul(
|
||||
grid=(m // bm, n // bn, k // bk),
|
||||
),
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
|
||||
compiler_params=dict(mosaic=dict(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary"))),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary")),
|
||||
)(x, y)
|
||||
```
|
||||
|
||||
|
@ -33,6 +33,7 @@
|
||||
"\n",
|
||||
"import jax\n",
|
||||
"from jax.experimental import pallas as pl\n",
|
||||
"from jax.experimental.pallas import tpu as pltpu\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
@ -696,7 +697,7 @@
|
||||
" in_specs=[block_spec, block_spec],\n",
|
||||
" out_specs=block_spec,\n",
|
||||
" grid=(2,),\n",
|
||||
" compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n",
|
||||
" compiler_params=pltpu.TPUCompilerParams(dimension_semantics=(\"parallel\",))\n",
|
||||
" )(x, y)\n",
|
||||
"\n",
|
||||
"x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",
|
||||
|
@ -29,6 +29,7 @@ pipelines in Pallas that overlap memory I/O with compute.
|
||||
|
||||
import jax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
```
|
||||
@ -465,7 +466,7 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
|
||||
in_specs=[block_spec, block_spec],
|
||||
out_specs=block_spec,
|
||||
grid=(2,),
|
||||
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",)))
|
||||
compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
|
||||
)(x, y)
|
||||
|
||||
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
|
||||
|
Loading…
x
Reference in New Issue
Block a user