mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

The starting point was the text in pipelining.md, where I replaced it now with a reference to the separate grid and BlockSpec documentation. The grids and BlockSpecs are also documented in the quickstart.md, which I mostly left alone because it was good enough for a simple example. I have also attempted to add a few docstrings.
513 lines
17 KiB
Plaintext
513 lines
17 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Pallas Quickstart\n",
|
|
"\n",
|
|
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
|
|
"\n",
|
|
"Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.\n",
|
|
"Pallas allows you to use the same JAX functions and APIs but operates at a\n",
|
|
"*lower* level of abstraction.\n",
|
|
"\n",
|
|
"Specifically, Pallas requires users to think about memory access and how to\n",
|
|
"divide up computations across multiple compute units in a hardware accelerator.\n",
|
|
"On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.\n",
|
|
"\n",
|
|
"Let's dive into some examples.\n",
|
|
"\n",
|
|
"> Note: Pallas is still an experimental API and you may be broken by changes!"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Hello world in Pallas"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from functools import partial\n",
|
|
"\n",
|
|
"import jax\n",
|
|
"from jax.experimental import pallas as pl\n",
|
|
"import jax.numpy as jnp\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We'll first write the \"hello world\" in Pallas, a kernel that adds two vectors."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def add_vectors_kernel(x_ref, y_ref, o_ref):\n",
|
|
" x, y = x_ref[...], y_ref[...]\n",
|
|
" o_ref[...] = x + y"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"**`Ref` types**\n",
|
|
"\n",
|
|
"Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n",
|
|
"it does not take in `jax.Array`s as inputs and doesn't return any values.\n",
|
|
"Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
|
|
"but we are given an `o_ref`, which corresponds to the desired output.\n",
|
|
"\n",
|
|
"**Reading from `Ref`s**\n",
|
|
"\n",
|
|
"In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]`\n",
|
|
"(the ellipsis means we are reading the whole `Ref`;\n",
|
|
"alternatively we also could have used `x_ref[:]`).\n",
|
|
"Reading from a `Ref` like this returns a `jax.Array`.\n",
|
|
"\n",
|
|
"**Writing to `Ref`s**\n",
|
|
"\n",
|
|
"We then write `x + y` to `o_ref`.\n",
|
|
"Mutation has not historically been supported in JAX -- `jax.Array`s are immutable!\n",
|
|
"`Ref`s are new (experimental) types that allow mutation under certain circumstances.\n",
|
|
"We can interpret writing to a `Ref` as mutating its underlying buffer."
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"So we've written what we call a \"kernel\", which we define as a program that will\n",
|
|
"run as an atomic unit of execution on an accelerator,\n",
|
|
"without any interaction with the host.\n",
|
|
"How do we invoke it from a JAX computation?\n",
|
|
"We use the `pallas_call` higher-order function."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"@jax.jit\n",
|
|
"def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:\n",
|
|
" return pl.pallas_call(add_vectors_kernel,\n",
|
|
" out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
|
|
" )(x, y)\n",
|
|
"add_vectors(jnp.arange(8), jnp.arange(8))"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"`pallas_call` lifts the Pallas kernel function into an operation that can be called\n",
|
|
"as part of a larger JAX program. But, to do so, it needs a few more details.\n",
|
|
"Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list\n",
|
|
"thereof).\n",
|
|
"`out_shape` determines the shape/dtype of `o_ref` in our `add_vector_kernel`.\n",
|
|
"\n",
|
|
"`pallas_call` returns a function that takes in and returns `jax.Array`s."
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"**What's actually happening here?**\n",
|
|
"\n",
|
|
"Thus far we've described how to think about Pallas kernels but what we've actually\n",
|
|
"accomplished is we're writing a function that's executed very close to the compute units.\n",
|
|
"\n",
|
|
"On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n",
|
|
"we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n",
|
|
"(this is a costly operation generally speaking!).\n",
|
|
"We then use GPU vector compute to execute the addition, then copy the resulting value\n",
|
|
"in SRAM back to HBM.\n",
|
|
"\n",
|
|
"On TPU, we do something slightly different. Before the kernel is ever executed,\n",
|
|
"we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in\n",
|
|
"SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register.\n",
|
|
"We then use TPU vector compute to execute the addition, then copy the resulting\n",
|
|
"value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM.\n",
|
|
"\n",
|
|
"We are in the process of writing backend-specific Pallas guides. Coming soon!"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Pallas programming model"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In our \"hello world\" example, we wrote a very simple kernel.\n",
|
|
"It takes advantage of the fact that our 8-sized arrays can comfortably fit inside\n",
|
|
"the SRAM of hardware accelerators.\n",
|
|
"In most real-world applications, this will not be the case!"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Part of writing Pallas kernels is thinking about how to take big arrays that\n",
|
|
"live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n",
|
|
"that operate on \"blocks\" of those arrays that can fit in SRAM.\n",
|
|
"\n",
|
|
"### Grids by example\n",
|
|
"\n",
|
|
"To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n",
|
|
"`BlockSpec`s to `pallas_call`.\n",
|
|
"\n",
|
|
"A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies\n",
|
|
"an iteration space.\n",
|
|
"For example, a grid `(4, 5)` would have 20 elements:\n",
|
|
"`(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`.\n",
|
|
"We run the kernel function once for each element, a style of single-program\n",
|
|
"multiple-data (SPMD) programming.\n",
|
|
"\n",
|
|
"<center>\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"A 2D grid\n",
|
|
"</center>\n",
|
|
"\n",
|
|
"When we provide a `grid` to `pallas_call`, the kernel is executed as many times\n",
|
|
"as `prod(grid)`. Each of these invocations is referred to as a \"program\".\n",
|
|
"To access which program (i.e. which element of the grid) the kernel is currently\n",
|
|
"executing, we use `program_id(axis=...)`.\n",
|
|
"For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and\n",
|
|
"`program_id(axis=1)` returns `2`.\n",
|
|
"\n",
|
|
"Here's an example kernel that uses a `grid` and `program_id`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def iota_kernel(o_ref):\n",
|
|
" i = pl.program_id(0)\n",
|
|
" o_ref[i] = i"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We now execute it using `pallas_call` with an additional `grid` argument."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)"
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def iota(size: int):\n",
|
|
" return pl.pallas_call(iota_kernel,\n",
|
|
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
|
|
" grid=(size,))()\n",
|
|
"iota(8)"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"On GPUs, each program is executed in parallel on separate threads.\n",
|
|
"Thus, we need to think about race conditions on writes to HBM.\n",
|
|
"A reasonable approach is to write our kernels in such a way that different\n",
|
|
"programs write to disjoint places in HBM to avoid these parallel writes.\n",
|
|
"On the other hand, parallelizing the computation is how we can execute\n",
|
|
"operations like matrix multiplications really quickly.\n",
|
|
"\n",
|
|
"On TPUs, programs are executed in a combination of parallel and sequential\n",
|
|
"(depending on the architecture) so there are slightly different considerations.\n",
|
|
"\n",
|
|
"You can read more details at {ref}`pallas_grid`."
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Block specs by example"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With `grid` and `program_id` in mind, Pallas provides an abstraction that\n",
|
|
"takes care of some common indexing patterns seen in a lot of kernels.\n",
|
|
"To build intuition, let's try to implement a matrix multiplication.\n",
|
|
"\n",
|
|
"A simple strategy for implementing a matrix multiplication in Pallas is to\n",
|
|
"implement it recursively.\n",
|
|
"We know our underlying hardware has support for small matrix multiplications\n",
|
|
"(using GPU and TPU tensorcores), so we just express a big matrix multiplication\n",
|
|
"in terms of smaller ones.\n",
|
|
"\n",
|
|
"Suppose we have input matrices $X$ and $Y$ and are computing $Z = XY$.\n",
|
|
"We first express $X$ and $Y$ as block matrices. $X$ will have \"row\" blocks\n",
|
|
"and $Y$ will have \"column\" blocks.\n",
|
|
"\n",
|
|
"$$\n",
|
|
"\\begin{align*}\n",
|
|
"X = \\begin{bmatrix}\n",
|
|
"X_0 \\\\ X_1\n",
|
|
"\\end{bmatrix}\n",
|
|
"\\end{align*}\n",
|
|
"$$\n",
|
|
"\n",
|
|
"$$\n",
|
|
"\\begin{align*}\n",
|
|
"Y = \\begin{bmatrix}\n",
|
|
"Y_0 & Y_1\n",
|
|
"\\end{bmatrix}\n",
|
|
"\\end{align*}\n",
|
|
"$$\n",
|
|
"\n",
|
|
"$$\n",
|
|
"\\begin{align*}\n",
|
|
"Z &=\n",
|
|
"\\begin{bmatrix}\n",
|
|
"X_0 \\\\ X_1\n",
|
|
"\\end{bmatrix}\n",
|
|
"\\begin{matrix}\n",
|
|
"\\begin{bmatrix}\n",
|
|
"Y_0 & Y_1\n",
|
|
"\\end{bmatrix}\n",
|
|
"\\\\\n",
|
|
"~\n",
|
|
"\\end{matrix}\n",
|
|
"\\\\\n",
|
|
"&=\n",
|
|
"\\begin{bmatrix}\n",
|
|
"X_0 Y_0 & X_0 Y_1 \\\\\n",
|
|
"X_1 Y_0 & X_1 Y_1\n",
|
|
"\\end{bmatrix}\n",
|
|
"\\end{align*}\n",
|
|
"$$\n",
|
|
"\n",
|
|
"Our strategy is that because $Z$ is also a block matrix, we can assign each of\n",
|
|
"the programs in our Pallas kernel one of the output blocks.\n",
|
|
"Computing each output block corresponds to doing a smaller matrix multiply\n",
|
|
"between a \"row\" block of $X$ and a \"column\" block of $Y$."
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To express this pattern, we use `BlockSpec`s. A `BlockSpec` specifies a block\n",
|
|
"shape for each input and output, and an \"index map\" function, that maps a\n",
|
|
"set of program indices to a block index.\n",
|
|
"\n",
|
|
"<center>\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"A visualization of a `BlockSpec`\n",
|
|
"\n",
|
|
"</center>\n",
|
|
"\n",
|
|
"For a concrete example, let's say we'd like to multiply two `(1024, 1024)`\n",
|
|
"matrices `x` and `y` together to produce `z`, and would like to parallelize\n",
|
|
"the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where\n",
|
|
"each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication.\n",
|
|
"To express this, we'd first use a `(2, 2)` grid (one block for each program).\n",
|
|
"\n",
|
|
"For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this\n",
|
|
"carves `x` up into \"row\" blocks.\n",
|
|
"To see this see how both program instances\n",
|
|
"`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.\n",
|
|
"For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`.\n",
|
|
"Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.\n",
|
|
"\n",
|
|
"These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n",
|
|
"\n",
|
|
"For more detail on `BlockSpec`s see {ref}`pallas_blockspec`.\n",
|
|
"\n",
|
|
"Underneath the hood, `pallas_call` will automatically carve up your inputs and\n",
|
|
"outputs into `Ref`s for each block that will be passed into the kernel."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def matmul_kernel(x_ref, y_ref, z_ref):\n",
|
|
" z_ref[...] = x_ref[...] @ y_ref[...]\n",
|
|
"\n",
|
|
"def matmul(x: jax.Array, y: jax.Array):\n",
|
|
" return pl.pallas_call(\n",
|
|
" matmul_kernel,\n",
|
|
" out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n",
|
|
" grid=(2, 2),\n",
|
|
" in_specs=[\n",
|
|
" pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n",
|
|
" pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n",
|
|
" ],\n",
|
|
" out_specs=pl.BlockSpec(\n",
|
|
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
|
|
" )\n",
|
|
" )(x, y)\n",
|
|
"k1, k2 = jax.random.split(jax.random.key(0))\n",
|
|
"x = jax.random.normal(k1, (1024, 1024))\n",
|
|
"y = jax.random.normal(k2, (1024, 1024))\n",
|
|
"z = matmul(x, y)\n",
|
|
"np.testing.assert_allclose(z, x @ y)"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Note that this is a very naive implementation of a matrix multiplication but\n",
|
|
"consider it a starting point for various types of optimizations.\n",
|
|
"Let's add an additional feature to our matrix multiply: fused activation.\n",
|
|
"It's actually really easy! Just pass a higher-order activation function into the kernel."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def matmul_kernel(x_ref, y_ref, z_ref, *, activation):\n",
|
|
" z_ref[...] = activation(x_ref[...] @ y_ref[...])\n",
|
|
"\n",
|
|
"def matmul(x: jax.Array, y: jax.Array, *, activation):\n",
|
|
" return pl.pallas_call(\n",
|
|
" partial(matmul_kernel, activation=activation),\n",
|
|
" out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n",
|
|
" grid=(2, 2),\n",
|
|
" in_specs=[\n",
|
|
" pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n",
|
|
" pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n",
|
|
" ],\n",
|
|
" out_specs=pl.BlockSpec(\n",
|
|
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
|
|
" ),\n",
|
|
" )(x, y)\n",
|
|
"k1, k2 = jax.random.split(jax.random.key(0))\n",
|
|
"x = jax.random.normal(k1, (1024, 1024))\n",
|
|
"y = jax.random.normal(k2, (1024, 1024))\n",
|
|
"z = matmul(x, y, activation=jax.nn.relu)\n",
|
|
"np.testing.assert_allclose(z, jax.nn.relu(x @ y))"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`!\n",
|
|
"To turn this matrix multiplication into a batched version, we just need to `vmap` it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"k1, k2 = jax.random.split(jax.random.key(0))\n",
|
|
"x = jax.random.normal(k1, (4, 1024, 1024))\n",
|
|
"y = jax.random.normal(k2, (4, 1024, 1024))\n",
|
|
"z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)\n",
|
|
"np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"formats": "ipynb,md:myst"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|