Add a skeleton for Pallas:Mosaic GPU documentation

This commit is contained in:
Adam Paszke 2025-04-08 11:52:53 +00:00
parent 8ed59d8b5d
commit 511f78202f
4 changed files with 270 additions and 1 deletions

99
docs/_static/pallas/gpu/nvidia_sm.svg vendored Normal file
View File

@ -0,0 +1,99 @@
<svg width="600" height="600" xmlns="http://www.w3.org/2000/svg" font-family="sans-serif" font-size="10">
<style>
.sm-border { stroke: #a0aec0; stroke-width: 2; fill: #f7fafc; }
.subdivision-border { stroke: #cbd5e0; stroke-width: 1; fill: #edf2f7; }
.component-box { stroke: #e2e8f0; stroke-width: 1; rx: 4; ry: 4; }
.tensor-core { fill: #bee3f8; }
.alu { fill: #c6f6d5; }
.ld-st { fill: #faf089; }
.sfu { fill: #fed7d7; }
.scheduler { fill: #e9d8fd; stroke: #d6bcfa; stroke-width:1; rx: 4; ry: 4; }
.shared-memory { fill: #b2f5ea; stroke: #81e6d9; stroke-width:1; rx: 6; ry: 6; }
.label { text-anchor: middle; fill: #2d3748; font-weight: bold; }
.component-label { text-anchor: middle; fill: #4a5568; font-size: 9px; }
.title-label { text-anchor: middle; fill: #1a202c; font-size: 16px; font-weight: bold; }
.small-label tspan { font-size: 8.5px; }
</style>
<rect x="10" y="10" width="580" height="580" class="sm-border" rx="10" ry="10"/>
<text x="300" y="35" class="title-label">Streaming Multiprocessor</text>
<g id="subdivision-grid">
<g transform="translate(30, 60)">
<rect x="0" y="0" width="255" height="225" class="subdivision-border" rx="6" ry="6"/>
<rect x="10" y="10" width="235" height="20" class="scheduler"/>
<text x="127.5" y="24" class="label">Warp Scheduler</text>
<rect x="10" y="40" width="145" height="125" class="component-box tensor-core"/>
<text x="82.5" y="107.5" class="component-label">TensorCore</text>
<rect x="165" y="40" width="80" height="125" class="component-box alu"/>
<text x="205" y="100" class="component-label">ALU</text>
<text x="205" y="110" class="component-label">(Float/Int)</text>
<rect x="10" y="175" width="115" height="40" class="component-box ld-st"/>
<text x="67.5" y="198" class="component-label">Load/Store</text>
<rect x="135" y="175" width="110" height="40" class="component-box sfu"/>
<text x="190" y="190" class="component-label small-label">
<tspan x="190" dy="3">Special</tspan>
<tspan x="190" dy="10">Functions</tspan>
</text>
</g>
<g transform="translate(310, 60)">
<rect x="0" y="0" width="255" height="225" class="subdivision-border" rx="6" ry="6"/>
<rect x="10" y="10" width="235" height="20" class="scheduler"/>
<text x="127.5" y="24" class="label">Warp Scheduler</text>
<rect x="10" y="40" width="145" height="125" class="component-box tensor-core"/>
<text x="82.5" y="107.5" class="component-label">TensorCore</text>
<rect x="165" y="40" width="80" height="125" class="component-box alu"/>
<text x="205" y="100" class="component-label">ALU</text>
<text x="205" y="110" class="component-label">(Float/Int)</text>
<rect x="10" y="175" width="115" height="40" class="component-box ld-st"/>
<text x="67.5" y="198" class="component-label">Load/Store</text>
<rect x="135" y="175" width="110" height="40" class="component-box sfu"/>
<text x="190" y="190" class="component-label small-label">
<tspan x="190" dy="3">Special</tspan>
<tspan x="190" dy="10">Functions</tspan>
</text>
</g>
<g transform="translate(30, 305)">
<rect x="0" y="0" width="255" height="225" class="subdivision-border" rx="6" ry="6"/>
<rect x="10" y="10" width="235" height="20" class="scheduler"/>
<text x="127.5" y="24" class="label">Warp Scheduler</text>
<rect x="10" y="40" width="145" height="125" class="component-box tensor-core"/>
<text x="82.5" y="107.5" class="component-label">TensorCore</text>
<rect x="165" y="40" width="80" height="125" class="component-box alu"/>
<text x="205" y="100" class="component-label">ALU</text>
<text x="205" y="110" class="component-label">(Float/Int)</text>
<rect x="10" y="175" width="115" height="40" class="component-box ld-st"/>
<text x="67.5" y="198" class="component-label">Load/Store</text>
<rect x="135" y="175" width="110" height="40" class="component-box sfu"/>
<text x="190" y="190" class="component-label small-label">
<tspan x="190" dy="3">Special</tspan>
<tspan x="190" dy="10">Functions</tspan>
</text>
</g>
<g transform="translate(310, 305)">
<rect x="0" y="0" width="255" height="225" class="subdivision-border" rx="6" ry="6"/>
<rect x="10" y="10" width="235" height="20" class="scheduler"/>
<text x="127.5" y="24" class="label">Warp Scheduler</text>
<rect x="10" y="40" width="145" height="125" class="component-box tensor-core"/>
<text x="82.5" y="107.5" class="component-label">TensorCore</text>
<rect x="165" y="40" width="80" height="125" class="component-box alu"/>
<text x="205" y="100" class="component-label">ALU</text>
<text x="205" y="110" class="component-label">(Float/Int)</text>
<rect x="10" y="175" width="115" height="40" class="component-box ld-st"/>
<text x="67.5" y="198" class="component-label">Load/Store</text>
<rect x="135" y="175" width="110" height="40" class="component-box sfu"/>
<text x="190" y="190" class="component-label small-label">
<tspan x="190" dy="3">Special</tspan>
<tspan x="190" dy="10">Functions</tspan>
</text>
</g>
</g>
<rect x="30" y="545" width="540" height="35" class="shared-memory"/>
<text x="300" y="567" class="label">Shared Memory / L1 Cache</text>
</svg>

After

Width:  |  Height:  |  Size: 5.4 KiB

14
docs/pallas/gpu/index.rst Normal file
View File

@ -0,0 +1,14 @@
Pallas:Mosaic GPU
=================
Backend specific documentation for the Mosaic GPU backend.
.. toctree::
:caption: Reference documentation
:maxdepth: 2
reference
.. toctree::
:caption: Guides
:maxdepth: 2

View File

@ -0,0 +1,150 @@
# Writing Mosaic GPU kernels with Pallas
This page is a reference for the most important features of the Pallas:MGPU backend.
It's not a tutorial and as such we do not expect everyone to read it top to bottom.
Still, it is worth going over
just to familiarise yourself with some patterns you can find in other tutorials.
In the following examples, we're going to assume the following imports are in scope:
```python
import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu
```
## What is a GPU?
Technically, the NVIDIA GPU architecture looks as follows: the GPU is partitioned into
_streaming multiprocessors_ (SMs). The way this manifests in the CUDA programming model
is that each _CUDA thread block_ (or CTA) is scheduled on exactly one SM, but multiple
blocks can be scheduled onto a single SM at a time.
Each SM contains a chunk of fast memory called _shared memory_ (SMEM) and 4 subdivisions,
each containing a _warp scheduler_ and compute units (ALU, TensorCore, ...).
This is also reflected in the CUDA programs: each _warp_ (a group of consecutive 32 CUDA
threads in a block) is assigned to one of those subdivisions in a round-robin fashion.
Similarly to blocks, each warp is assigned to exactly one subdivision (it never migrates),
but multiple warps can be assigned to the same SM subdivision. At each clock cycle, the
warp scheduler from each subdivision tries to select one of its resident warps to execute
the next instruction.
![A diagram of one SM](../../_static/pallas/gpu/nvidia_sm.svg)
Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are
4 consecutive warps. Knowing how the hardware looks like, we can see where this is comming
from: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructions
that utilize the whole SM.
> A GPU can be viewed in many different ways and in here we want to focus on a slightly
simplified model that is very TensorCore-centric. This should help you navigate the
complexities of writing kernels involving the TensorCore, but keep in mind that the
real picture is more complicated.
For our purposes, TensorCore operations have grown so big that it no longer makes much
sense to follow the CUDA model. As such, to us, a GPU is a collection of single-threaded cores
(SMs) with one thread of Pallas:MGPU corresponding to a CUDA warpgroup. In this model, each
operation you perform in the kernel occupies the whole CUDA warpgroup, and its constituent
warps always run in lockstep (modulo the jitter from hardware scheduling) and never take
different paths through control flow (with the small exception of `core_map` that we will
discuss later). One notable addition here is that we still allow you to co-schedule multiple
of those Pallas-level threads on the same SM so that they can cooperate and communicate
through shared memory (we relize that by putting them in the same CUDA block).
> This is very similar to a programming model popularized by [Triton](https://triton-lang.org/),
but as you will see there are a few differences. Mosaic GPU tends to be more low level,
which usually means you will have to put in more work, but it also puts you more in control.
In our view both approaches have their merits and we encourage you to pick the backend that
suits your needs the best! Pallas supports and will continue to support Triton as an alternative
GPU backend.
### In-order execution & using multiple hardware units
Unlike more complicated CPU architectures GPU only support in-order execution. That, however,
does not mean that at any given time only a single instruction is running! Each SM quarter
has multiple independent functional units: TensorCore, Arithmetic logic unit (ALU),
Load/Store (LSU), Special function unit (SFU). If the first instruction targets one of the
units and is followed by another one (that does not use the result of the first one), then the
warp scheduler can issue the second one before the first one completes. This is often referred
to as instruction-level parallelism (ILP) and is a common theme in modern TensorCore kernels:
TensorCore operations are so big and take so many cycles to complete, that it is a waste to not
try to use other units in the meantime.
To extend this even further, we can take advantage of this hardware-unit-level parallelism by
allowing multiple Pallas threads (warpgroups) to run concurrently. If one of the threads primarily
occupies the ALU, while another one primarily issues TensorCore related instructions, we can
take advantage of the efficient context switching built into the warp schedulers to keep both
units busy. This is one of the core idea behind algorithms such as [FlashAttention 3](https://arxiv.org/abs/2407.08608)
or [CUTLASS ping-pong matmul kernels](https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/).
For more information on how warp scheduling and instruction issue works, we recommend reading
[Analyzing Modern NVIDIA GPU cores](https://arxiv.org/abs/2503.20481).
## Array layouts and reference transforms
TODO
## MMA (TensorCore)
In this section, we focus on how Pallas:MGPU kernels can utilize the TensorCore unit.
NVIDIA continues to change the programming interface of the TensorCore significantly
between different hardware generations, which is why the lowest-level interfaces
differ in Pallas:MGPU as well.
### Hopper (`wgmma`)
TODO
### Blackwell (`tcgen05`)
TODO
## Using `core_map`
TODO
## Synchronization structures and primitives
### `commit_smem`
TODO
### `Barrier`
This is essentially a thin wrapper around an array of PTX `mbarrier` types and is
passed in as a reference. All functions involving barriers expect to only get a single
barrier argument, and so if the reference contains multiple, you have to extract one
of them explicitly using `barriers.at[index]`.
`Barrier`s are always allocated in SMEM and as such have relatively low overheads.
There are three primary use cases that require the use of `Barrier`s:
1. Awaiting asynchronous GMEM-to-SMEM copies
TODO
2. Cross-warpgroup synchronization
TODO
3. Awaiting `tcgen05` TensorCore instructions
TODO
### `ClusterBarrier`
TODO
### `Semaphore`
TODO
## Asynchronous copies
TODO
## Inline Mosaic GPU
TODO
## Compiler parameters
TODO

View File

@ -26,11 +26,17 @@ See also the :class:`jax.experimental.pallas` module API documentation.
.. toctree::
:caption: Platform Features
:caption: TPU backend guide
:maxdepth: 2
tpu/index
.. toctree::
:caption: Mosaic GPU backend guide
:maxdepth: 2
gpu/index
.. toctree::
:caption: Design Notes
:maxdepth: 2