From 2b2d7cda985f358d7b7e89ca9983ca0c0339bdc7 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 4 Dec 2024 16:38:42 -0800 Subject: [PATCH] [Pallas] Update TPU documentation --- docs/_static/pallas/vector_layout_example.svg | 1 + docs/pallas/tpu/details.rst | 66 +++++++++++++------ docs/pallas/tpu/sparse.ipynb | 2 + docs/pallas/tpu/sparse.md | 2 + 4 files changed, 52 insertions(+), 19 deletions(-) create mode 100644 docs/_static/pallas/vector_layout_example.svg diff --git a/docs/_static/pallas/vector_layout_example.svg b/docs/_static/pallas/vector_layout_example.svg new file mode 100644 index 000000000..f1c940357 --- /dev/null +++ b/docs/_static/pallas/vector_layout_example.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index b7ce10d56..0575806e6 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -119,24 +119,44 @@ The output reference can be then used as an accumulator for partial results. spilled vector registers) exceeds the size of VMEM. In this case, you will likely see a low-level compiler error message complaining about an out-of-memory error. -Dimension ordering is meaningful -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Array Layouts +^^^^^^^^^^^^^ +Dimension ordering of arrays is meaningful in Pallas. In JAX programs, the ordering of intermediate arrays inside ``jax.jit`` usually has no impact on performance, as the compiler is free to rearrange them. However, as Pallas is meant to expose lower-level capabilities, the dimension order can have great impact on the quality of generated code. -Recall that the TPUs perform bulk of the computation on 2D vector registers. -Pallas TPU will only ever consider mapping the last two dimensions of -intermediate arrays to those vector register dimensions (sublanes and lanes -respectively). An array of shape ``(n, 1, 1)`` is guaranteed to require at least -``n`` vector registers to represent. If ``n`` becomes too large, this can lead -to spills, and potential VMEM OOM errors due to an overly large memory footprint. -But it also might not --- the low-level compiler is free to rearrange the -instructions to lower the register pressure, and is in fact very good at it. -Still, it is a good rule of thumb to keep the last two dimensions large -(especially the last dimension), while keeping the leading dimensions small. +TPUs perform bulk of the computation on 2D vector registers, which are typically of +size 8x128 for 32-bit values (as of TPU v6). +When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``), +the last two dimensions of the array will be tiled into the registers. +Pallas will only ever consider mapping the last two dimensions of +intermediate arrays to the 8x128 vector register dimensions (sublanes and lanes +respectively). + +Here is a graphical example of how a 12x320 array can be tiled using 6 8x128 +tiles: + +.. image:: ../../_static/pallas/vector_layout_example.svg + +Tiled layouts have several import ramifications for kernel writers: + +* The last two axes of an array are treated differently than other + axes. For example, reductions, reshapes, and transposes are generally + more expensive when involving the last two axes. Some reshapes + involving the last two dimensions are not supported and will result in a compiler + error, but are "free" and performed at compile time for other dimensions. +* While sometimes unavoidable, it is generally wasteful to have singleton + dimensions in the last two axes, since they will occupy 1 element out of + the entire tile dimension. Consuming too many registers can + also potentially cause register spills into VMEM which degrades kernel + performance. +* Related to the above point, all vector computation is padded up to the tile + size. Adding a two 1x1 arrays costs as much as adding two 8x128 arrays, and + adding two 8x128x1x1 arrays will be 1024 times as expensive as adding two + 8x128 arrays, since the 8x128x1x1 array will be padded to 8x128x8x128. Multicore TPU configurations ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -196,18 +216,19 @@ for those arguments. But, the ``BlockSpec``\s for all subsequent arguments will receive not only the grid indices, but also the SMEM references to the leading operands. -.. note:: - We are working on implementing examples for this feature. Stay tuned! +See :ref:`pallas_scalar_prefetch_guide` for examples on using this +feature. Supported data types ^^^^^^^^^^^^^^^^^^^^ -At the moment Pallas TPU only supports the following data types: +At the moment Pallas TPU supports the following data types: * ``jnp.float32`` * ``jnp.bfloat16`` * ``jnp.int*`` (all precisions, except for ``jnp.int4``) * ``jnp.uint*`` (all precisions) +* ``jnp.bool_`` Computation placement ^^^^^^^^^^^^^^^^^^^^^ @@ -306,14 +327,13 @@ Array constructors ^^^^^^^^^^^^^^^^^^ All constant array constructors are supported (``jnp.ones``, ``jnp.zeros``, -``jnp.full``). Notably, the ``jax.random`` module is **not** compatible with -Pallas as of today. +``jnp.full``). Reductions ^^^^^^^^^^ -Sum, maximum and minimum reductions are supported, but only on a single array -axis at a time. +``sum``, ``max``, ``min`` (for floating point values) reductions are supported, as well +as ``any`` and ``all`` for boolean values. Integer reductions are not supported. Reductions over the last array dimension are generally the slowest. Reductions over the second last dimension are faster, but still slower than @@ -338,6 +358,14 @@ of an array is when (1) some leading dimensions are flattened onto the second to last dimension, or (2) it adds a dimension that was just removed by a reduction. +Random Number Generation +^^^^^^^^^^^^^^^^^^^^^^^^ + +Pallas supports the most commonly used functions from the ``jax.random`` module, +such as ``uniform``, ``normal``, and ``bernoulli``. The key should be a ``threefry2x32`` key, +which is the default setting in JAX. Keys can be directly passed into a kernel, +or generated inside of a kernel. + Control flow ^^^^^^^^^^^^ diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index a80ba4ebe..5b37e7b05 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -6,6 +6,8 @@ "id": "ZHuzXqQ-9JUQ" }, "source": [ + "(pallas_scalar_prefetch_guide)=\n", + "\n", "# Scalar Prefetch and Block-Sparse Computation\n", "\n", "In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory." diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 2ac25edb5..36a6e07e9 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -14,6 +14,8 @@ kernelspec: +++ {"id": "ZHuzXqQ-9JUQ"} +(pallas_scalar_prefetch_guide)= + # Scalar Prefetch and Block-Sparse Computation In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory.