diff --git a/docs/pallas/async_note.md b/docs/pallas/design/async_note.md
similarity index 100%
rename from docs/pallas/async_note.md
rename to docs/pallas/design/async_note.md
diff --git a/docs/pallas/design.md b/docs/pallas/design/design.md
similarity index 99%
rename from docs/pallas/design.md
rename to docs/pallas/design/design.md
index f6fc8f592..17c7a6dbd 100644
--- a/docs/pallas/design.md
+++ b/docs/pallas/design/design.md
@@ -94,7 +94,7 @@ Pallas kernels via JAX transformations.
-
+
Visualization of Pallas lowering paths
@@ -413,10 +413,10 @@ verify the correctness of the Triton and Mosaic compilers.
One could also imagine perturbing the `scan` ordering to simulate the
parallel reads and writes that happen on GPU.
-### Examples
+### GPU Examples
-Note all the following examples are for GPU only. They will require some small
-changes to work on TPUs.
+Note all the following examples are for GPU only. They will require tweaks to
+the block sizes to work on TPUs.
#### `add`
diff --git a/docs/pallas/design/index.rst b/docs/pallas/design/index.rst
new file mode 100644
index 000000000..d11a13d39
--- /dev/null
+++ b/docs/pallas/design/index.rst
@@ -0,0 +1,9 @@
+Pallas Design Notes
+===================
+
+.. toctree::
+ :caption: Design
+ :maxdepth: 2
+
+ design
+ async_note
diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md
index cde200528..c1b2c2b95 100644
--- a/docs/pallas/grid_blockspec.md
+++ b/docs/pallas/grid_blockspec.md
@@ -44,39 +44,7 @@ For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
You can also use {func}`jax.experimental.pallas.num_programs` to get the
grid size for a given axis.
-Here's an example kernel that uses a `grid` and `program_id`.
-
-```python
->>> import jax
->>> from jax.experimental import pallas as pl
-
->>> def iota_kernel(o_ref):
-... i = pl.program_id(0)
-... o_ref[i] = i
-
-```
-
-We now execute it using `pallas_call` with an additional `grid` argument.
-
-```python
->>> def iota(size: int):
-... return pl.pallas_call(iota_kernel,
-... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
-... grid=(size,), interpret=True)()
->>> iota(8)
-Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
-
-```
-
-On GPUs, each program is executed in parallel on separate thread blocks.
-Thus, we need to think about race conditions on writes to HBM.
-A reasonable approach is to write our kernels in such a way that different
-programs write to disjoint places in HBM to avoid these parallel writes.
-
-On TPUs, programs are executed in a combination of parallel and sequential
-(depending on the architecture) so there are slightly different considerations.
-
-See {ref}`pallas_tpu_noteworthy_properties`.
+See {ref}`grids_by_example` for a simple kernel that uses this API.
(pallas_blockspec)=
@@ -131,6 +99,8 @@ shape `x_shape` are computed as in the function `slice_for_invocation`
below:
```python
+>>> import jax
+>>> from jax.experimental import pallas as pl
>>> def slices_for_invocation(x_shape: tuple[int, ...],
... x_spec: pl.BlockSpec,
... grid: tuple[int, ...],
diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst
index 5969349c9..b2e2fca6c 100644
--- a/docs/pallas/index.rst
+++ b/docs/pallas/index.rst
@@ -22,7 +22,6 @@ See also the :class:`jax.experimental.pallas` module API documentation.
:maxdepth: 2
quickstart
- design
grid_blockspec
@@ -34,9 +33,9 @@ See also the :class:`jax.experimental.pallas` module API documentation.
.. toctree::
:caption: Design Notes
- :maxdepth: 1
+ :maxdepth: 2
- async_note
+ design/index
.. toctree::
:caption: Other
diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb
index af34d1674..11dd2108e 100644
--- a/docs/pallas/quickstart.ipynb
+++ b/docs/pallas/quickstart.ipynb
@@ -72,8 +72,9 @@
"\n",
"Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n",
"it does not take in `jax.Array`s as inputs and doesn't return any values.\n",
- "Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
- "but we are given an `o_ref`, which corresponds to the desired output.\n",
+ "Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.\n",
+ "Note that we also don't have any outputs but we are given an `o_ref`, which corresponds\n",
+ "to the desired output.\n",
"\n",
"**Reading from `Ref`s**\n",
"\n",
@@ -150,7 +151,8 @@
"**What's actually happening here?**\n",
"\n",
"Thus far we've described how to think about Pallas kernels but what we've actually\n",
- "accomplished is we're writing a function that's executed very close to the compute units.\n",
+ "accomplished is we're writing a function that's executed very close to the compute units\n",
+ "since values are loaded into the innermost (fastest) portion of the memory hierarchy.\n",
"\n",
"On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n",
"we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n",
@@ -195,6 +197,8 @@
"live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n",
"that operate on \"blocks\" of those arrays that can fit in SRAM.\n",
"\n",
+ "(grids_by_example)=\n",
+ "\n",
"### Grids by example\n",
"\n",
"To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n",
@@ -240,7 +244,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "We now execute it using `pallas_call` with an additional `grid` argument."
+ "We now execute it using `pallas_call` with an additional `grid` argument.\n",
+ "On GPUs, we can call the kernel directly like so:"
]
},
{
@@ -260,6 +265,7 @@
}
],
"source": [
+ "# GPU version\n",
"def iota(size: int):\n",
" return pl.pallas_call(iota_kernel,\n",
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
@@ -272,16 +278,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "On GPUs, each program is executed in parallel on separate threads.\n",
- "Thus, we need to think about race conditions on writes to HBM.\n",
- "A reasonable approach is to write our kernels in such a way that different\n",
- "programs write to disjoint places in HBM to avoid these parallel writes.\n",
- "On the other hand, parallelizing the computation is how we can execute\n",
- "operations like matrix multiplications really quickly.\n",
- "\n",
- "On TPUs, programs are executed in a combination of parallel and sequential\n",
- "(depending on the architecture) so there are slightly different considerations.\n",
- "\n",
+ "TPUs distinguish between vector and scalar memory spaces and in this case the\n",
+ "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
+ "a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n",
"To call the above kernel on TPU, run:"
]
},
@@ -292,6 +291,7 @@
"metadata": {},
"outputs": [],
"source": [
+ "# TPU version\n",
"from jax.experimental.pallas import tpu as pltpu\n",
"\n",
"def iota(size: int):\n",
@@ -307,11 +307,22 @@
"id": "68f97b4e",
"metadata": {},
"source": [
- "TPUs distinguish between vector and scalar memory spaces and in this case the\n",
- "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
- "a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n",
+ "### Grid semantics\n",
"\n",
- "You can read more details at {ref}`pallas_grid`."
+ "On GPUs, each program is executed in parallel on separate threads.\n",
+ "Thus, we need to think about race conditions on writes to HBM.\n",
+ "A reasonable approach is to write our kernels in such a way that different\n",
+ "programs write to disjoint locations in HBM to avoid these parallel writes.\n",
+ "On the other hand, parallelizing the computation is how we can execute\n",
+ "operations like matrix multiplications really quickly.\n",
+ "\n",
+ "In contrast, TPUs operate like a very wide SIMD machine.\n",
+ "Some TPU models contain multiple cores, but in many cases a TPU can be\n",
+ "treated as a single-threaded processor. The grid on a TPU can be\n",
+ "specified in a combination of parallel and sequential dimensions, where sequential\n",
+ "dimensions are guaranteed to run serially.\n",
+ "\n",
+ "You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`."
]
},
{
diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md
index e11868f5f..fff1dcb73 100644
--- a/docs/pallas/quickstart.md
+++ b/docs/pallas/quickstart.md
@@ -53,8 +53,9 @@ def add_vectors_kernel(x_ref, y_ref, o_ref):
Let's dissect this function a bit. Unlike most JAX functions you've probably written,
it does not take in `jax.Array`s as inputs and doesn't return any values.
-Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
-but we are given an `o_ref`, which corresponds to the desired output.
+Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.
+Note that we also don't have any outputs but we are given an `o_ref`, which corresponds
+to the desired output.
**Reading from `Ref`s**
@@ -101,7 +102,8 @@ thereof).
**What's actually happening here?**
Thus far we've described how to think about Pallas kernels but what we've actually
-accomplished is we're writing a function that's executed very close to the compute units.
+accomplished is we're writing a function that's executed very close to the compute units
+since values are loaded into the innermost (fastest) portion of the memory hierarchy.
On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when
we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)
@@ -134,6 +136,8 @@ Part of writing Pallas kernels is thinking about how to take big arrays that
live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations
that operate on "blocks" of those arrays that can fit in SRAM.
+(grids_by_example)=
+
### Grids by example
To automatically "carve" up the inputs and outputs, you provide a `grid` and
@@ -169,8 +173,10 @@ def iota_kernel(o_ref):
```
We now execute it using `pallas_call` with an additional `grid` argument.
+On GPUs, we can call the kernel directly like so:
```{code-cell} ipython3
+# GPU version
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
@@ -178,19 +184,13 @@ def iota(size: int):
iota(8)
```
-On GPUs, each program is executed in parallel on separate threads.
-Thus, we need to think about race conditions on writes to HBM.
-A reasonable approach is to write our kernels in such a way that different
-programs write to disjoint places in HBM to avoid these parallel writes.
-On the other hand, parallelizing the computation is how we can execute
-operations like matrix multiplications really quickly.
-
-On TPUs, programs are executed in a combination of parallel and sequential
-(depending on the architecture) so there are slightly different considerations.
-
+TPUs distinguish between vector and scalar memory spaces and in this case the
+output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
+a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.
To call the above kernel on TPU, run:
```{code-cell} ipython3
+# TPU version
from jax.experimental.pallas import tpu as pltpu
def iota(size: int):
@@ -201,11 +201,22 @@ def iota(size: int):
iota(8)
```
-TPUs distinguish between vector and scalar memory spaces and in this case the
-output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
-a scalar. For more details read {ref}`pallas_tpu_pipelining`.
+### Grid semantics
-You can read more details at {ref}`pallas_grid`.
+On GPUs, each program is executed in parallel on separate threads.
+Thus, we need to think about race conditions on writes to HBM.
+A reasonable approach is to write our kernels in such a way that different
+programs write to disjoint locations in HBM to avoid these parallel writes.
+On the other hand, parallelizing the computation is how we can execute
+operations like matrix multiplications really quickly.
+
+In contrast, TPUs operate like a very wide SIMD machine.
+Some TPU models contain multiple cores, but in many cases a TPU can be
+treated as a single-threaded processor. The grid on a TPU can be
+specified in a combination of parallel and sequential dimensions, where sequential
+dimensions are guaranteed to run serially.
+
+You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`.
+++
diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb
index 9774e08dc..10de58710 100644
--- a/docs/pallas/tpu/pipelining.ipynb
+++ b/docs/pallas/tpu/pipelining.ipynb
@@ -48,12 +48,20 @@
},
{
"cell_type": "markdown",
+ "id": "0e212a5e",
"metadata": {
"id": "TWKESTKAlyjT"
},
"source": [
- "## TPU and its memory spaces\n",
+ "(tpu_and_its_memory_spaces)=\n",
"\n",
+ "## TPU and its memory spaces"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
"A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n",
"registers (which temporarily store scalar and array values) and compute units\n",
"(that do computation with values in registers).\n",
diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md
index 218654301..df570cf08 100644
--- a/docs/pallas/tpu/pipelining.md
+++ b/docs/pallas/tpu/pipelining.md
@@ -38,8 +38,12 @@ import numpy as np
+++ {"id": "TWKESTKAlyjT"}
+(tpu_and_its_memory_spaces)=
+
## TPU and its memory spaces
++++
+
A TPU and its TensorCore consist of memory spaces (where arrays can reside),
registers (which temporarily store scalar and array values) and compute units
(that do computation with values in registers).