{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Pallas Quickstart\n", "\n", "\n", "\n", "Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.\n", "Pallas allows you to use the same JAX functions and APIs but operates at a\n", "*lower* level of abstraction.\n", "\n", "Specifically, Pallas requires users to think about memory access and how to\n", "divide up computations across multiple compute units in a hardware accelerator.\n", "On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.\n", "\n", "Let's dive into some examples.\n", "\n", "> Note: Pallas is still an experimental API and you may be broken by changes!" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Hello world in Pallas" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "from functools import partial\n", "\n", "import jax\n", "from jax.experimental import pallas as pl\n", "import jax.numpy as jnp\n", "import numpy as np" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We'll first write the \"hello world\" in Pallas, a kernel that adds two vectors." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def add_vectors_kernel(x_ref, y_ref, o_ref):\n", " x, y = x_ref[...], y_ref[...]\n", " o_ref[...] = x + y" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "**`Ref` types**\n", "\n", "Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n", "it does not take in `jax.Array`s as inputs and doesn't return any values.\n", "Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n", "but we are given an `o_ref`, which corresponds to the desired output.\n", "\n", "**Reading from `Ref`s**\n", "\n", "In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]`\n", "(the ellipsis means we are reading the whole `Ref`;\n", "alternatively we also could have used `x_ref[:]`).\n", "Reading from a `Ref` like this returns a `jax.Array`.\n", "\n", "**Writing to `Ref`s**\n", "\n", "We then write `x + y` to `o_ref`.\n", "Mutation has not historically been supported in JAX -- `jax.Array`s are immutable!\n", "`Ref`s are new (experimental) types that allow mutation under certain circumstances.\n", "We can interpret writing to a `Ref` as mutating its underlying buffer." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "So we've written what we call a \"kernel\", which we define as a program that will\n", "run as an atomic unit of execution on an accelerator,\n", "without any interaction with the host.\n", "How do we invoke it from a JAX computation?\n", "We use the `pallas_call` higher-order function." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@jax.jit\n", "def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:\n", " return pl.pallas_call(add_vectors_kernel,\n", " out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", " )(x, y)\n", "add_vectors(jnp.arange(8), jnp.arange(8))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "`pallas_call` lifts the Pallas kernel function into an operation that can be called\n", "as part of a larger JAX program. But, to do so, it needs a few more details.\n", "Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list\n", "thereof).\n", "`out_shape` determines the shape/dtype of `o_ref` in our `add_vector_kernel`.\n", "\n", "`pallas_call` returns a function that takes in and returns `jax.Array`s." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "**What's actually happening here?**\n", "\n", "Thus far we've described how to think about Pallas kernels but what we've actually\n", "accomplished is we're writing a function that's executed very close to the compute units.\n", "\n", "On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n", "we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n", "(this is a costly operation generally speaking!).\n", "We then use GPU vector compute to execute the addition, then copy the resulting value\n", "in SRAM back to HBM.\n", "\n", "On TPU, we do something slightly different. Before the kernel is ever executed,\n", "we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in\n", "SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register.\n", "We then use TPU vector compute to execute the addition, then copy the resulting\n", "value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM.\n", "\n", "We are in the process of writing backend-specific Pallas guides. Coming soon!" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Pallas programming model" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "In our \"hello world\" example, we wrote a very simple kernel.\n", "It takes advantage of the fact that our 8-sized arrays can comfortably fit inside\n", "the SRAM of hardware accelerators.\n", "In most real-world applications, this will not be the case!" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Part of writing Pallas kernels is thinking about how to take big arrays that\n", "live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n", "that operate on \"blocks\" of those arrays that can fit in SRAM.\n", "\n", "### Grids by example\n", "\n", "To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n", "`BlockSpec`s to `pallas_call`.\n", "\n", "A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies\n", "an iteration space.\n", "For example, a grid `(4, 5)` would have 20 elements:\n", "`(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`.\n", "We run the kernel function once for each element, a style of single-program\n", "multiple-data (SPMD) programming.\n", "\n", "