{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "-z6pOJwvn-_j" }, "source": [ "# Matrix Multiplication\n", "\n", "In this guide, we'll write a matrix multiplication routine using Pallas. We'll also go over how to think about matmul performance on TPU and how to template a matmul kernel to fuse in operations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ejAVO6ikUUuF" }, "outputs": [], "source": [ "#@title Imports\n", "import functools\n", "from typing import Callable\n", "\n", "import jax\n", "from jax.experimental import pallas as pl\n", "from jax.experimental.pallas import tpu as pltpu\n", "from jax import random\n", "import jax.numpy as jnp\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": { "id": "58plJlycoPmT" }, "source": [ "## Background\n", "\n", "Matrix multiplication is a fundamental linear algebra operation at heart of modern deep learning and language modeling. We'd like to make matmuls as speedy as possible using specialized accelerators like TPUs and GPUs, which both have specialized units for fast matrix multiplication.\n", "\n", "To effectively utilize TPUs for matrix multiplication, we'll need to cover a few background concepts: block matrix multiplication, tiling and pipelining.\n", "\n", "### Block Matrix Multiplication\n", "\n", "Let's say we want to implement `matmul(x, y)` which generically multiplies an `(m, k)` array with a `(k, n)` array, but with a twist. We're only allowed to use the primitive `matmul_small` which multiples small matrices (say `m, k, n <= 256`). How could we do it?\n", "\n", "A nice property of matrix multiplication is that each block of the output can be expressed as the sum of several smaller matrix multiplications of row blocks and column blocks of the inputs.\n", "Formally, if we have input arrays $x \\in \\mathbb{R}^{m \\times k}$ and $y \\in \\mathbb{R}^{k \\times n}$ and output $z \\in \\mathbb{R}^{m \\times n}$, we decompose them into blocks along the dimensions of size $b_m, b_k, b_n$.\n", "\n", "For example, $x$ would be decomposed as:\n", "\n", "$$\n", "\\begin{bmatrix}\n", "x_{0, 0} & \\cdots & x_{0, i_k} \\\\\n", "x_{1, 0} & \\cdots & x_{1, i_k} \\\\\n", "\\vdots & \\ddots & \\vdots \\\\\n", "x_{i_m, 0} & \\cdots & x_{i_m, i_k} \\\\\n", "\\end{bmatrix}\n", "$$\n", "\n", "where $x_{ik} \\in \\mathbb{R}^{b_m \\times b_k}$. (We can similarly decompose $y$ and $z$.)\n", "\n", "For a particular output block $z_{ij}$, we can compute it as\n", "\n", "$$\n", "z_{ij} = \\sum_k x_{ik} y_{kj}\n", "$$\n", "\n", "Therefore, each output block $z_{ij}$ is the sum of several smaller block matrix multiplications $x_{ik} y_{kj}$. Here's how we'd implement this algorithm in NumPy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PACqDMtQrMOL" }, "outputs": [], "source": [ "def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:\n", " m, k, n = x.shape[0], x.shape[1], y.shape[0]\n", " assert m <= 256\n", " assert k <= 256\n", " assert n <= 256\n", " return np.matmul(x, y)\n", "\n", "def block_matmul(\n", " x: np.ndarray,\n", " y: np.ndarray,\n", " *,\n", " bm: int = 256,\n", " bk: int = 256,\n", " bn: int = 256,\n", ") -> np.ndarray:\n", " m, k = x.shape\n", " _, n = y.shape\n", "\n", " z = np.zeros((m, n), dtype=x.dtype)\n", " for m_i in range(m // bm):\n", " for n_i in range(n // bn):\n", " for k_i in range(k // bk):\n", " m_slice = slice(m_i * bm, (m_i + 1) * bm)\n", " k_slice = slice(k_i * bk, (k_i + 1) * bk)\n", " n_slice = slice(n_i * bn, (n_i + 1) * bn)\n", " x_block = x[m_slice, k_slice]\n", " y_block = y[k_slice, n_slice]\n", " z[m_slice, n_slice] += matmul_small(x_block, y_block)\n", " return z" ] }, { "cell_type": "markdown", "metadata": { "id": "TP49TV6q8so9" }, "source": [ "Our `block_matmul` function should now work on inputs larger than 256 (though we assume that our input dimensions evenly divide 256)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2SZFnWnurzEC" }, "outputs": [], "source": [ "m, k, n = 4096, 4096, 4096\n", "x = np.random.uniform(size=(m, k)).astype(np.float32)\n", "y = np.random.uniform(size=(k, n)).astype(np.float32)\n", "np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)" ] }, { "cell_type": "markdown", "metadata": { "id": "GXtjEtEhtARN" }, "source": [ "`block_matmul` decomposes a matrix multiplication into many smaller ones by observing that each output chunk of size `(bm, bn)` can be computed by accumulating several `(bm, bk) x (bk, bn)` size matrix multiplications.\n", "\n", "TPUs and GPUs do matmuls just like this! They natively support small matrix multiplication akin to `matmul_small`, so to utilize this hardware when doing bigger matrix multiplications, we will apply the `block_matmul` decomposition." ] }, { "cell_type": "markdown", "metadata": { "id": "a0ESFoX1ID0z" }, "source": [ "### Tiling and Pipelining\n", "\n", "In [the previous guide](pipelining), we covered how tiling up computations and pipelining in Pallas works. To make sure our compute units are always working and never stalled by memory transfers, we overlap the memory transfers for the next iteration of a kernel with the current one.\n", "\n", "In Pallas, we specify that via `BlockSpec`s and a `grid`. Note that we already have a nested for loop in the block matrix multiplication algorithm. We can specify that in Pallas via a `grid`. The slices in the block matrix multiplication can also be specified via `BlockSpec`s." ] }, { "cell_type": "markdown", "metadata": { "id": "FvYoyqlyIqo6" }, "source": [ "## Your first matrix multiplication kernel" ] }, { "cell_type": "markdown", "metadata": { "id": "umKZAlSvIt7x" }, "source": [ "Putting it all together, here's an implementation of a block matrix multiplication kernel that pipelines the memory transfers with the compute. We create a 3-d grid, corresponding to the 3-nested loop in the NumPy code. Note that while MXUs are only capable of multiplying small blocks, Pallas will automatically take bigger blocks and automatically tile them over the MXUs.\n", "\n", "The last dimension of the grid corresponds to the contraction dimension of the matrix multiply and is a reduction dimension, so we need to be sure to initialize the accumulator." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "75FBANKFbmQ5" }, "outputs": [], "source": [ "def matmul_kernel(x_ref, y_ref, z_ref):\n", " @pl.when(pl.program_id(2) == 0)\n", " def _():\n", " z_ref[...] = jnp.zeros_like(z_ref)\n", "\n", " z_ref[...] += x_ref[...] @ y_ref[...]\n", "\n", "def matmul(\n", " x: jax.Array,\n", " y: jax.Array,\n", " *,\n", " bm: int = 128,\n", " bk: int = 128,\n", " bn: int = 128,\n", "):\n", " m, k = x.shape\n", " _, n = y.shape\n", " return pl.pallas_call(\n", " matmul_kernel,\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", " in_specs=[pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),\n", " pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " grid=(m // bm, n // bn, k // bk),\n", " compiler_params=pltpu.TPUCompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0e8qTsimccGV" }, "outputs": [], "source": [ "m, k, n = 4096, 4096, 4096\n", "k1, k2 = random.split(random.key(0), 2)\n", "x = random.normal(k1, (m, k), dtype=jnp.float32)\n", "y = random.normal(k2, (k, n), dtype=jnp.float32)\n", "np.testing.assert_array_equal(x @ y, matmul(x, y))" ] }, { "cell_type": "markdown", "metadata": { "id": "DycJX_-PJnnB" }, "source": [ "## Matrix multiplication performance\n", "\n", "Let's think about how to analyze matrix multiplication performance. When we think about matmul performance, we typically care about two things: the total number of floating point operations (FLOPs) and the amount of memory bandwidth usage. From the [guide on TPUs and pipelining](pipelining), we see that in order to use the efficient compute units on TPUs (and on ML accelerators on general), we need to copy our inputs from HBM into VMEM, closer to the compute units. This copying to and from HBM takes time and an efficient kernel hopefully spends most of its time actually computing, as opposed to waiting for these transfers. Memory bandwidth measures the rate of this data transfer.\n", "\n", "> Quick note: in this guide, we'll be discussing floating point operations, but want to make the distinction between FLOPs vs FLOP/s.\n", " When we say \"FLOPs\" we mean \"floating point operations\", as in a number of operations. When we say \"FLOP/s\", we refer to \"floating point operations *per second*\", as in a *rate* of performing floating point operations.\n", "\n", "The number of FLOPs in a `(m, k) x (k, n)` matrix multiplication are (approximately) `2 * m * k * n`. (Technically it is `n * m * (2k - 1)` but for large enough `k` our approximation is sufficient.)\n", "\n", "The minimum amount of memory bandwidth usage for a matrix multiply (assuming float32) is the total size of the inputs (copying into VMEM) plus the size of the output (copying into HBM). Thus the minimum bandwidth usage is `(m * k + k * n + m * n) * 4 bytes/float32`. Memory usage can be greater if we re-read the inputs multiple times, which is often the case.\n", "\n", "One observation is that the number of matmul FLOPs is cubic in its inputs whereas the minimum bandwidth usage is quadratic in its inputs. Intuitively, this means that FLOPs grow faster than bandwidth usage, meaning that the bigger our matmul is, the more compute we have relative to copying." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HZwmYZ61QZ5L", "outputId": "18505741-9254-4738-ec64-1660f6733d77" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2147483648\n", "12582912\n" ] } ], "source": [ "def matmul_flops(m: int, k: int, n: int):\n", " return 2 * m * k * n\n", "\n", "def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):\n", " return (m * k + k * n + m * n) * np.dtype(dtype).itemsize\n", "\n", "print(matmul_flops(1024, 1024, 1024))\n", "print(matmul_membw(1024, 1024, 1024, jnp.float32))" ] }, { "cell_type": "markdown", "metadata": { "id": "agCtb2GMQazl" }, "source": [ "Now that we can calculate the total number of FLOPs and (minimum) memory bandwidth usage of a matrix multiplication, let's see what a real TPU can handle.\n", "\n", "This notebook was run on a TPU v5e chip so we'll use the v5e numbers (if you are running this notebook, your numbers may differ). TPU v5es have [197 TFLOP/s of bf16/f32 compute and 819 GB/s of memory bandwidth](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v5e). By looking at the ratio of these numbers (called the arithmetic intensity), we can get a bound on how low this \"FLOPs / memory bandwidth usage\" ratio can get before we become IO bound (about 240 FLOPs/byte on TPU v5e)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "WUydNX2-K6Oy" }, "outputs": [], "source": [ "v5e_flops = 197e12\n", "v5e_membw = 819e9\n", "v5e_op_intensity = v5e_flops / v5e_membw # ~240.5" ] }, { "cell_type": "markdown", "metadata": { "id": "UjQIWq-9RJue" }, "source": [ "Roughly, these numbers tell us the FLOPs of a matmul should take `2 * m * k * n / (197 TFLOP/s)` seconds and the copies to/from VMEM should take `(m*k + k*n + m*n) * 4 bytes / 819GB/s` seconds." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "PiYobLc-RQSB" }, "outputs": [], "source": [ "def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):\n", " flops = matmul_flops(m, k, n)\n", " membw = matmul_membw(m, k, n, dtype)\n", " return flops / membw" ] }, { "cell_type": "markdown", "metadata": { "id": "q1y6dP00Sv9S" }, "source": [ "This basic calculation tells us roughly how efficiently we'll be able to use our MXUs. If our matmul op intensity is below what our chip is capable of, then our computation will be *memory bound*, i.e. our compute units will be idling while waiting for values to be transferred. If the matmul intensity is higher than what the chip is capable, then we will be *compute bound*.\n", "\n", "Because matmul FLOPs are cubic in their input sizes and memory bandwidth usage is quadratic, we expect that we will get compute bound as we get bigger and bigger, but this crossing over point is really important! Let's say we are doing a `(1024, 1024) x (1024, 1024)` float32 matrix multiplication." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "NMcretZoTPjj", "outputId": "1a03e351-abcf-48d4-f81d-b8fcbe056619" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "170.66666666666666 flops/byte\n" ] } ], "source": [ "print(f\"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte\")" ] }, { "cell_type": "markdown", "metadata": { "id": "U0CZSKwdTbqE" }, "source": [ "Our matmul flops intensity is less than what our chip is capable of. That's not good! We are likely going to be memory bound with this type of matrix multiplication. However, what if our inputs and outputs were bigger instead? At some point when our matmuls get big enough, we will cross over from memory bound into compute bound. For example, if we have a matmul where `m = k = n`, we will cross over (on TPU v5e) when `2m**3 / 12m**2 > 240` or when `m = k = n > 1440`." ] }, { "cell_type": "markdown", "id": "5bd00f91", "metadata": { "id": "iw4c_CZIdSeV" }, "source": [ "### `bfloat16` matrix multiplication" ] }, { "cell_type": "markdown", "id": "f05f0c15", "metadata": { "id": "7tACYDKIT3lq" }, "source": [ "To make it easier for matrix multiplication to be compute bound on TPU, we could also use a smaller dtype for our inputs and outputs. Our previous example used `float32` inputs and outputs but TPU v5e also supports the `bfloat16` data type (a 16-bit floating point format, also called `bf16`) for matrix multiplication as well. On TPU v5e, we will have the same FLOP/s but will *halve our memory bandwidth usage*. This makes it way easier to be compute bound for smaller matrices. Let's see what our intensity is with a 1024 x 1024 x 1024 `bf16` matrix multiply:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "mcuLdyDoTmnO", "outputId": "10c3dcf0-7421-49f5-a38e-e5772d791bc2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "341.3333333333333 flops/byte\n" ] } ], "source": [ "print(f\"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte\")" ] }, { "cell_type": "markdown", "metadata": { "id": "XPPil1YSTn9Z" }, "source": [ "We now have a matmul that is compute bound!\n", "\n", "Let's add `bf16` support to our matrix multiplication kernel.\n", "\n", "The native MXU `bf16` matmul routine takes two input `bf16` matrices and accumulates it in `f32`. We will trigger this routine by passing `preferred_element_type=jnp.float32` into `jnp.matmul`. We will also need a accumulator `Ref` that is in `f32`. We will then downcast the output back to `bf16` before writing it back to HBM. This way we don't lose any precision, don't do any extra casting, and still retain the `bf16` memory bandwidth savings.\n", "\n", "> Note that the only way of allocating scratch spaces right now is via `pltpu.PrefetchScalarGridSpec`. Don't worry about exactly what it does for now -- all you need to know for now is that it allows you to allocate scratch spaces in VMEM." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tyMcZtA6dWDP" }, "outputs": [], "source": [ "def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps):\n", " @pl.when(pl.program_id(2) == 0)\n", " def _():\n", " acc_ref[...] = jnp.zeros_like(acc_ref)\n", "\n", " acc_ref[...] += jnp.dot(\n", " x_ref[...], y_ref[...], preferred_element_type=jnp.float32\n", " )\n", "\n", " @pl.when(pl.program_id(2) == nsteps - 1)\n", " def _():\n", " z_ref[...] = acc_ref[...].astype(z_ref.dtype)\n", "\n", "\n", "@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])\n", "def matmul(\n", " x: jax.Array,\n", " y: jax.Array,\n", " *,\n", " bm: int = 128,\n", " bk: int = 128,\n", " bn: int = 128,\n", "):\n", " m, k = x.shape\n", " _, n = y.shape\n", " return pl.pallas_call(\n", " functools.partial(matmul_kernel, nsteps=k // bk),\n", " grid_spec=pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", " pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),\n", " pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),\n", " ],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],\n", " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", " compiler_params=pltpu.TPUCompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G3uHKEabVXep" }, "outputs": [], "source": [ "m, k, n = 4096, 4096, 4096\n", "k1, k2 = random.split(random.key(0), 2)\n", "x = random.normal(k1, (m, k), dtype=jnp.bfloat16)\n", "y = random.normal(k2, (k, n), dtype=jnp.bfloat16)\n", "np.testing.assert_array_equal(x @ y, matmul(x, y))" ] }, { "cell_type": "markdown", "metadata": { "id": "fBL1NwXzVlWa" }, "source": [ "## Performance of pipelined kernels\n", "\n", "Our above analysis about FLOPs vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.\n", "\n", "This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n", "\n", "The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints:\n", "\n", "1. VMEM usage: The bigger our blocks, the more VMEM we use. With large enough blocks, we will run out.\n", "2. Pipeline bubbles: The larger our blocks are relative to the matrix size, the fewer loop iterations we will have in our pipeline. This will make the size of the bubbles at the beginning and end of the pipeline larger relative to the total pipeline and this overhead can be nontrivial.\n", "\n", "Getting good matrix multiplication performance in Pallas boils down to picking good block sizes to balance this optimization problem. In practice, we often sweep over a large set of candidate block sizes, profile the kernel, and pick the best one.\n", "\n", "For now, let's do some very simple timing experiments. We'll use `timeit` to measure the amount of time running each kernel takes. Note that this is a upper bound on the actual runtime of the kernel since we are measuring Python dispatch and other overheads using `timeit`. We'll compute the amount of FLOP/s we obtained this way and compute the percentage utilization we get compared to what the chip offers and we'll use some reasonable block sizes to verify our intuition." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RjU3sSTUWzIk", "outputId": "02b5793e-1ff3-41f4-ab45-4cf1393885ba" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================bm=128, bk=128, bn=128===================\n", "----- 1024 x 1024 x 1024 -----\n", "Matmul time: 0.00029766598949208854\n", "Matmul FLOP/s: 7214407167121.377\n", "FLOP/s utilization: 3.6621%\n", "\n", "----- 4096 x 4096 x 4096 -----\n", "Matmul time: 0.011771515250438824\n", "Matmul FLOP/s: 11675553278230.387\n", "FLOP/s utilization: 5.9267%\n", "\n", "----- 8192 x 8192 x 8192 -----\n", "Matmul time: 0.09183577066054567\n", "Matmul FLOP/s: 11972585626140.668\n", "FLOP/s utilization: 6.0775%\n", "\n", "================bm=512, bk=1024, bn=1024===================\n", "----- 1024 x 1024 x 1024 -----\n", "Matmul time: 0.00012708659982308746\n", "Matmul FLOP/s: 16897797651282.135\n", "FLOP/s utilization: 8.5776%\n", "\n", "----- 4096 x 4096 x 4096 -----\n", "Matmul time: 0.00088908776990138\n", "Matmul FLOP/s: 154584235803001.88\n", "FLOP/s utilization: 78.4692%\n", "\n", "----- 8192 x 8192 x 8192 -----\n", "Matmul time: 0.006099433819763363\n", "Matmul FLOP/s: 180264539343531.62\n", "FLOP/s utilization: 91.5048%\n", "\n" ] } ], "source": [ "import timeit\n", "\n", "def benchmark(f, ntrials: int = 100):\n", " def run(*args, **kwargs):\n", " # Compile function first\n", " jax.block_until_ready(f(*args, **kwargs))\n", " # Time function\n", " result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),\n", " number=ntrials)\n", " time = result / ntrials\n", " # print(f\"Time: {time}\")\n", " return time\n", " return run\n", "\n", "def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,\n", " mm_func):\n", " x = jnp.ones((m, k), dtype=dtype)\n", " y = jnp.ones((k, n), dtype=dtype)\n", " time = benchmark(mm_func)(x, y)\n", " print(f\"----- {m} x {k} x {n} -----\")\n", " print(\"Matmul time: \", time)\n", " mm_flops = matmul_flops(m, k, n) / time\n", " print(\"Matmul FLOP/s: \", mm_flops)\n", " print(f\"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%\")\n", " print()\n", "\n", "print(\"================bm=128, bk=128, bn=128===================\")\n", "mm = functools.partial(matmul, bm=128, bk=128, bn=128)\n", "analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)\n", "analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)\n", "analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)\n", "\n", "print(\"================bm=512, bk=1024, bn=1024===================\")\n", "mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)\n", "analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)\n", "analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)\n", "analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)" ] }, { "cell_type": "markdown", "metadata": { "id": "mg1GMqcVan70" }, "source": [ "Bigger block sizes help a lot! We get pretty good utilization (80-90%) in the bigger matmuls, but the smallest matmul seems pretty hard to get good performance with.\n", "\n", "Let's compare this with XLA's matmuls. We don't expect Pallas to do better than XLA because XLA is *very* good at generating matmuls but hopefully we are close.\n", "With more careful block size tuning (left as future work), we can also reach XLA performance." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OpU7I7BNXQYg", "outputId": "28c2c3cf-759e-465c-f969-0e2c9607b8a5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================ XLA matmul ===================\n", "----- 1024 x 1024 x 1024 -----\n", "Matmul time: 0.00011943008983507753\n", "Matmul FLOP/s: 17981093801113.996\n", "FLOP/s utilization: 9.1275%\n", "\n", "----- 4096 x 4096 x 4096 -----\n", "Matmul time: 0.0008272899803705514\n", "Matmul FLOP/s: 166131533963991.34\n", "FLOP/s utilization: 84.3307%\n", "\n", "----- 8192 x 8192 x 8192 -----\n", "Matmul time: 0.006047147869830951\n", "Matmul FLOP/s: 181823175395037.44\n", "FLOP/s utilization: 92.2960%\n", "\n" ] } ], "source": [ "print(\"================ XLA matmul ===================\")\n", "mm = jnp.matmul\n", "analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)\n", "analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)\n", "analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)" ] }, { "cell_type": "markdown", "metadata": { "id": "L-KUG3lha-jm" }, "source": [ "Pallas, with some very basic tuning, gets pretty close to XLA's performance numbers! By trying out more block sizes, we should expect to close the gap entirely." ] }, { "cell_type": "markdown", "metadata": { "id": "nbdHMJRObnZa" }, "source": [ "## Templating the matrix multiplication" ] }, { "cell_type": "markdown", "metadata": { "id": "qSfcMwtDg7Vn" }, "source": [ "Now that we have a basic matrix multiplication kernel, we can now try fusing operations into it.\n", "\n", "### Fused right-hand-side transpose\n", "\n", "A common first thing to do is to fuse a transpose. What do we mean by that? Suppose we wanted to compute `x @ y.T` instead of `x @ y`. Naively we could first compute `y.T` and then pass it into our efficient matrix multiply kernel. However, the operation `y.T` is not free on its own -- it involves copying `O(n^2)` data. Ideally, we could compute the transpose *while* doing the matrix multiply in just one kernel, i.e. \"fusing\" it with the matmul.\n", "\n", "Accelerators often support native matrix multiplication routine that fuse a RHS transpose. For instance TPU v5e, the MXU allows us to do `x @ y.T` for small arrays. We can invoke this routine with `jax.lax.dot_general`, which will be more efficient than doing a transpose then a matmul separately." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1_6S_QnMbHAQ" }, "outputs": [], "source": [ "def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):\n", " @pl.when(pl.program_id(2) == 0)\n", " def _():\n", " acc_ref[...] = jnp.zeros_like(acc_ref)\n", "\n", " # dot_general expects a data structure (contraction_dims, batch_dims),\n", " # where contraction_dims are the set of dimensions for LHS and RHS that will\n", " # be contracted (reduced) in the matmul; batch_dims, on the other hand, are\n", " # looped over. The remaining dimensions will be the input and output dimension\n", " # of the matmul.\n", " if transpose_rhs:\n", " dims = ((1,), (1,)), ((), ())\n", " else:\n", " dims = ((1,), (0,)), ((), ())\n", "\n", " acc_ref[...] += jax.lax.dot_general(\n", " x_ref[...], y_ref[...], dims, preferred_element_type=jnp.float32,\n", " )\n", "\n", " @pl.when(pl.program_id(2) == nsteps - 1)\n", " def _():\n", " z_ref[...] = acc_ref[...].astype(z_ref.dtype)\n", "\n", "\n", "@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'transpose_rhs'])\n", "def matmul(\n", " x: jax.Array,\n", " y: jax.Array,\n", " *,\n", " bm: int = 128,\n", " bk: int = 128,\n", " bn: int = 128,\n", " transpose_rhs: bool = False,\n", "):\n", " if transpose_rhs:\n", " y = y.swapaxes(0, 1)\n", " y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))\n", " else:\n", " y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))\n", " m, k = x.shape\n", " _, n = y.shape\n", " return pl.pallas_call(\n", " functools.partial(matmul_kernel, nsteps=k // bk, transpose_rhs=transpose_rhs),\n", " grid_spec=pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", " pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),\n", " y_block_spec,\n", " ],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],\n", " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", " compiler_params=pltpu.TPUCompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, { "cell_type": "markdown", "metadata": { "id": "eSmPJHSchuGX" }, "source": [ "We do a transpose inside of the `matmul` function (`y = y.swapaxes(0, 1)`). This is because inside of a JIT-ted JAX computation, dimension ordering is purely *logical*, not physical, so rearranging dimensions does not imply a\n", "physical layout difference. However, when we pass an array into a `pallas_call`, we do enforce a major-to-minor dimension ordering constraint. By transposing `y` inside of the `matmul` function, we are requesting that `y` be in a\n", "transposed layout `(n, k)` instead of the usual `(k, n)`. The user will still pass in the array in the (logical) `(n, k)` dimension, however.\n", "\n", "Note: to benchmark the transpose, we actually want `y` to be in the physical transposed layout when we pass it into the kernel, so we don't measure relayout time. In the wrapper function, we will (logically) transpose it back to `(n, k)`\n", "before passing it into `matmul` because `matmul` expects a logical `(n, k)` dimension ordering." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AcBMHhKLhkDp", "outputId": "48f2f70b-c94d-44eb-c781-871c36cf457f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================bm=128, bk=128, bn=128===================\n", "----- 1024 x 1024 x 1024 -----\n", "Matmul time: 0.0003029372810851783\n", "Matmul FLOP/s: 7088872126624.065\n", "FLOP/s utilization: 3.5984%\n", "\n", "----- 4096 x 4096 x 4096 -----\n", "Matmul time: 0.012017967159627005\n", "Matmul FLOP/s: 11436123235026.848\n", "FLOP/s utilization: 5.8051%\n", "\n", "----- 8192 x 8192 x 8192 -----\n", "Matmul time: 0.09500920018996112\n", "Matmul FLOP/s: 11572685861765.383\n", "FLOP/s utilization: 5.8745%\n", "\n", "================bm=512, bk=1024, bn=1024===================\n", "----- 1024 x 1024 x 1024 -----\n", "Matmul time: 0.00012131539988331496\n", "Matmul FLOP/s: 17701657415839.363\n", "FLOP/s utilization: 8.9856%\n", "\n", "----- 4096 x 4096 x 4096 -----\n", "Matmul time: 0.0008790623804088682\n", "Matmul FLOP/s: 156347213275211.03\n", "FLOP/s utilization: 79.3641%\n", "\n", "----- 8192 x 8192 x 8192 -----\n", "Matmul time: 0.006107717020204291\n", "Matmul FLOP/s: 180020067095253.78\n", "FLOP/s utilization: 91.3807%\n", "\n" ] } ], "source": [ "def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,\n", " mm_func, transpose_rhs: bool = False):\n", " x = jnp.ones((m, k), dtype=dtype)\n", " if transpose_rhs:\n", " y = jnp.ones((n, k), dtype=dtype)\n", " @jax.jit\n", " def _wrapper(x, y):\n", " y = y.swapaxes(0, 1)\n", " return mm_func(x, y, transpose_rhs=True)\n", " else:\n", " y = jnp.ones((k, n), dtype=dtype)\n", " _wrapper = mm_func\n", " time = benchmark(_wrapper)(x, y)\n", " print(f\"----- {m} x {k} x {n} -----\")\n", " print(\"Matmul time: \", time)\n", " mm_flops = matmul_flops(m, k, n) / time\n", " print(\"Matmul FLOP/s: \", mm_flops)\n", " print(f\"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%\")\n", " print()\n", "\n", "print(\"================bm=128, bk=128, bn=128===================\")\n", "mm = functools.partial(matmul, bm=128, bk=128, bn=128)\n", "analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)\n", "analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)\n", "analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)\n", "\n", "print(\"================bm=512, bk=1024, bn=1024===================\")\n", "mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)\n", "analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)\n", "analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)\n", "analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "E0P8lWhskn3j" }, "source": [ "See how we get the same utilization despite the extra transpose!" ] }, { "cell_type": "markdown", "metadata": { "id": "DUYGnu7zkz8v" }, "source": [ "### Fused activation function\n", "\n", "Fusing in an activation is also really common. This makes sure we don't follow an efficient, compute bound matmul kernel with a slow memory bound activation kernel." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SANr6fyBiso_" }, "outputs": [], "source": [ "def matmul_kernel(\n", " x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs, activation\n", "):\n", " @pl.when(pl.program_id(2) == 0)\n", " def _():\n", " acc_ref[...] = jnp.zeros_like(acc_ref)\n", "\n", " if transpose_rhs:\n", " dims = ((1,), (1,)), ((), ())\n", " else:\n", " dims = ((1,), (0,)), ((), ())\n", "\n", " acc_ref[...] += jax.lax.dot_general(\n", " x_ref[...],\n", " y_ref[...],\n", " dims,\n", " preferred_element_type=jnp.float32,\n", " )\n", "\n", " @pl.when(pl.program_id(2) == nsteps - 1)\n", " def _():\n", " z_ref[...] = activation(acc_ref[...]).astype(z_ref.dtype)\n", "\n", "\n", "@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'activation'])\n", "def matmul(\n", " x: jax.Array,\n", " y: jax.Array,\n", " *,\n", " bm: int = 128,\n", " bk: int = 128,\n", " bn: int = 128,\n", " transpose_rhs: bool = False,\n", " activation: Callable[[jax.Array], jax.Array] = lambda x: x,\n", "):\n", " if transpose_rhs:\n", " y = y.swapaxes(0, 1)\n", " y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))\n", " else:\n", " y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))\n", " m, k = x.shape\n", " _, n = y.shape\n", " return pl.pallas_call(\n", " functools.partial(\n", " matmul_kernel,\n", " nsteps=k // bk,\n", " transpose_rhs=transpose_rhs,\n", " activation=activation,\n", " ),\n", " grid_spec=pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", " pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),\n", " y_block_spec,\n", " ],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],\n", " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", " compiler_params=pltpu.TPUCompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BOu7WBCBlHpN", "outputId": "4b7f72c4-f562-4a49-cc48-17bf0c845434" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================bm=128, bk=128, bn=128===================\n", "----- 1024 x 1024 x 1024 -----\n", "Matmul time: 0.00030103540048003196\n", "Matmul FLOP/s: 7133658182976.541\n", "FLOP/s utilization: 3.6211%\n", "\n", "----- 4096 x 4096 x 4096 -----\n", "Matmul time: 0.011807117109419778\n", "Matmul FLOP/s: 11640348122095.826\n", "FLOP/s utilization: 5.9088%\n", "\n", "----- 8192 x 8192 x 8192 -----\n", "Matmul time: 0.09181861146935262\n", "Matmul FLOP/s: 11974823079773.941\n", "FLOP/s utilization: 6.0786%\n", "\n", "================bm=512, bk=1024, bn=1024===================\n", "----- 1024 x 1024 x 1024 -----\n", "Matmul time: 0.00012622540001757442\n", "Matmul FLOP/s: 17013086492108.6\n", "FLOP/s utilization: 8.6361%\n", "\n", "----- 4096 x 4096 x 4096 -----\n", "Matmul time: 0.000896632740041241\n", "Matmul FLOP/s: 153283442968721.44\n", "FLOP/s utilization: 77.8089%\n", "\n", "----- 8192 x 8192 x 8192 -----\n", "Matmul time: 0.006130605939542875\n", "Matmul FLOP/s: 179347953304919.88\n", "FLOP/s utilization: 91.0396%\n", "\n" ] } ], "source": [ "def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,\n", " mm_func, transpose_rhs: bool = False,\n", " activation = lambda x: x):\n", " x = jnp.ones((m, k), dtype=dtype)\n", " if transpose_rhs:\n", " y = jnp.ones((n, k), dtype=dtype)\n", " @jax.jit\n", " def _wrapper(x, y):\n", " y = y.swapaxes(0, 1)\n", " return mm_func(x, y, transpose_rhs=True, activation=activation)\n", " else:\n", " y = jnp.ones((k, n), dtype=dtype)\n", " _wrapper = functools.partial(mm_func, activation=activation)\n", " time = benchmark(_wrapper)(x, y)\n", " print(f\"----- {m} x {k} x {n} -----\")\n", " print(\"Matmul time: \", time)\n", " mm_flops = matmul_flops(m, k, n) / time\n", " print(\"Matmul FLOP/s: \", mm_flops)\n", " print(f\"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%\")\n", " print()\n", "\n", "\n", "activation = jax.nn.relu\n", "print(\"================bm=128, bk=128, bn=128===================\")\n", "mm = functools.partial(matmul, bm=128, bk=128, bn=128)\n", "analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)\n", "analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)\n", "analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)\n", "\n", "print(\"================bm=512, bk=1024, bn=1024===================\")\n", "mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)\n", "analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)\n", "analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)\n", "analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)" ] }, { "cell_type": "markdown", "metadata": { "id": "tIekGWFLmgtS" }, "source": [ "The additional fused activation barely affects our utilization at all!" ] }, { "cell_type": "markdown", "metadata": { "id": "faNZwx20mpJi" }, "source": [ "## Conclusion\n", "\n", "In this guide, we covered how to write efficient matrix multiplications on TPU using Pallas. We discussed blocked matrix multiplication and pipelining, how to analyze the performance of a TPU matmul, and how to write an efficient `bf16` matrix multiplication. We concluded with templating the matrix multiplication to support a fused transpose and fused activation functions.\n", "\n", "Exercises left to the reader:\n", "* Add support for input fusions. Sometimes we want to fuse an operation into the inputs of the matmul. Try templating the matrix multiplication even more to support this.\n", "* Add support for `int8` matrix multiplication. TPU v5 supports native `int8` matrix multiplication at twice the FLOPs of `bf16`. Try adding support for that and see what utilization is possible.\n", "* Add backwards pass support for the `matmul` function. You can do this with `jax.custom_vjp`." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 4 }