[Pallas] Update TPU documentation

This commit is contained in:
Justin Fu 2024-12-04 16:38:42 -08:00
parent 1a3c9c44dc
commit 2b2d7cda98
4 changed files with 52 additions and 19 deletions

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 26 KiB

View File

@ -119,24 +119,44 @@ The output reference can be then used as an accumulator for partial results.
spilled vector registers) exceeds the size of VMEM. In this case, you will likely see a
low-level compiler error message complaining about an out-of-memory error.
Dimension ordering is meaningful
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Array Layouts
^^^^^^^^^^^^^
Dimension ordering of arrays is meaningful in Pallas.
In JAX programs, the ordering of intermediate arrays inside ``jax.jit`` usually
has no impact on performance, as the compiler is free to rearrange them.
However, as Pallas is meant to expose lower-level capabilities, the dimension
order can have great impact on the quality of generated code.
Recall that the TPUs perform bulk of the computation on 2D vector registers.
Pallas TPU will only ever consider mapping the last two dimensions of
intermediate arrays to those vector register dimensions (sublanes and lanes
respectively). An array of shape ``(n, 1, 1)`` is guaranteed to require at least
``n`` vector registers to represent. If ``n`` becomes too large, this can lead
to spills, and potential VMEM OOM errors due to an overly large memory footprint.
But it also might not --- the low-level compiler is free to rearrange the
instructions to lower the register pressure, and is in fact very good at it.
Still, it is a good rule of thumb to keep the last two dimensions large
(especially the last dimension), while keeping the leading dimensions small.
TPUs perform bulk of the computation on 2D vector registers, which are typically of
size 8x128 for 32-bit values (as of TPU v6).
When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``),
the last two dimensions of the array will be tiled into the registers.
Pallas will only ever consider mapping the last two dimensions of
intermediate arrays to the 8x128 vector register dimensions (sublanes and lanes
respectively).
Here is a graphical example of how a 12x320 array can be tiled using 6 8x128
tiles:
.. image:: ../../_static/pallas/vector_layout_example.svg
Tiled layouts have several import ramifications for kernel writers:
* The last two axes of an array are treated differently than other
axes. For example, reductions, reshapes, and transposes are generally
more expensive when involving the last two axes. Some reshapes
involving the last two dimensions are not supported and will result in a compiler
error, but are "free" and performed at compile time for other dimensions.
* While sometimes unavoidable, it is generally wasteful to have singleton
dimensions in the last two axes, since they will occupy 1 element out of
the entire tile dimension. Consuming too many registers can
also potentially cause register spills into VMEM which degrades kernel
performance.
* Related to the above point, all vector computation is padded up to the tile
size. Adding a two 1x1 arrays costs as much as adding two 8x128 arrays, and
adding two 8x128x1x1 arrays will be 1024 times as expensive as adding two
8x128 arrays, since the 8x128x1x1 array will be padded to 8x128x8x128.
Multicore TPU configurations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -196,18 +216,19 @@ for those arguments. But, the ``BlockSpec``\s for all subsequent arguments will
receive not only the grid indices, but also the SMEM references to the leading
operands.
.. note::
We are working on implementing examples for this feature. Stay tuned!
See :ref:`pallas_scalar_prefetch_guide` for examples on using this
feature.
Supported data types
^^^^^^^^^^^^^^^^^^^^
At the moment Pallas TPU only supports the following data types:
At the moment Pallas TPU supports the following data types:
* ``jnp.float32``
* ``jnp.bfloat16``
* ``jnp.int*`` (all precisions, except for ``jnp.int4``)
* ``jnp.uint*`` (all precisions)
* ``jnp.bool_``
Computation placement
^^^^^^^^^^^^^^^^^^^^^
@ -306,14 +327,13 @@ Array constructors
^^^^^^^^^^^^^^^^^^
All constant array constructors are supported (``jnp.ones``, ``jnp.zeros``,
``jnp.full``). Notably, the ``jax.random`` module is **not** compatible with
Pallas as of today.
``jnp.full``).
Reductions
^^^^^^^^^^
Sum, maximum and minimum reductions are supported, but only on a single array
axis at a time.
``sum``, ``max``, ``min`` (for floating point values) reductions are supported, as well
as ``any`` and ``all`` for boolean values. Integer reductions are not supported.
Reductions over the last array dimension are generally the slowest.
Reductions over the second last dimension are faster, but still slower than
@ -338,6 +358,14 @@ of an array is when (1) some leading dimensions are flattened onto the second
to last dimension, or (2) it adds a dimension that was just removed by a
reduction.
Random Number Generation
^^^^^^^^^^^^^^^^^^^^^^^^
Pallas supports the most commonly used functions from the ``jax.random`` module,
such as ``uniform``, ``normal``, and ``bernoulli``. The key should be a ``threefry2x32`` key,
which is the default setting in JAX. Keys can be directly passed into a kernel,
or generated inside of a kernel.
Control flow
^^^^^^^^^^^^

View File

@ -6,6 +6,8 @@
"id": "ZHuzXqQ-9JUQ"
},
"source": [
"(pallas_scalar_prefetch_guide)=\n",
"\n",
"# Scalar Prefetch and Block-Sparse Computation\n",
"\n",
"In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory."

View File

@ -14,6 +14,8 @@ kernelspec:
+++ {"id": "ZHuzXqQ-9JUQ"}
(pallas_scalar_prefetch_guide)=
# Scalar Prefetch and Block-Sparse Computation
In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory.