mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Pallas] Pallas documentation cleanup
This commit is contained in:
parent
1a3c9c44dc
commit
e05afefc97
@ -94,7 +94,7 @@ Pallas kernels via JAX transformations.
|
||||
|
||||
<center>
|
||||
|
||||

|
||||

|
||||
Visualization of Pallas lowering paths
|
||||
|
||||
</center>
|
||||
@ -413,10 +413,10 @@ verify the correctness of the Triton and Mosaic compilers.
|
||||
One could also imagine perturbing the `scan` ordering to simulate the
|
||||
parallel reads and writes that happen on GPU.
|
||||
|
||||
### Examples
|
||||
### GPU Examples
|
||||
|
||||
Note all the following examples are for GPU only. They will require some small
|
||||
changes to work on TPUs.
|
||||
Note all the following examples are for GPU only. They will require tweaks to
|
||||
the block sizes to work on TPUs.
|
||||
|
||||
#### `add`
|
||||
|
9
docs/pallas/design/index.rst
Normal file
9
docs/pallas/design/index.rst
Normal file
@ -0,0 +1,9 @@
|
||||
Pallas Design Notes
|
||||
===================
|
||||
|
||||
.. toctree::
|
||||
:caption: Design
|
||||
:maxdepth: 2
|
||||
|
||||
design
|
||||
async_note
|
@ -44,39 +44,7 @@ For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
|
||||
You can also use {func}`jax.experimental.pallas.num_programs` to get the
|
||||
grid size for a given axis.
|
||||
|
||||
Here's an example kernel that uses a `grid` and `program_id`.
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax.experimental import pallas as pl
|
||||
|
||||
>>> def iota_kernel(o_ref):
|
||||
... i = pl.program_id(0)
|
||||
... o_ref[i] = i
|
||||
|
||||
```
|
||||
|
||||
We now execute it using `pallas_call` with an additional `grid` argument.
|
||||
|
||||
```python
|
||||
>>> def iota(size: int):
|
||||
... return pl.pallas_call(iota_kernel,
|
||||
... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
|
||||
... grid=(size,), interpret=True)()
|
||||
>>> iota(8)
|
||||
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
|
||||
|
||||
```
|
||||
|
||||
On GPUs, each program is executed in parallel on separate thread blocks.
|
||||
Thus, we need to think about race conditions on writes to HBM.
|
||||
A reasonable approach is to write our kernels in such a way that different
|
||||
programs write to disjoint places in HBM to avoid these parallel writes.
|
||||
|
||||
On TPUs, programs are executed in a combination of parallel and sequential
|
||||
(depending on the architecture) so there are slightly different considerations.
|
||||
|
||||
See {ref}`pallas_tpu_noteworthy_properties`.
|
||||
See {ref}`grids_by_example` for a simple kernel that uses this API.
|
||||
|
||||
(pallas_blockspec)=
|
||||
|
||||
@ -131,6 +99,8 @@ shape `x_shape` are computed as in the function `slice_for_invocation`
|
||||
below:
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax.experimental import pallas as pl
|
||||
>>> def slices_for_invocation(x_shape: tuple[int, ...],
|
||||
... x_spec: pl.BlockSpec,
|
||||
... grid: tuple[int, ...],
|
||||
|
@ -22,7 +22,6 @@ See also the :class:`jax.experimental.pallas` module API documentation.
|
||||
:maxdepth: 2
|
||||
|
||||
quickstart
|
||||
design
|
||||
grid_blockspec
|
||||
|
||||
|
||||
@ -34,9 +33,9 @@ See also the :class:`jax.experimental.pallas` module API documentation.
|
||||
|
||||
.. toctree::
|
||||
:caption: Design Notes
|
||||
:maxdepth: 1
|
||||
:maxdepth: 2
|
||||
|
||||
async_note
|
||||
design/index
|
||||
|
||||
.. toctree::
|
||||
:caption: Other
|
||||
|
@ -72,8 +72,9 @@
|
||||
"\n",
|
||||
"Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n",
|
||||
"it does not take in `jax.Array`s as inputs and doesn't return any values.\n",
|
||||
"Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
|
||||
"but we are given an `o_ref`, which corresponds to the desired output.\n",
|
||||
"Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.\n",
|
||||
"Note that we also don't have any outputs but we are given an `o_ref`, which corresponds\n",
|
||||
"to the desired output.\n",
|
||||
"\n",
|
||||
"**Reading from `Ref`s**\n",
|
||||
"\n",
|
||||
@ -150,7 +151,8 @@
|
||||
"**What's actually happening here?**\n",
|
||||
"\n",
|
||||
"Thus far we've described how to think about Pallas kernels but what we've actually\n",
|
||||
"accomplished is we're writing a function that's executed very close to the compute units.\n",
|
||||
"accomplished is we're writing a function that's executed very close to the compute units\n",
|
||||
"since values are loaded into the innermost (fastest) portion of the memory hierarchy.\n",
|
||||
"\n",
|
||||
"On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n",
|
||||
"we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n",
|
||||
@ -195,6 +197,8 @@
|
||||
"live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n",
|
||||
"that operate on \"blocks\" of those arrays that can fit in SRAM.\n",
|
||||
"\n",
|
||||
"(grids_by_example)=\n",
|
||||
"\n",
|
||||
"### Grids by example\n",
|
||||
"\n",
|
||||
"To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n",
|
||||
@ -240,7 +244,8 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We now execute it using `pallas_call` with an additional `grid` argument."
|
||||
"We now execute it using `pallas_call` with an additional `grid` argument.\n",
|
||||
"On GPUs, we can call the kernel directly like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -260,6 +265,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# GPU version\n",
|
||||
"def iota(size: int):\n",
|
||||
" return pl.pallas_call(iota_kernel,\n",
|
||||
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
|
||||
@ -272,16 +278,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"On GPUs, each program is executed in parallel on separate threads.\n",
|
||||
"Thus, we need to think about race conditions on writes to HBM.\n",
|
||||
"A reasonable approach is to write our kernels in such a way that different\n",
|
||||
"programs write to disjoint places in HBM to avoid these parallel writes.\n",
|
||||
"On the other hand, parallelizing the computation is how we can execute\n",
|
||||
"operations like matrix multiplications really quickly.\n",
|
||||
"\n",
|
||||
"On TPUs, programs are executed in a combination of parallel and sequential\n",
|
||||
"(depending on the architecture) so there are slightly different considerations.\n",
|
||||
"\n",
|
||||
"TPUs distinguish between vector and scalar memory spaces and in this case the\n",
|
||||
"output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
|
||||
"a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n",
|
||||
"To call the above kernel on TPU, run:"
|
||||
]
|
||||
},
|
||||
@ -292,6 +291,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# TPU version\n",
|
||||
"from jax.experimental.pallas import tpu as pltpu\n",
|
||||
"\n",
|
||||
"def iota(size: int):\n",
|
||||
@ -307,11 +307,22 @@
|
||||
"id": "68f97b4e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"TPUs distinguish between vector and scalar memory spaces and in this case the\n",
|
||||
"output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
|
||||
"a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n",
|
||||
"### Grid semantics\n",
|
||||
"\n",
|
||||
"You can read more details at {ref}`pallas_grid`."
|
||||
"On GPUs, each program is executed in parallel on separate threads.\n",
|
||||
"Thus, we need to think about race conditions on writes to HBM.\n",
|
||||
"A reasonable approach is to write our kernels in such a way that different\n",
|
||||
"programs write to disjoint locations in HBM to avoid these parallel writes.\n",
|
||||
"On the other hand, parallelizing the computation is how we can execute\n",
|
||||
"operations like matrix multiplications really quickly.\n",
|
||||
"\n",
|
||||
"In contrast, TPUs operate like a very wide SIMD machine.\n",
|
||||
"Some TPU models contain multiple cores, but in many cases a TPU can be\n",
|
||||
"treated as a single-threaded processor. The grid on a TPU can be\n",
|
||||
"specified in a combination of parallel and sequential dimensions, where sequential\n",
|
||||
"dimensions are guaranteed to run serially.\n",
|
||||
"\n",
|
||||
"You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -53,8 +53,9 @@ def add_vectors_kernel(x_ref, y_ref, o_ref):
|
||||
|
||||
Let's dissect this function a bit. Unlike most JAX functions you've probably written,
|
||||
it does not take in `jax.Array`s as inputs and doesn't return any values.
|
||||
Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
|
||||
but we are given an `o_ref`, which corresponds to the desired output.
|
||||
Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.
|
||||
Note that we also don't have any outputs but we are given an `o_ref`, which corresponds
|
||||
to the desired output.
|
||||
|
||||
**Reading from `Ref`s**
|
||||
|
||||
@ -101,7 +102,8 @@ thereof).
|
||||
**What's actually happening here?**
|
||||
|
||||
Thus far we've described how to think about Pallas kernels but what we've actually
|
||||
accomplished is we're writing a function that's executed very close to the compute units.
|
||||
accomplished is we're writing a function that's executed very close to the compute units
|
||||
since values are loaded into the innermost (fastest) portion of the memory hierarchy.
|
||||
|
||||
On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when
|
||||
we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)
|
||||
@ -134,6 +136,8 @@ Part of writing Pallas kernels is thinking about how to take big arrays that
|
||||
live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations
|
||||
that operate on "blocks" of those arrays that can fit in SRAM.
|
||||
|
||||
(grids_by_example)=
|
||||
|
||||
### Grids by example
|
||||
|
||||
To automatically "carve" up the inputs and outputs, you provide a `grid` and
|
||||
@ -169,8 +173,10 @@ def iota_kernel(o_ref):
|
||||
```
|
||||
|
||||
We now execute it using `pallas_call` with an additional `grid` argument.
|
||||
On GPUs, we can call the kernel directly like so:
|
||||
|
||||
```{code-cell} ipython3
|
||||
# GPU version
|
||||
def iota(size: int):
|
||||
return pl.pallas_call(iota_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
|
||||
@ -178,19 +184,13 @@ def iota(size: int):
|
||||
iota(8)
|
||||
```
|
||||
|
||||
On GPUs, each program is executed in parallel on separate threads.
|
||||
Thus, we need to think about race conditions on writes to HBM.
|
||||
A reasonable approach is to write our kernels in such a way that different
|
||||
programs write to disjoint places in HBM to avoid these parallel writes.
|
||||
On the other hand, parallelizing the computation is how we can execute
|
||||
operations like matrix multiplications really quickly.
|
||||
|
||||
On TPUs, programs are executed in a combination of parallel and sequential
|
||||
(depending on the architecture) so there are slightly different considerations.
|
||||
|
||||
TPUs distinguish between vector and scalar memory spaces and in this case the
|
||||
output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
|
||||
a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.
|
||||
To call the above kernel on TPU, run:
|
||||
|
||||
```{code-cell} ipython3
|
||||
# TPU version
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
|
||||
def iota(size: int):
|
||||
@ -201,11 +201,22 @@ def iota(size: int):
|
||||
iota(8)
|
||||
```
|
||||
|
||||
TPUs distinguish between vector and scalar memory spaces and in this case the
|
||||
output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
|
||||
a scalar. For more details read {ref}`pallas_tpu_pipelining`.
|
||||
### Grid semantics
|
||||
|
||||
You can read more details at {ref}`pallas_grid`.
|
||||
On GPUs, each program is executed in parallel on separate threads.
|
||||
Thus, we need to think about race conditions on writes to HBM.
|
||||
A reasonable approach is to write our kernels in such a way that different
|
||||
programs write to disjoint locations in HBM to avoid these parallel writes.
|
||||
On the other hand, parallelizing the computation is how we can execute
|
||||
operations like matrix multiplications really quickly.
|
||||
|
||||
In contrast, TPUs operate like a very wide SIMD machine.
|
||||
Some TPU models contain multiple cores, but in many cases a TPU can be
|
||||
treated as a single-threaded processor. The grid on a TPU can be
|
||||
specified in a combination of parallel and sequential dimensions, where sequential
|
||||
dimensions are guaranteed to run serially.
|
||||
|
||||
You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`.
|
||||
|
||||
+++
|
||||
|
||||
|
@ -48,12 +48,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e212a5e",
|
||||
"metadata": {
|
||||
"id": "TWKESTKAlyjT"
|
||||
},
|
||||
"source": [
|
||||
"## TPU and its memory spaces\n",
|
||||
"(tpu_and_its_memory_spaces)=\n",
|
||||
"\n",
|
||||
"## TPU and its memory spaces"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n",
|
||||
"registers (which temporarily store scalar and array values) and compute units\n",
|
||||
"(that do computation with values in registers).\n",
|
||||
|
@ -38,8 +38,12 @@ import numpy as np
|
||||
|
||||
+++ {"id": "TWKESTKAlyjT"}
|
||||
|
||||
(tpu_and_its_memory_spaces)=
|
||||
|
||||
## TPU and its memory spaces
|
||||
|
||||
+++
|
||||
|
||||
A TPU and its TensorCore consist of memory spaces (where arrays can reside),
|
||||
registers (which temporarily store scalar and array values) and compute units
|
||||
(that do computation with values in registers).
|
||||
|
Loading…
x
Reference in New Issue
Block a user