mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Address changes
This commit is contained in:
parent
39ec5dacb4
commit
ff62d5e229
@ -39,13 +39,13 @@
|
||||
"source": [
|
||||
"## Background\n",
|
||||
"\n",
|
||||
"Matrix multiplication is a fundamental linear algebra operation at heart of modern deep learning and language modeling. We'd like to make matmuls as speedy as possible using specialized accelerators like TPUs and GPUs. Both TPUs and GPUs have specialized units for fast matrix multiplication though they work differently.\n",
|
||||
"Matrix multiplication is a fundamental linear algebra operation at heart of modern deep learning and language modeling. We'd like to make matmuls as speedy as possible using specialized accelerators like TPUs and GPUs, which both have specialized units for fast matrix multiplication.\n",
|
||||
"\n",
|
||||
"To effectively utilize TPUs for matrix multiplication, we'll need to cover two background concepts: block matrix multiplication, and pipelining.\n",
|
||||
"To effectively utilize TPUs for matrix multiplication, we'll need to cover a few background concepts: block matrix multiplication, tiling and pipelining.\n",
|
||||
"\n",
|
||||
"### Block Matrix Multiplication\n",
|
||||
"\n",
|
||||
"Let's implement `matmul(x, y)` function that multiplies an `(m, k)` array with a `(k, n)` array, but let's say that we also have a function `matmul_small` that does matrix multiplication for us but only with small sizes (say `m, k, n <= 256`). Could we implement our our `matmul` function in terms of `matmul_small`?\n",
|
||||
"Let's say we'd like to implement a `matmul(x, y)` function that multiplies an `(m, k)` array with a `(k, n)` array using a provided function `matmul_small` that does matrix multiplication for us but only with small sizes (say `m, k, n <= 256`) to do the compute. Could we implement our our `matmul` function in terms of `matmul_small`?\n",
|
||||
"\n",
|
||||
"Yes! The answer is by decomposing the matmul into one that operates on small chunks of our inputs (a.k.a. blocks).\n",
|
||||
"\n",
|
||||
@ -134,7 +134,7 @@
|
||||
"source": [
|
||||
"`block_matmul` decomposes a matrix multiplication into many smaller ones by observing that each output chunk of size `(bm, bn)` can be computed by accumulating several `(bm, bk) x (bk, bn)` size matrix multiplications.\n",
|
||||
"\n",
|
||||
"TPUs have native hardware capable of small matrix multiplication akin to `matmul_small`, so to utilize this hardware when doing bigger matrix multiplications, we will apply the `block_matmul` decomposition."
|
||||
"TPUs and GPUs have native hardware capable of small matrix multiplication akin to `matmul_small`, so to utilize this hardware when doing bigger matrix multiplications, we will apply the `block_matmul` decomposition."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -143,9 +143,9 @@
|
||||
"id": "a0ESFoX1ID0z"
|
||||
},
|
||||
"source": [
|
||||
"### Pipelining\n",
|
||||
"### Tiling and Pipelining\n",
|
||||
"\n",
|
||||
"In [the previous guide](pipelining), we covered how pipelining in Pallas works. To make sure our compute units are always working and never waiting for memory transfers, we overlap the memory transfers for the next iteration of a kernel with the current one.\n",
|
||||
"In [the previous guide](pipelining), we covered how tiling up computations and pipelining in Pallas works. To make sure our compute units are always working and never stalled by memory transfers, we overlap the memory transfers for the next iteration of a kernel with the current one.\n",
|
||||
"\n",
|
||||
"In Pallas, we specify that via `BlockSpec`s and a `grid`. Note that we already have a nested for loop in the block matrix multiplication algorithm. We can specify that in Pallas via a `grid`. The slices in the block matrix multiplication can also be specified via `BlockSpec`s."
|
||||
]
|
||||
@ -201,7 +201,9 @@
|
||||
" in_specs=[pl.BlockSpec(lambda i, j, k: (i, k), (bm, bk)),\n",
|
||||
" pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn))],\n",
|
||||
" out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (bm, bn)),\n",
|
||||
" grid=(m // bm, n // bn, k // bk)\n",
|
||||
" grid=(m // bm, n // bn, k // bk),\n",
|
||||
" compiler_params=dict(mosaic=dict(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
|
||||
" )(x, y)"
|
||||
]
|
||||
},
|
||||
@ -232,7 +234,7 @@
|
||||
"\n",
|
||||
"The number of FLOPs in a `(m, k) x (k, n)` matrix multiplication are (approximately) `2 * m * k * n`. (Technically it is `n * m * (2k - 1)` but for large enough `k` our approximation is sufficient.)\n",
|
||||
"\n",
|
||||
"The minimum amount of memory usage for a matrix multiply (assuming float32) is the total size of the inputs (copying into VMEM) plus the size of the output (copying into HBM). Thus the minimum bandwidth usage is `(m * k + k * n + m * n) * 4 bytes/float32`.\n",
|
||||
"The minimum amount of memory usage for a matrix multiply (assuming float32) is the total size of the inputs (copying into VMEM) plus the size of the output (copying into HBM). Thus the minimum bandwidth usage is `(m * k + k * n + m * n) * 4 bytes/float32`. Memory usage can be greater if we re-read the inputs multiple times, which is often the case.\n",
|
||||
"\n",
|
||||
"One observation is that the number of matmul FLOPs is cubic in its inputs whereas the minimum bandwidth usage is quadratic in its inputs. Intuitively, this means that FLOPs grow faster than bandwidth usage, meaning that the bigger our matmul is, the more compute we have relative to copying."
|
||||
]
|
||||
@ -271,9 +273,9 @@
|
||||
"id": "agCtb2GMQazl"
|
||||
},
|
||||
"source": [
|
||||
"Now that we can calculate the amount of FLOPs and memory bandwidth usage of a matrix multiplication, let's now see how many FLOP/s we can execute and how much memory bandwidth we actually have on a particular TPU.\n",
|
||||
"Now that we can calculate the amount of FLOPs and (minimum) memory bandwidth usage of a matrix multiplication, let's now see how many FLOP/s we can execute and how much memory bandwidth we actually have on a particular TPU.\n",
|
||||
"\n",
|
||||
"This notebook was run on TPU v5e so we will utilize numbers for that specific chip (if you are running this notebook, your numbers may differ). TPU v5es have [197 TFLOPs of bf16 compute and 819 GBps of memory bandwidth](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v5e). We can compute the arithmetic intensity that the chip is capable of by taking their ratio (about 240 on TPU v5e)."
|
||||
"This notebook was run on TPU v5e so we will utilize numbers for that specific chip (if you are running this notebook, your numbers may differ). TPU v5es have [197 TFLOPs of bf16 compute and 819 GB/s of memory bandwidth](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v5e). We can compute the minimum arithmetic intensity (ratio of FLOPs to memory bandwidth) needed to be compute bound in this analytical model on this chip by taking their ratio (about 240 FLOPs/byte on TPU v5e)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -477,9 +479,9 @@
|
||||
"\n",
|
||||
"Our above analysis about FLOPS vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.\n",
|
||||
"\n",
|
||||
"This means that we actually care about the FLOPS vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPS vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for matrix multiplication performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n",
|
||||
"This means that we actually care about the FLOPS vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPS vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n",
|
||||
"\n",
|
||||
"The intuition should therefore be: make the blocks as big as possible! There are two main constraints:\n",
|
||||
"The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints:\n",
|
||||
"\n",
|
||||
"1. VMEM usage: The bigger our blocks, the more VMEM we use. With large enough blocks, we will run out.\n",
|
||||
"2. Pipeline bubbles: The larger our blocks are relative to the matrix size, the fewer loop iterations we will have in our pipeline. This will make the size of the bubbles at the beginning and end of the pipeline larger relative to the total pipeline and this overhead can be nontrivial.\n",
|
||||
@ -584,7 +586,8 @@
|
||||
"source": [
|
||||
"Bigger block sizes help a lot! We get pretty good utilization (80-90%) in the bigger matmuls, but the smallest matmul seems pretty hard to get good performance with.\n",
|
||||
"\n",
|
||||
"Let's compare this with XLA's matmuls. We don't expect Pallas to do better than XLA because XLA is *very* good at generating matmuls but hopefully we are close."
|
||||
"Let's compare this with XLA's matmuls. We don't expect Pallas to do better than XLA because XLA is *very* good at generating matmuls but hopefully we are close.\n",
|
||||
"With more careful block size tuning (left as future work), we can also reach XLA performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -654,7 +657,7 @@
|
||||
"\n",
|
||||
"### Fused right-hand-side transpose\n",
|
||||
"\n",
|
||||
"A common first thing to do is to fuse a transpose (e.g. instead of doing `x @ y` we do `x @ y.T`). On TPUs, the MXU supports matmul routines with the right hand side transposed natively, so there is no additional cost to fusing in a transpose. We can use this routine with `jax.lax.dot_general`."
|
||||
"A common first thing to do is to fuse a transpose (e.g. instead of doing `x @ y` we do `x @ y.T`). Accelerators often supports a matrix multiplication routine that fuses an RHS transpose (for example, TPU v5e) so there is no additional cost to fusing in a transpose. We can use this routine with `jax.lax.dot_general`, which when natively supported will be more efficient than doing a transpose and a matmul separately."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -670,6 +673,11 @@
|
||||
" def _():\n",
|
||||
" acc_ref[...] = jnp.zeros_like(acc_ref)\n",
|
||||
"\n",
|
||||
" # dot_general expects a data structure (contraction_dims, batch_dims),\n",
|
||||
" # where contraction_dims are the set of dimensions for LHS and RHS that will\n",
|
||||
" # be contracted (reduced) in the matmul; batch_dims, on the other hand, are\n",
|
||||
" # looped over. The remaining dimensions will be the input and output dimension\n",
|
||||
" # of the matmul.\n",
|
||||
" if transpose_rhs:\n",
|
||||
" dims = ((1,), (1,)), ((), ())\n",
|
||||
" else:\n",
|
||||
@ -723,9 +731,12 @@
|
||||
"id": "eSmPJHSchuGX"
|
||||
},
|
||||
"source": [
|
||||
"Note that we do a transpose inside of the `matmul` function (`y = y.swapaxes(0, 1)`). This is because transposes in XLA are *logical*, not physical. It informs XLA that the custom call would like `y` to be laid out differently before the matmul.\n",
|
||||
"We do a transpose inside of the `matmul` function (`y = y.swapaxes(0, 1)`). This is because inside of a JIT-ted JAX computation, dimension ordering is purely *logical*, not physical, so rearranging dimensions does not imply a\n",
|
||||
"physical layout difference. However, when we pass an array into a `pallas_call`, we do enforce a major-to-minor dimension ordering constraint. By transposing `y` inside of the `matmul` function, we are requesting that `y` be in a\n",
|
||||
"transposed layout `(n, k)` instead of the usual `(k, n)`. The user will still pass in the array in the (logical) `(n, k)` dimension, however.\n",
|
||||
"\n",
|
||||
"To benchmark the transpose, we actually want `y` to be in the transposed layout beforehand (and we will untranspose it in the benchmark)."
|
||||
"Note: to benchmark the transpose, we actually want `y` to be in the physical transposed layout when we pass it into the kernel, so we don't measure relayout time. In the wrapper function, we will (logically) transpose it back to `(n, k)`\n",
|
||||
"before passing it into `matmul` because `matmul` expects a logical `(n, k)` dimension ordering."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -897,6 +908,8 @@
|
||||
" grid=(m // bm, n // bn, k // bk),\n",
|
||||
" ),\n",
|
||||
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
|
||||
" compiler_params=dict(mosaic=dict(\n",
|
||||
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
|
||||
" )(x, y)"
|
||||
]
|
||||
},
|
||||
@ -1004,7 +1017,6 @@
|
||||
"In this guide, we covered how to write efficient matrix multiplications on TPU using Pallas. We discussed blocked matrix multiplication and pipelining, how to analyze the performance of a TPU matmul, and how to write an efficient `bf16` matrix multiplication. We concluded with templating the matrix multiplication to support a fused transpose and fused activation functions.\n",
|
||||
"\n",
|
||||
"Exercises left to the reader:\n",
|
||||
"* Add Megacore support to the kernel\n",
|
||||
"* Add support for input fusions. Sometimes we want to fuse an operation into the inputs of the matmul. Try templating the matrix multiplication even more to support this.\n",
|
||||
"* Add support for `int8` matrix multiplication. TPU v5 supports native `int8` matrix multiplication at twice the FLOPs of `bf16`. Try adding support for that and see what utilization is possible.\n",
|
||||
"* Add backwards pass support for the `matmul` function. You can do this with `jax.custom_vjp`."
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.14.7
|
||||
kernelspec:
|
||||
display_name: Python 3 (ipykernel)
|
||||
language: python
|
||||
@ -37,13 +37,13 @@ import numpy as np
|
||||
|
||||
## Background
|
||||
|
||||
Matrix multiplication is a fundamental linear algebra operation at heart of modern deep learning and language modeling. We'd like to make matmuls as speedy as possible using specialized accelerators like TPUs and GPUs. Both TPUs and GPUs have specialized units for fast matrix multiplication though they work differently.
|
||||
Matrix multiplication is a fundamental linear algebra operation at heart of modern deep learning and language modeling. We'd like to make matmuls as speedy as possible using specialized accelerators like TPUs and GPUs, which both have specialized units for fast matrix multiplication.
|
||||
|
||||
To effectively utilize TPUs for matrix multiplication, we'll need to cover two background concepts: block matrix multiplication, and pipelining.
|
||||
To effectively utilize TPUs for matrix multiplication, we'll need to cover a few background concepts: block matrix multiplication, tiling and pipelining.
|
||||
|
||||
### Block Matrix Multiplication
|
||||
|
||||
Let's implement `matmul(x, y)` function that multiplies an `(m, k)` array with a `(k, n)` array, but let's say that we also have a function `matmul_small` that does matrix multiplication for us but only with small sizes (say `m, k, n <= 256`). Could we implement our our `matmul` function in terms of `matmul_small`?
|
||||
Let's say we'd like to implement a `matmul(x, y)` function that multiplies an `(m, k)` array with a `(k, n)` array using a provided function `matmul_small` that does matrix multiplication for us but only with small sizes (say `m, k, n <= 256`) to do the compute. Could we implement our our `matmul` function in terms of `matmul_small`?
|
||||
|
||||
Yes! The answer is by decomposing the matmul into one that operates on small chunks of our inputs (a.k.a. blocks).
|
||||
|
||||
@ -112,13 +112,13 @@ np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)
|
||||
|
||||
`block_matmul` decomposes a matrix multiplication into many smaller ones by observing that each output chunk of size `(bm, bn)` can be computed by accumulating several `(bm, bk) x (bk, bn)` size matrix multiplications.
|
||||
|
||||
TPUs have native hardware capable of small matrix multiplication akin to `matmul_small`, so to utilize this hardware when doing bigger matrix multiplications, we will apply the `block_matmul` decomposition.
|
||||
TPUs and GPUs have native hardware capable of small matrix multiplication akin to `matmul_small`, so to utilize this hardware when doing bigger matrix multiplications, we will apply the `block_matmul` decomposition.
|
||||
|
||||
+++ {"id": "a0ESFoX1ID0z"}
|
||||
|
||||
### Pipelining
|
||||
### Tiling and Pipelining
|
||||
|
||||
In [the previous guide](pipelining), we covered how pipelining in Pallas works. To make sure our compute units are always working and never waiting for memory transfers, we overlap the memory transfers for the next iteration of a kernel with the current one.
|
||||
In [the previous guide](pipelining), we covered how tiling up computations and pipelining in Pallas works. To make sure our compute units are always working and never stalled by memory transfers, we overlap the memory transfers for the next iteration of a kernel with the current one.
|
||||
|
||||
In Pallas, we specify that via `BlockSpec`s and a `grid`. Note that we already have a nested for loop in the block matrix multiplication algorithm. We can specify that in Pallas via a `grid`. The slices in the block matrix multiplication can also be specified via `BlockSpec`s.
|
||||
|
||||
@ -158,7 +158,9 @@ def matmul(
|
||||
in_specs=[pl.BlockSpec(lambda i, j, k: (i, k), (bm, bk)),
|
||||
pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn))],
|
||||
out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (bm, bn)),
|
||||
grid=(m // bm, n // bn, k // bk)
|
||||
grid=(m // bm, n // bn, k // bk),
|
||||
compiler_params=dict(mosaic=dict(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary"))),
|
||||
)(x, y)
|
||||
```
|
||||
|
||||
@ -180,7 +182,7 @@ Let's think about how to analyze matrix multiplication performance.
|
||||
|
||||
The number of FLOPs in a `(m, k) x (k, n)` matrix multiplication are (approximately) `2 * m * k * n`. (Technically it is `n * m * (2k - 1)` but for large enough `k` our approximation is sufficient.)
|
||||
|
||||
The minimum amount of memory usage for a matrix multiply (assuming float32) is the total size of the inputs (copying into VMEM) plus the size of the output (copying into HBM). Thus the minimum bandwidth usage is `(m * k + k * n + m * n) * 4 bytes/float32`.
|
||||
The minimum amount of memory usage for a matrix multiply (assuming float32) is the total size of the inputs (copying into VMEM) plus the size of the output (copying into HBM). Thus the minimum bandwidth usage is `(m * k + k * n + m * n) * 4 bytes/float32`. Memory usage can be greater if we re-read the inputs multiple times, which is often the case.
|
||||
|
||||
One observation is that the number of matmul FLOPs is cubic in its inputs whereas the minimum bandwidth usage is quadratic in its inputs. Intuitively, this means that FLOPs grow faster than bandwidth usage, meaning that the bigger our matmul is, the more compute we have relative to copying.
|
||||
|
||||
@ -200,9 +202,9 @@ print(matmul_membw(1024, 1024, 1024, jnp.float32))
|
||||
|
||||
+++ {"id": "agCtb2GMQazl"}
|
||||
|
||||
Now that we can calculate the amount of FLOPs and memory bandwidth usage of a matrix multiplication, let's now see how many FLOP/s we can execute and how much memory bandwidth we actually have on a particular TPU.
|
||||
Now that we can calculate the amount of FLOPs and (minimum) memory bandwidth usage of a matrix multiplication, let's now see how many FLOP/s we can execute and how much memory bandwidth we actually have on a particular TPU.
|
||||
|
||||
This notebook was run on TPU v5e so we will utilize numbers for that specific chip (if you are running this notebook, your numbers may differ). TPU v5es have [197 TFLOPs of bf16 compute and 819 GBps of memory bandwidth](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v5e). We can compute the arithmetic intensity that the chip is capable of by taking their ratio (about 240 on TPU v5e).
|
||||
This notebook was run on TPU v5e so we will utilize numbers for that specific chip (if you are running this notebook, your numbers may differ). TPU v5es have [197 TFLOPs of bf16 compute and 819 GB/s of memory bandwidth](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v5e). We can compute the minimum arithmetic intensity (ratio of FLOPs to memory bandwidth) needed to be compute bound in this analytical model on this chip by taking their ratio (about 240 FLOPs/byte on TPU v5e).
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: WUydNX2-K6Oy
|
||||
@ -325,9 +327,9 @@ np.testing.assert_array_equal(x @ y, matmul(x, y))
|
||||
|
||||
Our above analysis about FLOPS vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.
|
||||
|
||||
This means that we actually care about the FLOPS vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPS vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for matrix multiplication performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.
|
||||
This means that we actually care about the FLOPS vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPS vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.
|
||||
|
||||
The intuition should therefore be: make the blocks as big as possible! There are two main constraints:
|
||||
The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints:
|
||||
|
||||
1. VMEM usage: The bigger our blocks, the more VMEM we use. With large enough blocks, we will run out.
|
||||
2. Pipeline bubbles: The larger our blocks are relative to the matrix size, the fewer loop iterations we will have in our pipeline. This will make the size of the bubbles at the beginning and end of the pipeline larger relative to the total pipeline and this overhead can be nontrivial.
|
||||
@ -384,6 +386,7 @@ analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
|
||||
Bigger block sizes help a lot! We get pretty good utilization (80-90%) in the bigger matmuls, but the smallest matmul seems pretty hard to get good performance with.
|
||||
|
||||
Let's compare this with XLA's matmuls. We don't expect Pallas to do better than XLA because XLA is *very* good at generating matmuls but hopefully we are close.
|
||||
With more careful block size tuning (left as future work), we can also reach XLA performance.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: OpU7I7BNXQYg
|
||||
@ -410,7 +413,7 @@ Now that we have a basic matrix multiplication kernel, we can now try fusing ope
|
||||
|
||||
### Fused right-hand-side transpose
|
||||
|
||||
A common first thing to do is to fuse a transpose (e.g. instead of doing `x @ y` we do `x @ y.T`). On TPUs, the MXU supports matmul routines with the right hand side transposed natively, so there is no additional cost to fusing in a transpose. We can use this routine with `jax.lax.dot_general`.
|
||||
A common first thing to do is to fuse a transpose (e.g. instead of doing `x @ y` we do `x @ y.T`). Accelerators often supports a matrix multiplication routine that fuses an RHS transpose (for example, TPU v5e) so there is no additional cost to fusing in a transpose. We can use this routine with `jax.lax.dot_general`, which when natively supported will be more efficient than doing a transpose and a matmul separately.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 1_6S_QnMbHAQ
|
||||
@ -420,6 +423,11 @@ def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):
|
||||
def _():
|
||||
acc_ref[...] = jnp.zeros_like(acc_ref)
|
||||
|
||||
# dot_general expects a data structure (contraction_dims, batch_dims),
|
||||
# where contraction_dims are the set of dimensions for LHS and RHS that will
|
||||
# be contracted (reduced) in the matmul; batch_dims, on the other hand, are
|
||||
# looped over. The remaining dimensions will be the input and output dimension
|
||||
# of the matmul.
|
||||
if transpose_rhs:
|
||||
dims = ((1,), (1,)), ((), ())
|
||||
else:
|
||||
@ -469,9 +477,12 @@ def matmul(
|
||||
|
||||
+++ {"id": "eSmPJHSchuGX"}
|
||||
|
||||
Note that we do a transpose inside of the `matmul` function (`y = y.swapaxes(0, 1)`). This is because transposes in XLA are *logical*, not physical. It informs XLA that the custom call would like `y` to be laid out differently before the matmul.
|
||||
We do a transpose inside of the `matmul` function (`y = y.swapaxes(0, 1)`). This is because inside of a JIT-ted JAX computation, dimension ordering is purely *logical*, not physical, so rearranging dimensions does not imply a
|
||||
physical layout difference. However, when we pass an array into a `pallas_call`, we do enforce a major-to-minor dimension ordering constraint. By transposing `y` inside of the `matmul` function, we are requesting that `y` be in a
|
||||
transposed layout `(n, k)` instead of the usual `(k, n)`. The user will still pass in the array in the (logical) `(n, k)` dimension, however.
|
||||
|
||||
To benchmark the transpose, we actually want `y` to be in the transposed layout beforehand (and we will untranspose it in the benchmark).
|
||||
Note: to benchmark the transpose, we actually want `y` to be in the physical transposed layout when we pass it into the kernel, so we don't measure relayout time. In the wrapper function, we will (logically) transpose it back to `(n, k)`
|
||||
before passing it into `matmul` because `matmul` expects a logical `(n, k)` dimension ordering.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: AcBMHhKLhkDp
|
||||
@ -583,6 +594,8 @@ def matmul(
|
||||
grid=(m // bm, n // bn, k // bk),
|
||||
),
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
|
||||
compiler_params=dict(mosaic=dict(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary"))),
|
||||
)(x, y)
|
||||
```
|
||||
|
||||
@ -637,7 +650,6 @@ The additional fused activation barely affects our utilization at all!
|
||||
In this guide, we covered how to write efficient matrix multiplications on TPU using Pallas. We discussed blocked matrix multiplication and pipelining, how to analyze the performance of a TPU matmul, and how to write an efficient `bf16` matrix multiplication. We concluded with templating the matrix multiplication to support a fused transpose and fused activation functions.
|
||||
|
||||
Exercises left to the reader:
|
||||
* Add Megacore support to the kernel
|
||||
* Add support for input fusions. Sometimes we want to fuse an operation into the inputs of the matmul. Try templating the matrix multiplication even more to support this.
|
||||
* Add support for `int8` matrix multiplication. TPU v5 supports native `int8` matrix multiplication at twice the FLOPs of `bf16`. Try adding support for that and see what utilization is possible.
|
||||
* Add backwards pass support for the `matmul` function. You can do this with `jax.custom_vjp`.
|
||||
|
Loading…
x
Reference in New Issue
Block a user