mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas] Update TPU documentation
This commit is contained in:
parent
1a3c9c44dc
commit
2b2d7cda98
1
docs/_static/pallas/vector_layout_example.svg
vendored
Normal file
1
docs/_static/pallas/vector_layout_example.svg
vendored
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 26 KiB |
@ -119,24 +119,44 @@ The output reference can be then used as an accumulator for partial results.
|
||||
spilled vector registers) exceeds the size of VMEM. In this case, you will likely see a
|
||||
low-level compiler error message complaining about an out-of-memory error.
|
||||
|
||||
Dimension ordering is meaningful
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Array Layouts
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Dimension ordering of arrays is meaningful in Pallas.
|
||||
In JAX programs, the ordering of intermediate arrays inside ``jax.jit`` usually
|
||||
has no impact on performance, as the compiler is free to rearrange them.
|
||||
However, as Pallas is meant to expose lower-level capabilities, the dimension
|
||||
order can have great impact on the quality of generated code.
|
||||
|
||||
Recall that the TPUs perform bulk of the computation on 2D vector registers.
|
||||
Pallas TPU will only ever consider mapping the last two dimensions of
|
||||
intermediate arrays to those vector register dimensions (sublanes and lanes
|
||||
respectively). An array of shape ``(n, 1, 1)`` is guaranteed to require at least
|
||||
``n`` vector registers to represent. If ``n`` becomes too large, this can lead
|
||||
to spills, and potential VMEM OOM errors due to an overly large memory footprint.
|
||||
But it also might not --- the low-level compiler is free to rearrange the
|
||||
instructions to lower the register pressure, and is in fact very good at it.
|
||||
Still, it is a good rule of thumb to keep the last two dimensions large
|
||||
(especially the last dimension), while keeping the leading dimensions small.
|
||||
TPUs perform bulk of the computation on 2D vector registers, which are typically of
|
||||
size 8x128 for 32-bit values (as of TPU v6).
|
||||
When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``),
|
||||
the last two dimensions of the array will be tiled into the registers.
|
||||
Pallas will only ever consider mapping the last two dimensions of
|
||||
intermediate arrays to the 8x128 vector register dimensions (sublanes and lanes
|
||||
respectively).
|
||||
|
||||
Here is a graphical example of how a 12x320 array can be tiled using 6 8x128
|
||||
tiles:
|
||||
|
||||
.. image:: ../../_static/pallas/vector_layout_example.svg
|
||||
|
||||
Tiled layouts have several import ramifications for kernel writers:
|
||||
|
||||
* The last two axes of an array are treated differently than other
|
||||
axes. For example, reductions, reshapes, and transposes are generally
|
||||
more expensive when involving the last two axes. Some reshapes
|
||||
involving the last two dimensions are not supported and will result in a compiler
|
||||
error, but are "free" and performed at compile time for other dimensions.
|
||||
* While sometimes unavoidable, it is generally wasteful to have singleton
|
||||
dimensions in the last two axes, since they will occupy 1 element out of
|
||||
the entire tile dimension. Consuming too many registers can
|
||||
also potentially cause register spills into VMEM which degrades kernel
|
||||
performance.
|
||||
* Related to the above point, all vector computation is padded up to the tile
|
||||
size. Adding a two 1x1 arrays costs as much as adding two 8x128 arrays, and
|
||||
adding two 8x128x1x1 arrays will be 1024 times as expensive as adding two
|
||||
8x128 arrays, since the 8x128x1x1 array will be padded to 8x128x8x128.
|
||||
|
||||
Multicore TPU configurations
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@ -196,18 +216,19 @@ for those arguments. But, the ``BlockSpec``\s for all subsequent arguments will
|
||||
receive not only the grid indices, but also the SMEM references to the leading
|
||||
operands.
|
||||
|
||||
.. note::
|
||||
We are working on implementing examples for this feature. Stay tuned!
|
||||
See :ref:`pallas_scalar_prefetch_guide` for examples on using this
|
||||
feature.
|
||||
|
||||
Supported data types
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
At the moment Pallas TPU only supports the following data types:
|
||||
At the moment Pallas TPU supports the following data types:
|
||||
|
||||
* ``jnp.float32``
|
||||
* ``jnp.bfloat16``
|
||||
* ``jnp.int*`` (all precisions, except for ``jnp.int4``)
|
||||
* ``jnp.uint*`` (all precisions)
|
||||
* ``jnp.bool_``
|
||||
|
||||
Computation placement
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
@ -306,14 +327,13 @@ Array constructors
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
All constant array constructors are supported (``jnp.ones``, ``jnp.zeros``,
|
||||
``jnp.full``). Notably, the ``jax.random`` module is **not** compatible with
|
||||
Pallas as of today.
|
||||
``jnp.full``).
|
||||
|
||||
Reductions
|
||||
^^^^^^^^^^
|
||||
|
||||
Sum, maximum and minimum reductions are supported, but only on a single array
|
||||
axis at a time.
|
||||
``sum``, ``max``, ``min`` (for floating point values) reductions are supported, as well
|
||||
as ``any`` and ``all`` for boolean values. Integer reductions are not supported.
|
||||
|
||||
Reductions over the last array dimension are generally the slowest.
|
||||
Reductions over the second last dimension are faster, but still slower than
|
||||
@ -338,6 +358,14 @@ of an array is when (1) some leading dimensions are flattened onto the second
|
||||
to last dimension, or (2) it adds a dimension that was just removed by a
|
||||
reduction.
|
||||
|
||||
Random Number Generation
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Pallas supports the most commonly used functions from the ``jax.random`` module,
|
||||
such as ``uniform``, ``normal``, and ``bernoulli``. The key should be a ``threefry2x32`` key,
|
||||
which is the default setting in JAX. Keys can be directly passed into a kernel,
|
||||
or generated inside of a kernel.
|
||||
|
||||
Control flow
|
||||
^^^^^^^^^^^^
|
||||
|
||||
|
@ -6,6 +6,8 @@
|
||||
"id": "ZHuzXqQ-9JUQ"
|
||||
},
|
||||
"source": [
|
||||
"(pallas_scalar_prefetch_guide)=\n",
|
||||
"\n",
|
||||
"# Scalar Prefetch and Block-Sparse Computation\n",
|
||||
"\n",
|
||||
"In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory."
|
||||
|
@ -14,6 +14,8 @@ kernelspec:
|
||||
|
||||
+++ {"id": "ZHuzXqQ-9JUQ"}
|
||||
|
||||
(pallas_scalar_prefetch_guide)=
|
||||
|
||||
# Scalar Prefetch and Block-Sparse Computation
|
||||
|
||||
In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory.
|
||||
|
Loading…
x
Reference in New Issue
Block a user