mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
672 lines
74 KiB
Plaintext
672 lines
74 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "teoJ_fUwlu0l"
|
|
},
|
|
"source": [
|
|
"# Pipelining and `BlockSpec`s"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "gAJDZh1gBh-h"
|
|
},
|
|
"source": [
|
|
"In this guide we'll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "ejAVO6ikUUuF"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Imports\n",
|
|
"\n",
|
|
"import jax\n",
|
|
"from jax.experimental import pallas as pl\n",
|
|
"import jax.numpy as jnp\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "TWKESTKAlyjT"
|
|
},
|
|
"source": [
|
|
"## TPU and its memory spaces\n",
|
|
"\n",
|
|
"A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units (that do computation with values in registers). Below is a diagram of a TPU in which `x` and `y` are arrays that live in high-bandwidth memory (HBM):\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"Let's talk about the components of this diagram in more detail:\n",
|
|
"\n",
|
|
"* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we often think of as \"device memory\". There is also vector memory (VMEM), a cache meant for storing vector and array values, and scalar memory (SMEM), a cache designed to store scalar values.\n",
|
|
"* **Registers**: A TensorCore has two main types of registers: vector registers (VREGs) store array values, and scalar registers (SREGs) store scalar values. Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs).\n",
|
|
"* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and matrix unit (MXU) that can do numerical computation. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well.\n",
|
|
"\n",
|
|
"In order to do a vectorized computation on our values `x` and `y` that live in HBM, we need to:\n",
|
|
"\n",
|
|
"1. Copy the values `x` and `y` into VMEM.\n",
|
|
"2. Load the values from VMEM into VREGs.\n",
|
|
"3. Execute the computation using the VPU or MXU, storing the output in VREGs.\n",
|
|
"4. Store the values in the output VREGs into VMEM.\n",
|
|
"5. Copy the output values in VMEM back to HBM."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "TzctMbNsn3vc"
|
|
},
|
|
"source": [
|
|
"Let's implement a Pallas function that does just that!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "2IXQxNWrKJyb",
|
|
"outputId": "d62eb493-5f92-4496-f113-d3cd24cb0b9f"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([[2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" ...,\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):\n",
|
|
" # Load x and y from VMEM into VREGs\n",
|
|
" x_vregs = x_vmem_ref[:, :]\n",
|
|
" y_vregs = y_vmem_ref[:, :]\n",
|
|
" # Execute a vectorized add\n",
|
|
" z_vregs = x_vregs + y_vregs\n",
|
|
" # Store the output values in VREGs back into VMEM\n",
|
|
" z_vmem_ref[:, :] = z_vregs\n",
|
|
"\n",
|
|
"\n",
|
|
"def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:\n",
|
|
" # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.\n",
|
|
" # It will then copy `x` and `y` from HBM into VMEM.\n",
|
|
" z = pl.pallas_call(\n",
|
|
" add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
|
|
" )(x, y)\n",
|
|
" # pallas_call will also copy the output from VMEM back into HBM.\n",
|
|
" return z\n",
|
|
"\n",
|
|
"\n",
|
|
"x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",
|
|
"add_matrices(x, y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "HMENNLy8okCL"
|
|
},
|
|
"source": [
|
|
"We've written two functions: `add_matrices_kernel` and `add_matrices`.\n",
|
|
"\n",
|
|
"`add_matrices_kernel` operates using `Ref`s that live in VMEM. Loading from a VMEM `Ref` produces a value that lives in VREGs. Values in VREGs behave like `jax.Array`s in that we can use `jnp` and `jax.lax` operations on then to produce new values that live in VREGs. When we produce the values we'd like to return, we store them in the output VMEM `Ref`.\n",
|
|
"\n",
|
|
"The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into `pallas_call`. `pallas_call` is responsible for copying `x` and `y` into VMEM and for allocating the VMEM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output VMEM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "5kWr-1tKpYro"
|
|
},
|
|
"source": [
|
|
"## Constraints of using VMEM/SMEM\n",
|
|
"\n",
|
|
"Pallas exposes access to lower level memory spaces like VMEM and SMEM but writing kernels utilizing them adds some considerations.\n",
|
|
"\n",
|
|
"1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB and SMEM ranges in the tens to hundreds of KiB. If our arrays are too big, we won't even be able to fit them into VMEM at all. For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays.\n",
|
|
"\n",
|
|
"2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and VMEM than actually performing the addition itself.\n",
|
|
"\n",
|
|
"With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our TPUs."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "_NTqvlbetB3P"
|
|
},
|
|
"source": [
|
|
"## Primer: Pipelining\n",
|
|
"\n",
|
|
"Pipelining our computation offers a way of dealing with both the memory capacity and bandwidth constraints in one fell swoop. What do we mean by pipelining?\n",
|
|
"\n",
|
|
"The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our compute units. Naively this is difficult because in our program above we copy *all* of `x` and `y` before we start doing any compute with them, creating a dependence between the copy and the compute.\n",
|
|
"\n",
|
|
"However, if we can chunk up our computation into several subcomputations (e.g. when we add two matrices, we can express that as addition of \"blocks\" of the original matrices together), we can now overlap the copies of one of those subcomputations with the compute of the other. Let's walk through a simple example:\n",
|
|
"\n",
|
|
"Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for example, split along the leading axis, resulting in two `(256, 512)` arrays for each input. We can now execute the following pipelined computation.\n",
|
|
"\n",
|
|
"1. Copy `x1` and `y1` into VMEM.\n",
|
|
"1. Start copying `x2` and `y2` into VMEM\n",
|
|
"2. Load `x1, y1` from VMEM into VREGs.\n",
|
|
"3. Execute the `z1 = x1 + y1` using the compute units.\n",
|
|
"4. Store `z1` into VMEM.\n",
|
|
"5. Start copying `z1` from VMEM back into HBM.\n",
|
|
"6. Wait until `x2, y2` have been copied into VMEM.\n",
|
|
"7. Load `x2, y2` from VMEM into VREGs.\n",
|
|
"8. Execute the `z2 = x2 + y2` using the compute units.\n",
|
|
"9. Store `z2` into VMEM.\n",
|
|
"10. Wait until `z1` is copied into HBM.\n",
|
|
"10. Start copying `z2` from VMEM back into HBM.\n",
|
|
"10. Wait until `z2` is copied into HBM.\n",
|
|
"\n",
|
|
"Any time we are doing compute here, we are asynchronously copying something. This means that some of the time spent copying is not wasted.\n",
|
|
"\n",
|
|
"The two most important numbers for determining how efficient a pipelined computation are a) how many floating point operations (FLOPs) we need to execute and b) how many bytes we need to copy to execute that computation. The ratio of these two (FLOPs/memory usage) is called the *arithmetic intensity* of an operation and determines if our pipeline will be compute bound or memory bound."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "gutx7y8uvZKH"
|
|
},
|
|
"source": [
|
|
"## Pipelining in Pallas"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "U-dPTjlBverB"
|
|
},
|
|
"source": [
|
|
"How do we implement a pipeline like the one above in Pallas? It seems like a complex sequence of asynchronous data operations and executing kernels that would be a pain to implement manually. Fear not! Pallas offers an API for expressing pipelines without too much boilerplate, namely through `grid`s and `BlockSpec`s."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "x-LQKu8HwED7"
|
|
},
|
|
"source": [
|
|
"### `grid`, a.k.a. kernels in a loop\n",
|
|
"\n",
|
|
"See how in the above pipelined example, we are executing the same logic multiple times: steps 3-5 and 8-10 both execute the same operations, only on different inputs. The generalized version of this is a loop in which the same kernel is executed multiple times. `pallas_call` provides an option to do exactly that.\n",
|
|
"\n",
|
|
"The number of iterations in the loop is specified via the `grid` argument to `pallas_call`. Conceptually:\n",
|
|
"```python\n",
|
|
"pl.pallas_call(some_kernel, grid=n)(...)\n",
|
|
"```\n",
|
|
"maps to\n",
|
|
"```python\n",
|
|
"for i in range(n):\n",
|
|
" # do HBM -> VMEM copies\n",
|
|
" some_kernel(...)\n",
|
|
" # do VMEM -> HBM copies\n",
|
|
"```\n",
|
|
"Grids can be generalized to be multi-dimensional, corresponding to nested loops. For example,\n",
|
|
"\n",
|
|
"```python\n",
|
|
"pl.pallas_call(some_kernel, grid=(n, m))(...)\n",
|
|
"```\n",
|
|
"is equivalent to\n",
|
|
"```python\n",
|
|
"for i in range(n):\n",
|
|
" for j in range(m):\n",
|
|
" # do HBM -> VMEM copies\n",
|
|
" some_kernel(...)\n",
|
|
" # do VMEM -> HBM copies\n",
|
|
"```\n",
|
|
"This generalizes to any tuple of integers (a length `d` grid will correspond to `d` nested loops)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "hRLr5JeyyEwM"
|
|
},
|
|
"source": [
|
|
"### `BlockSpec`, a.k.a. how to chunk up inputs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "miWgPkytyIIa"
|
|
},
|
|
"source": [
|
|
"The next piece of information we need to provide Pallas in order to automatically pipeline our computation is information on how to chunk it up. Specifically, we need to provide a mapping between *the iteration of the loop* to *which block of our inputs and outputs to be operated on*. A `BlockSpec` is exactly these two pieces of information.\n",
|
|
"\n",
|
|
" First we pick a `block_shape` for our inputs. In the pipelining example above, we had `(512, 512)`-shaped arrays and split them along the leading dimension into two `(256, 512)`-shaped arrays. In this pipeline, our `block_shape` would be `(256, 512)`.\n",
|
|
"\n",
|
|
"We then provide an `index_map` function that maps the iteration space to the blocks. Specifically, in the aforementioned pipeline, on the 1st iteration we'd like to select `x1` and on the second iteration we'd like to use `x2`. This can be expressed with the following `index_map`:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"def x_index_map(i):\n",
|
|
" return (i, 0)\n",
|
|
"```\n",
|
|
"\n",
|
|
"We'd then construct the `BlockSpec`:\n",
|
|
"```python\n",
|
|
"block_spec = pl.BlockSpec(x_index_map, (256, 512))\n",
|
|
"```\n",
|
|
"\n",
|
|
"The `BlockSpec`s for `y` and `z` will be the same as the one for `x`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "noybOKghzjwG"
|
|
},
|
|
"source": [
|
|
"### Putting it together\n",
|
|
"\n",
|
|
"We provide these arguments to `pallas_call` via `grid`, `in_specs` and `out_specs` (`in_specs` corresponds to the tuple of positional arguments, and `out_specs` corresponds to the output)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "ehKAYAwIojfv",
|
|
"outputId": "504bab29-83f3-4e1f-8664-1860ad15b6de"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([[2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" ...,\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:\n",
|
|
" block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n",
|
|
" return pl.pallas_call(\n",
|
|
" add_matrices_kernel,\n",
|
|
" out_shape=x,\n",
|
|
" in_specs=[block_spec, block_spec],\n",
|
|
" out_specs=block_spec,\n",
|
|
" grid=(2,))(x, y)\n",
|
|
"\n",
|
|
"add_matrices_pipelined(x, y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "rkytgIZYzz4t"
|
|
},
|
|
"source": [
|
|
"We've only added a little bit of code to our original function to add automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy lifting!\n",
|
|
"\n",
|
|
"How does it work? Well, the `BlockSpec`s provide enough information to start *prefetching* blocks of our input from HBM into VMEM. For example, if we are starting iteration `i` of our `grid`, we can pass `i + 1` into the `index_map` functions to obtain the blocks needed for the next iteration. We can then start an asynchronous copy for those blocks. Similarly for outputs, we can wait for the outputs of the previous iteration to be copied before starting the copy for the current iteration's outputs."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "7Xtz9oMs0ZRL"
|
|
},
|
|
"source": [
|
|
"### Parameterizing a pipeline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "esY4GcIB0bqQ"
|
|
},
|
|
"source": [
|
|
"It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do).\n",
|
|
"\n",
|
|
"Furthermore, we could also carve up the inputs and outputs along the 2nd dimension (we are only splitting along the first right now). Let's write a more general kernel that handles both of these features."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "VartelFd0YfY"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def add_matrices_pipelined_2d(\n",
|
|
" x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n",
|
|
") -> jax.Array:\n",
|
|
" m, n = x.shape\n",
|
|
" block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))\n",
|
|
"\n",
|
|
" return pl.pallas_call(\n",
|
|
" add_matrices_kernel,\n",
|
|
" out_shape=x,\n",
|
|
" in_specs=[block_spec, block_spec],\n",
|
|
" out_specs=block_spec,\n",
|
|
" grid=(m // bm, n // bn),\n",
|
|
" )(x, y)\n",
|
|
"\n",
|
|
"\n",
|
|
"np.testing.assert_array_equal(\n",
|
|
" add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y\n",
|
|
")\n",
|
|
"np.testing.assert_array_equal(\n",
|
|
" add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y\n",
|
|
")\n",
|
|
"np.testing.assert_array_equal(\n",
|
|
" add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "KrfeYwaW1QA-"
|
|
},
|
|
"source": [
|
|
"## Handling reductions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "P3SqEKDe3Mar"
|
|
},
|
|
"source": [
|
|
"How would you implement something like `jnp.sum` using `pallas_call`? Specifically, we'd like to pipeline across the reduction dimension.\n",
|
|
"\n",
|
|
"Take the example of reducing a `(8, 512, 512)`-shaped array to a `(512, 512)`-shaped one."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "JoT-ZKEk1R7l",
|
|
"outputId": "fd842223-98a5-4e5c-87fc-5dadc94da4fa"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([[8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" ...,\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x = jnp.ones((8, 512, 512))\n",
|
|
"jnp.sum(x, axis=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "5O3ByvuT3iyC"
|
|
},
|
|
"source": [
|
|
"To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration `i` load `x[i]` into VMEM. Then we could add `x[i]` to an output VMEM buffer. Let's implement this naively first."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "hqvv_WRQ3bvP",
|
|
"outputId": "200648d2-3f4d-4d1a-b95a-d2c1352cd7b8"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([[9., 9., 9., ..., 9., 9., 9.],\n",
|
|
" [9., 9., 9., ..., 9., 9., 9.],\n",
|
|
" [9., 9., 9., ..., 9., 9., 9.],\n",
|
|
" ...,\n",
|
|
" [9., 9., 9., ..., 9., 9., 9.],\n",
|
|
" [9., 9., 9., ..., 9., 9., 9.],\n",
|
|
" [9., 9., 9., ..., 9., 9., 9.]], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Warning: this implementation is incorrect!\n",
|
|
"\n",
|
|
"def naive_sum_kernel(x_ref, o_ref):\n",
|
|
" o_ref[...] += x_ref[...]\n",
|
|
"\n",
|
|
"def naive_sum(x: jax.Array) -> jax.Array:\n",
|
|
" grid, *out_shape = x.shape\n",
|
|
" return pl.pallas_call(\n",
|
|
" naive_sum_kernel,\n",
|
|
" grid=grid,\n",
|
|
" # None in `block_shape` means we pick a size of 1 and squeeze it away\n",
|
|
" in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n",
|
|
" out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n",
|
|
" out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n",
|
|
" )(x)\n",
|
|
"naive_sum(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Kv9qJYJY4jbK"
|
|
},
|
|
"source": [
|
|
"Notice how we've set up the `BlockSpec`s: we're loading the entirety of the `(512, 512)` dimension into VMEM (no pipelining there) but selecting the `i`-th dimension of `x` each iteration in the `index_map`. We are using a `None` for that dimension in the block shape, which indicates that we are selecting a singleton dimension from `x` that we would like squeeze away in the kernel. Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n",
|
|
"\n",
|
|
"`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that `o_ref` is unchanged over the course of the pipeline. This means that we can update its value each iteration by reading from and writing to it. Or can it? Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll be accumulating into garbage. This will result in the overall function outputting the incorrect value!\n",
|
|
"\n",
|
|
"Therefore, **whenever we do a reduction in a kernel, we need to make sure to initialize the `Ref` that is storing the reduced value**. We can accomplish this by conditionally writing a value to `out_ref` when we're on iteration 0. We can do this with the helper function `pl.when`, a convenience wrapper around `jax.lax.cond`, and `pl.program_id`, which queries which iteration in a grid axis we are in."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "JXN2RthX5cSw",
|
|
"outputId": "195df19b-a889-479b-95b6-1fb7281f1518"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([[8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" ...,\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.],\n",
|
|
" [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def sum_kernel(x_ref, o_ref):\n",
|
|
" @pl.when(pl.program_id(axis=0) == 0)\n",
|
|
" def _():\n",
|
|
" o_ref[...] = jnp.zeros_like(o_ref)\n",
|
|
"\n",
|
|
" o_ref[...] += x_ref[...]\n",
|
|
"\n",
|
|
"def sum(x: jax.Array) -> jax.Array:\n",
|
|
" grid, *out_shape = x.shape\n",
|
|
" return pl.pallas_call(\n",
|
|
" sum_kernel,\n",
|
|
" grid=grid,\n",
|
|
" # None in `block_shape` means we pick a size of 1 and squeeze it away\n",
|
|
" in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n",
|
|
" out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n",
|
|
" out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n",
|
|
" )(x)\n",
|
|
"sum(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "2828qXBI5ksZ"
|
|
},
|
|
"source": [
|
|
"This `sum` function now outputs the correct values!\n",
|
|
"\n",
|
|
"One last thing to note about reductions in Pallas are that **they must be done in the minormost (rightmost) dimensions of our grid** (our grid is 1-dimensional in the above example so we are reducing over its minormost dimension). This is because the pipeline that Pallas generates using the `BlockSpec`s, `grid` and kernel function *does not read outputs back from HBM*. Once you've written an output value back to HBM you cannot revisit it. Therefore, you cannot do a reduction across a grid dimension that has any revisiting and therefore all reductions need to happen in the rightmost dimensions."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "KvPFez9N8cKJ"
|
|
},
|
|
"source": [
|
|
"## TPUs in Megacore configuration"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "0f4HAVzQ8n71"
|
|
},
|
|
"source": [
|
|
"Some TPU chips have two TensorCores but appear as one device to JAX users. This is called \"megacore\". The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs and compute units but *share HBM*.\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously?\n",
|
|
"\n",
|
|
"The basic idea is that if we have embarassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "nQNa8RaQ-TR1",
|
|
"outputId": "385ed87c-d95c-466c-af77-df3845c979f2"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([[2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" ...,\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.],\n",
|
|
" [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:\n",
|
|
" block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n",
|
|
" return pl.pallas_call(\n",
|
|
" add_matrices_kernel,\n",
|
|
" out_shape=x,\n",
|
|
" in_specs=[block_spec, block_spec],\n",
|
|
" out_specs=block_spec,\n",
|
|
" grid=(2,),\n",
|
|
" mosaic_params=dict(dimension_semantics=(\"parallel\",)))(x, y)\n",
|
|
"\n",
|
|
"x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",
|
|
"add_matrices_pipelined_megacore(x, y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xG51AiUC-8cl"
|
|
},
|
|
"source": [
|
|
"`dimension_semantics` should be a tuple of same length as `grid` where each entry is either `\"parallel\"` or `\"arbitrary\"`. `\"parallel\"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `\"arbitrary\"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized.\n",
|
|
"\n",
|
|
"By specifying `dimension_semantics`, we now execute the kernel simultaneously on each TensorCore. Pallas will handle splitting up the grid automatically.\n",
|
|
"\n",
|
|
"> Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "1ZJ2rV5W8FAe"
|
|
},
|
|
"source": [
|
|
"## Conclusion\n",
|
|
"\n",
|
|
"In this guide we covered how to express TPU pipelines using `pallas_call`, `grid` and `BlockSpec`s. We covered how to express nested loops via a multi-dimensional grid and how to handle reductions by initialize our accumulators at the beginning of the reduction. We also learned how to handle Megacore by adding annotations to the kernel.\n",
|
|
"\n",
|
|
"Exercises left to the reader:\n",
|
|
"* Try implementing a `sum` kernel that pipelines the other dimensions as well\n",
|
|
"* Add megacore support to the `add` kernel and the `sum` kernel as well."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"formats": "ipynb,md:myst"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|