[Pallas] Pallas documentation cleanup

This commit is contained in:
Justin Fu 2024-12-04 14:27:11 -08:00
parent 1a3c9c44dc
commit e05afefc97
9 changed files with 88 additions and 76 deletions

View File

@ -94,7 +94,7 @@ Pallas kernels via JAX transformations.
<center>
![Pallas lowering path](../_static/pallas/pallas_flow.png)
![Pallas lowering path](../../_static/pallas/pallas_flow.png)
Visualization of Pallas lowering paths
</center>
@ -413,10 +413,10 @@ verify the correctness of the Triton and Mosaic compilers.
One could also imagine perturbing the `scan` ordering to simulate the
parallel reads and writes that happen on GPU.
### Examples
### GPU Examples
Note all the following examples are for GPU only. They will require some small
changes to work on TPUs.
Note all the following examples are for GPU only. They will require tweaks to
the block sizes to work on TPUs.
#### `add`

View File

@ -0,0 +1,9 @@
Pallas Design Notes
===================
.. toctree::
:caption: Design
:maxdepth: 2
design
async_note

View File

@ -44,39 +44,7 @@ For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
You can also use {func}`jax.experimental.pallas.num_programs` to get the
grid size for a given axis.
Here's an example kernel that uses a `grid` and `program_id`.
```python
>>> import jax
>>> from jax.experimental import pallas as pl
>>> def iota_kernel(o_ref):
... i = pl.program_id(0)
... o_ref[i] = i
```
We now execute it using `pallas_call` with an additional `grid` argument.
```python
>>> def iota(size: int):
... return pl.pallas_call(iota_kernel,
... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
... grid=(size,), interpret=True)()
>>> iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
```
On GPUs, each program is executed in parallel on separate thread blocks.
Thus, we need to think about race conditions on writes to HBM.
A reasonable approach is to write our kernels in such a way that different
programs write to disjoint places in HBM to avoid these parallel writes.
On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.
See {ref}`pallas_tpu_noteworthy_properties`.
See {ref}`grids_by_example` for a simple kernel that uses this API.
(pallas_blockspec)=
@ -131,6 +99,8 @@ shape `x_shape` are computed as in the function `slice_for_invocation`
below:
```python
>>> import jax
>>> from jax.experimental import pallas as pl
>>> def slices_for_invocation(x_shape: tuple[int, ...],
... x_spec: pl.BlockSpec,
... grid: tuple[int, ...],

View File

@ -22,7 +22,6 @@ See also the :class:`jax.experimental.pallas` module API documentation.
:maxdepth: 2
quickstart
design
grid_blockspec
@ -34,9 +33,9 @@ See also the :class:`jax.experimental.pallas` module API documentation.
.. toctree::
:caption: Design Notes
:maxdepth: 1
:maxdepth: 2
async_note
design/index
.. toctree::
:caption: Other

View File

@ -72,8 +72,9 @@
"\n",
"Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n",
"it does not take in `jax.Array`s as inputs and doesn't return any values.\n",
"Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
"but we are given an `o_ref`, which corresponds to the desired output.\n",
"Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.\n",
"Note that we also don't have any outputs but we are given an `o_ref`, which corresponds\n",
"to the desired output.\n",
"\n",
"**Reading from `Ref`s**\n",
"\n",
@ -150,7 +151,8 @@
"**What's actually happening here?**\n",
"\n",
"Thus far we've described how to think about Pallas kernels but what we've actually\n",
"accomplished is we're writing a function that's executed very close to the compute units.\n",
"accomplished is we're writing a function that's executed very close to the compute units\n",
"since values are loaded into the innermost (fastest) portion of the memory hierarchy.\n",
"\n",
"On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n",
"we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n",
@ -195,6 +197,8 @@
"live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n",
"that operate on \"blocks\" of those arrays that can fit in SRAM.\n",
"\n",
"(grids_by_example)=\n",
"\n",
"### Grids by example\n",
"\n",
"To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n",
@ -240,7 +244,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We now execute it using `pallas_call` with an additional `grid` argument."
"We now execute it using `pallas_call` with an additional `grid` argument.\n",
"On GPUs, we can call the kernel directly like so:"
]
},
{
@ -260,6 +265,7 @@
}
],
"source": [
"# GPU version\n",
"def iota(size: int):\n",
" return pl.pallas_call(iota_kernel,\n",
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
@ -272,16 +278,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"On GPUs, each program is executed in parallel on separate threads.\n",
"Thus, we need to think about race conditions on writes to HBM.\n",
"A reasonable approach is to write our kernels in such a way that different\n",
"programs write to disjoint places in HBM to avoid these parallel writes.\n",
"On the other hand, parallelizing the computation is how we can execute\n",
"operations like matrix multiplications really quickly.\n",
"\n",
"On TPUs, programs are executed in a combination of parallel and sequential\n",
"(depending on the architecture) so there are slightly different considerations.\n",
"\n",
"TPUs distinguish between vector and scalar memory spaces and in this case the\n",
"output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
"a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n",
"To call the above kernel on TPU, run:"
]
},
@ -292,6 +291,7 @@
"metadata": {},
"outputs": [],
"source": [
"# TPU version\n",
"from jax.experimental.pallas import tpu as pltpu\n",
"\n",
"def iota(size: int):\n",
@ -307,11 +307,22 @@
"id": "68f97b4e",
"metadata": {},
"source": [
"TPUs distinguish between vector and scalar memory spaces and in this case the\n",
"output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
"a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n",
"### Grid semantics\n",
"\n",
"You can read more details at {ref}`pallas_grid`."
"On GPUs, each program is executed in parallel on separate threads.\n",
"Thus, we need to think about race conditions on writes to HBM.\n",
"A reasonable approach is to write our kernels in such a way that different\n",
"programs write to disjoint locations in HBM to avoid these parallel writes.\n",
"On the other hand, parallelizing the computation is how we can execute\n",
"operations like matrix multiplications really quickly.\n",
"\n",
"In contrast, TPUs operate like a very wide SIMD machine.\n",
"Some TPU models contain multiple cores, but in many cases a TPU can be\n",
"treated as a single-threaded processor. The grid on a TPU can be\n",
"specified in a combination of parallel and sequential dimensions, where sequential\n",
"dimensions are guaranteed to run serially.\n",
"\n",
"You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`."
]
},
{

View File

@ -53,8 +53,9 @@ def add_vectors_kernel(x_ref, y_ref, o_ref):
Let's dissect this function a bit. Unlike most JAX functions you've probably written,
it does not take in `jax.Array`s as inputs and doesn't return any values.
Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
but we are given an `o_ref`, which corresponds to the desired output.
Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.
Note that we also don't have any outputs but we are given an `o_ref`, which corresponds
to the desired output.
**Reading from `Ref`s**
@ -101,7 +102,8 @@ thereof).
**What's actually happening here?**
Thus far we've described how to think about Pallas kernels but what we've actually
accomplished is we're writing a function that's executed very close to the compute units.
accomplished is we're writing a function that's executed very close to the compute units
since values are loaded into the innermost (fastest) portion of the memory hierarchy.
On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when
we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)
@ -134,6 +136,8 @@ Part of writing Pallas kernels is thinking about how to take big arrays that
live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations
that operate on "blocks" of those arrays that can fit in SRAM.
(grids_by_example)=
### Grids by example
To automatically "carve" up the inputs and outputs, you provide a `grid` and
@ -169,8 +173,10 @@ def iota_kernel(o_ref):
```
We now execute it using `pallas_call` with an additional `grid` argument.
On GPUs, we can call the kernel directly like so:
```{code-cell} ipython3
# GPU version
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
@ -178,19 +184,13 @@ def iota(size: int):
iota(8)
```
On GPUs, each program is executed in parallel on separate threads.
Thus, we need to think about race conditions on writes to HBM.
A reasonable approach is to write our kernels in such a way that different
programs write to disjoint places in HBM to avoid these parallel writes.
On the other hand, parallelizing the computation is how we can execute
operations like matrix multiplications really quickly.
On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.
TPUs distinguish between vector and scalar memory spaces and in this case the
output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.
To call the above kernel on TPU, run:
```{code-cell} ipython3
# TPU version
from jax.experimental.pallas import tpu as pltpu
def iota(size: int):
@ -201,11 +201,22 @@ def iota(size: int):
iota(8)
```
TPUs distinguish between vector and scalar memory spaces and in this case the
output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
a scalar. For more details read {ref}`pallas_tpu_pipelining`.
### Grid semantics
You can read more details at {ref}`pallas_grid`.
On GPUs, each program is executed in parallel on separate threads.
Thus, we need to think about race conditions on writes to HBM.
A reasonable approach is to write our kernels in such a way that different
programs write to disjoint locations in HBM to avoid these parallel writes.
On the other hand, parallelizing the computation is how we can execute
operations like matrix multiplications really quickly.
In contrast, TPUs operate like a very wide SIMD machine.
Some TPU models contain multiple cores, but in many cases a TPU can be
treated as a single-threaded processor. The grid on a TPU can be
specified in a combination of parallel and sequential dimensions, where sequential
dimensions are guaranteed to run serially.
You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`.
+++

View File

@ -48,12 +48,20 @@
},
{
"cell_type": "markdown",
"id": "0e212a5e",
"metadata": {
"id": "TWKESTKAlyjT"
},
"source": [
"## TPU and its memory spaces\n",
"(tpu_and_its_memory_spaces)=\n",
"\n",
"## TPU and its memory spaces"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n",
"registers (which temporarily store scalar and array values) and compute units\n",
"(that do computation with values in registers).\n",

View File

@ -38,8 +38,12 @@ import numpy as np
+++ {"id": "TWKESTKAlyjT"}
(tpu_and_its_memory_spaces)=
## TPU and its memory spaces
+++
A TPU and its TensorCore consist of memory spaces (where arrays can reside),
registers (which temporarily store scalar and array values) and compute units
(that do computation with values in registers).