2023-08-03 16:01:05 -07:00
|
|
|
|
# Pallas Design
|
|
|
|
|
|
2024-06-21 14:50:02 -07:00
|
|
|
|
<!--* freshness: { reviewed: '2024-04-15' } *-->
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
In this document, we explain the initial Pallas design.
|
|
|
|
|
This is a snapshot of some of the earlier design decisions made
|
|
|
|
|
and Pallas's specific APIs might have changed since.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
## Introduction
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
JAX is being used for a diverse set of workloads, from large scale machine
|
|
|
|
|
learning to scientific computing.
|
|
|
|
|
JAX’s success story is as much a success story for XLA,
|
|
|
|
|
the primary compiler that JAX targets – XLA compiles JAX
|
|
|
|
|
programs for accelerators and has enabled JAX to scale to the largest ML
|
|
|
|
|
models.
|
|
|
|
|
JAX describes logical computations in XLA’s representation, HLO.
|
|
|
|
|
HLO describes how computations happen logically but not physically.
|
|
|
|
|
Given a logical HLO computation, XLA decides how that computation is to be
|
|
|
|
|
executed physically.
|
|
|
|
|
For a wide variety of ML applications, XLA does a good
|
|
|
|
|
job of compiling user programs but inevitably some users hit XLA's
|
|
|
|
|
limitations.
|
|
|
|
|
In these cases, we need to provide an “escape hatch” to allow
|
|
|
|
|
experts to write hand-tuned kernels that outperform XLA at that
|
2024-07-01 14:25:25 -07:00
|
|
|
|
point in time.
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Furthermore, advances in ML systems research take some time to be
|
|
|
|
|
incorporated into XLA and users often want to run ahead with them.
|
|
|
|
|
Over time, the compiler can incorporate the optimizations that were proven
|
|
|
|
|
out experimentally through hand-tuned kernels.
|
|
|
|
|
|
|
|
|
|
XLA does offer the `CustomCall` mechanism as an escape hatch, but it
|
|
|
|
|
requires users to write C++ and on GPU it requires users to learn the
|
|
|
|
|
CUDA programming model.
|
|
|
|
|
The CUDA programming model is arguably too low-level for many machine
|
|
|
|
|
learning GPU kernels, like matrix multiplication,
|
|
|
|
|
and even expert users will have trouble using CUDA to implement efficient
|
|
|
|
|
matrix multiplication or multi-headed attention.
|
|
|
|
|
Not only this, JAX users are usually familiar with Python and NumPy-style
|
|
|
|
|
array programming which doesn’t involve writing any C++ or thinking about
|
|
|
|
|
GPU parallelism.
|
|
|
|
|
All popular machine learning frameworks share this
|
|
|
|
|
idea: manipulating (usually) arrays with high level operations
|
|
|
|
|
like `matmul` or `convolution`.
|
|
|
|
|
Unfortunately, this means implementing a custom operation via `CustomCall`
|
|
|
|
|
is a big investment, involving potentially learning C++ and/or GPU
|
|
|
|
|
programming.
|
|
|
|
|
|
|
|
|
|
[Triton](https://triton-lang.org/main/index.html), a GPU compiler built
|
|
|
|
|
and maintained by OpenAI, has taken the ML compiler world by storm.
|
|
|
|
|
Triton offers the best of both worlds: an array-based programming model
|
|
|
|
|
for GPU kernels. Triton is the primary code generation route
|
|
|
|
|
for `torch.compile` in PyTorch 2.0, via the Torch Inductor library.
|
|
|
|
|
Triton actively hides some aspects of GPU programming in the name of a
|
|
|
|
|
more accessible programming model that can be used from Python and to
|
|
|
|
|
generate optimized code from a higher-level representation.
|
|
|
|
|
While GPUs are more flexible than what Triton offers, in the ML domain,
|
|
|
|
|
Triton seems to be expressive enough for many applications.
|
|
|
|
|
|
|
|
|
|
In this document, we describe Pallas, an extension to JAX that enables
|
|
|
|
|
kernel programming for both GPUs and TPUs using a Triton-like model.
|
|
|
|
|
A JAX-based kernel language offers several advantages:
|
|
|
|
|
* Although Triton exposes a TPU-like programming model to users,
|
|
|
|
|
i.e. writing programs for tiles of arrays in L1-cache, it is specialized
|
|
|
|
|
enough to GPU that we cannot directly compile Triton for TPU.
|
|
|
|
|
For example, Triton offers atomic operations specifically meant to
|
|
|
|
|
handle parallel writes that don’t necessarily make sense on TPU.
|
|
|
|
|
A higher level front end can abstract away details of the platform
|
|
|
|
|
while surfacing just that tile-based programming model.
|
|
|
|
|
The kernels will thus be portable across different hardware platforms.
|
|
|
|
|
* JAX as a tracing-based frontend for numerical computing is both
|
|
|
|
|
mature and well-used.
|
|
|
|
|
By embedding the kernel programming language in JAX itself,
|
|
|
|
|
we can re-use JAX’s tracing infrastructure and provide a
|
|
|
|
|
NumPy-like frontend that’s already familiar to users.
|
|
|
|
|
* JAX transformations are key to its success, allowing users to
|
|
|
|
|
express simple programs but transform them to achieve complex
|
|
|
|
|
functionality.
|
|
|
|
|
We can leverage the same transformations (vmap, jvp, etc.) to
|
|
|
|
|
transform user-written kernels.
|
|
|
|
|
|
|
|
|
|
The open question is: is JAX a good fit for a kernel language at all?
|
|
|
|
|
We think so.
|
|
|
|
|
Triton demonstrates that an array programming language can be
|
|
|
|
|
practical for writing GPU kernels and JAX is just that.
|
|
|
|
|
JAX has also proven to be a flexible front-end for compilers and
|
|
|
|
|
for program transformations.
|
|
|
|
|
|
|
|
|
|
We describe Pallas as follows: we first describe the ways in which
|
|
|
|
|
we extend JAX to support writing custom kernels.
|
|
|
|
|
We then show how we can lower Pallas to both Triton and Mosaic.
|
|
|
|
|
We conclude by describing existing and potential ways to transform
|
|
|
|
|
Pallas kernels via JAX transformations.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
<center>
|
|
|
|
|
|
|
|
|
|

|
|
|
|
|
Visualization of Pallas lowering paths
|
|
|
|
|
|
|
|
|
|
</center>
|
|
|
|
|
|
|
|
|
|
## Pallas: Extending JAX for kernels
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
The key point we’d like to make is that Pallas is just JAX, with some
|
|
|
|
|
extensions:
|
|
|
|
|
1. Users now use reference types called `Ref`s in their JAX code.
|
|
|
|
|
This gives users more precise control over memory access and
|
|
|
|
|
layout in JAX will more closely resemble physical layout.
|
|
|
|
|
2. Users write their JAX programs using a subset of JAX primitives,
|
|
|
|
|
along with a set of Pallas-specific primitives.
|
|
|
|
|
3. Users embed their Pallas kernels in an outer JAX program via a
|
|
|
|
|
special `pallas_call` higher-order function, that executes the
|
|
|
|
|
kernel in a map. It is analogous to `pmap` or `shard_map`,
|
|
|
|
|
except with references to shared memory.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
We’ll go over these three extensions one at a time, by example.
|
|
|
|
|
|
|
|
|
|
Note that these APIs are still experimental and subject to change.
|
|
|
|
|
|
|
|
|
|
### Reference types
|
|
|
|
|
|
|
|
|
|
Let’s look at an example Pallas program for adding two vectors:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
import jax
|
|
|
|
|
import jax.numpy as jnp
|
|
|
|
|
from jax.experimental import pallas as pl
|
|
|
|
|
|
|
|
|
|
def add_kernel(x_ref, y_ref, o_ref):
|
|
|
|
|
# In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
|
|
|
|
|
x = x_ref[:]
|
|
|
|
|
y = y_ref[:]
|
|
|
|
|
o_ref[:] = x + y
|
|
|
|
|
x, y = jnp.arange(8), jnp.arange(8, 16)
|
|
|
|
|
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
|
|
|
|
|
add(x, y)
|
|
|
|
|
```
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Unlike a regular JAX program, `add_kernel` does not receive immutable
|
|
|
|
|
array arguments.
|
|
|
|
|
Instead, it’s provided with references that can be read from and
|
|
|
|
|
updated in-place using NumPy-like syntax.
|
|
|
|
|
`Ref`s are not a Pallas-specific concept – they were introduced to
|
|
|
|
|
JAX to represent stateful computations.
|
|
|
|
|
However, we can leverage them when writing kernels that operate on
|
|
|
|
|
mutable memory too.
|
|
|
|
|
|
|
|
|
|
Pallas kernels not only receive `Ref`s corresponding to the inputs
|
|
|
|
|
to the kernel, but also receive `Ref`s for the outputs as well
|
|
|
|
|
(specified in `pallas_call` via `out_shape`).
|
|
|
|
|
`Ref`s are special types that cannot be passed into the usual set of
|
|
|
|
|
JAX primitives without being read from first.
|
|
|
|
|
When you read from a `Ref` you get a JAX `Array` type out, and you
|
|
|
|
|
must write an `Array` into a `Ref`.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### Reading from/writing into Refs
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Reading from a `Ref` corresponds to loading an array into the
|
|
|
|
|
lowest level of the memory hierarchy (L1-cache on GPU and vector
|
|
|
|
|
registers on TPU). Writing into a `Ref` is analogous.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
def f(x_ref, o_ref):
|
|
|
|
|
# Using vanilla Python indexing
|
|
|
|
|
x = x_ref[0, 2:5, :]
|
|
|
|
|
# Or via Numpy advanced int indexing
|
|
|
|
|
o_ref[jnp.arange(3), :] = x
|
|
|
|
|
|
2024-02-12 06:18:51 -08:00
|
|
|
|
# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:
|
2023-08-03 16:01:05 -07:00
|
|
|
|
def f(x_ref):
|
|
|
|
|
# Assume x_ref is (8, 4) and we want to read out a (2, 3) slice
|
|
|
|
|
x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]]
|
|
|
|
|
```
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Writing to `Ref`s can be done via analogous `__setitem__` style
|
|
|
|
|
indexing.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Other forms of indexing (for example, dynamic slicing) can be done
|
|
|
|
|
via `pallas.load` and `pallas.store`, new JAX primitives designed to
|
|
|
|
|
make loading from/storing into memory easier.
|
|
|
|
|
We’ll discuss these new primitives later.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
### Extending JAX with new Pallas primitives
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Because JAX was designed with HLO in mind, the set of JAX primitives
|
|
|
|
|
closely mirrors the set of HLO operations.
|
|
|
|
|
Targeting a new compiler (e.g. Triton or Mosaic) means we might need
|
|
|
|
|
to supplement JAX’s primitives with new ones specific to the new
|
|
|
|
|
compiler.
|
|
|
|
|
At the same time, we may not be able to lower all JAX primitives,
|
|
|
|
|
so we need to restrict it to a subset.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Because Pallas was initially designed with Triton in mind,
|
|
|
|
|
we offer a set of new primitives targeting the Triton programming model.
|
|
|
|
|
As we’ll show later, we can lower these primitives to Mosaic as well.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### `pallas.load` and `pallas.store`
|
2024-04-03 10:46:05 +01:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
`pallas.load` and `pallas.store` are primitives that allow loading
|
|
|
|
|
from memory and storing into memory.
|
|
|
|
|
Unlike `__getitem__` and `__setitem__` they are more flexible at the
|
|
|
|
|
cost of being more verbose.
|
|
|
|
|
Specifically, you can use the `pallas.dynamic_slice` (`pallas.ds` for
|
|
|
|
|
short) construct (which should maybe be upstreamed into JAX to be
|
|
|
|
|
used with Ref `__getitem__` and `__setitem__`).
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
def f(x_ref, o_ref):
|
|
|
|
|
# Reading from memory via pallas.load
|
|
|
|
|
x = pl.load(x_ref, (0, slice(2, 5), slice(None)))
|
|
|
|
|
# Using integer indexing automatically broadcasts
|
|
|
|
|
x = pl.load(x_ref, (0, 2 + jnp.arange(3), slice(None)))
|
|
|
|
|
# You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as well
|
|
|
|
|
pl.store(o_ref, (0, pl.ds(start=2, size=3), slice(None)), x)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
`pallas.load` and `pallas.store` also support masking via the mask
|
|
|
|
|
argument.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
def f(x_ref, o_ref):
|
|
|
|
|
# Reading from memory via pallas.load
|
|
|
|
|
idx = jnp.arange(8)
|
|
|
|
|
mask = idx < 5
|
|
|
|
|
x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf'))
|
|
|
|
|
```
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Masking is important when doing out-of-bounds loads/stores.
|
|
|
|
|
The operational semantics of masking can be compiler-determined
|
|
|
|
|
(if we understand the documentation properly, Triton avoids the read
|
|
|
|
|
from/write to memory if it’s masked).
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### `pallas.program_id` and `pallas.num_programs`
|
2024-04-03 10:46:05 +01:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
As we’ll soon see, we’ll be executing the same Pallas kernels many
|
|
|
|
|
times (either in parallel or in a pipeline depending on the backend).
|
|
|
|
|
These new primitives tell us “where” we are in the execution of the
|
|
|
|
|
kernel.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
`pallas.program_id` takes in an axis argument, which tells us which
|
|
|
|
|
index in an axis of a multidimensional grid this kernel is currently
|
|
|
|
|
executing in (analogous to `threadId` from CUDA programming or
|
|
|
|
|
`lax.axis_index` in `jax.pmap`).
|
|
|
|
|
Note that we are currently borrowing the “program” terminology from
|
|
|
|
|
Triton and in the future we might want to change it to something more
|
|
|
|
|
familiar to JAX users.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
def f(x_ref, o_ref):
|
|
|
|
|
i = pl.program_id(axis=0) # execution index in the first axis of the grid
|
|
|
|
|
o_ref[i] = jnp.exp(x_ref[i])
|
|
|
|
|
```
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
`pallas.num_programs` also takes in an axis and returns the grid size
|
|
|
|
|
for that axis.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Note that while `program_id` and `num_programs` are Triton-specific
|
|
|
|
|
terminology they are easily generalized to make sense on TPU as well.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### Using a subset of JAX primitives in Pallas
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Because we’re writing kernels, not high-level HLO programs, some JAX
|
|
|
|
|
primitives may not be able to be represented in our underlying
|
|
|
|
|
substrate efficiently.
|
|
|
|
|
However, we know we can support most elementwise operations,
|
|
|
|
|
simple dot products, and JAX control flow.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
While we haven’t yet mapped out exactly all the JAX primitives that
|
|
|
|
|
we can support in Pallas kernels, we can certainly identify some that
|
|
|
|
|
are not easy to lower or are unlikely to be useful:
|
|
|
|
|
* `conv_general` - convolution usually isn’t offered as primitive in
|
|
|
|
|
the underlying hardware.
|
|
|
|
|
* `gather/scatter` - the underlying compiler may not support
|
|
|
|
|
noncontiguous memory reads and writes
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
### Executing Pallas kernels with `pallas_call`
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Now that we’ve written our Pallas kernels (a.k.a. JAX with `Ref`s and
|
|
|
|
|
the extra Pallas primitives), how do we execute them on a GPU or TPU?
|
|
|
|
|
We use `pallas_call`, a higher order function (akin to `jax.jit` and
|
|
|
|
|
`jax.pmap`) that executes the kernel.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
The signature of `pallas_call` is as follows:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
def pallas_call(
|
|
|
|
|
kernel: Callable,
|
2024-06-27 11:07:26 +03:00
|
|
|
|
out_shape: Sequence[jax.ShapeDtypeStruct],
|
|
|
|
|
*,
|
2023-08-03 16:01:05 -07:00
|
|
|
|
in_specs: Sequence[Spec],
|
2023-11-09 10:57:47 -05:00
|
|
|
|
out_specs: Sequence[Spec],
|
2023-08-03 16:01:05 -07:00
|
|
|
|
grid: Optional[Tuple[int, ...]] = None) -> Callable:
|
|
|
|
|
...
|
|
|
|
|
```
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
When we provide a kernel to `pallas_call` we provide additional
|
|
|
|
|
information. The first is `out_shape` which tells the kernel what the
|
|
|
|
|
outputs look like (`pallas_call` will pass a `Ref` corresponding to
|
|
|
|
|
these into the kernel to be written to).
|
|
|
|
|
The rest of the information (`in_specs`, `out_specs`, and `grid`) are
|
|
|
|
|
information about how the kernel will be scheduled on the accelerator.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
2024-02-12 06:18:51 -08:00
|
|
|
|
The (rough) semantics for `pallas_call` are as follows:
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
```python
|
2024-06-27 11:07:26 +03:00
|
|
|
|
def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid):
|
2023-08-03 16:01:05 -07:00
|
|
|
|
def execute(*args):
|
2024-06-27 11:07:26 +03:00
|
|
|
|
outputs = map(empty_ref, out_shape)
|
2023-08-03 16:01:05 -07:00
|
|
|
|
grid_indices = map(range, grid)
|
|
|
|
|
for indices in itertools.product(*grid_indices): # Could run in parallel!
|
|
|
|
|
local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
|
2023-11-09 10:57:47 -05:00
|
|
|
|
zip(args, in_specs)]
|
2023-08-03 16:01:05 -07:00
|
|
|
|
local_outputs = [out_spec.transform(arg, indices) for arg, out_spec in
|
2023-11-09 10:57:47 -05:00
|
|
|
|
zip(outputs, out_specs)]
|
2023-08-03 16:01:05 -07:00
|
|
|
|
kernel(*local_inputs, *local_outputs) # writes to outputs
|
|
|
|
|
return execute
|
|
|
|
|
```
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Specifically, `pallas_call` will “loop” over grid iteration space,
|
|
|
|
|
applying a transformation to the inputs and outputs specified via
|
|
|
|
|
the `in_specs` and `out_specs`.
|
|
|
|
|
In each iteration, the kernel will be called on the transformed
|
|
|
|
|
inputs and outputs. Note that the “loop” over the iteration space
|
|
|
|
|
could be executed in parallel (e.g. on GPU).
|
|
|
|
|
`pallas_call` also provides no guarantees on the order of loop
|
|
|
|
|
iterations over the iteration space, just that every member of the
|
|
|
|
|
iteration space will be looped over.
|
|
|
|
|
Compilers like Triton and Mosaic will have more specific operational
|
|
|
|
|
semantics associated with the grid.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### Transformation functions
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
The `in_specs` and `out_specs` arguments to `pallas_call` allow
|
|
|
|
|
inputs and outputs to be transformed in some way.
|
|
|
|
|
The two options that Pallas offers right now are an identity
|
|
|
|
|
transformation (where inputs and outputs are left unchanged),
|
|
|
|
|
and `BlockSpec`s, take fixed-size slices of `Ref`s determined by the
|
|
|
|
|
loop index.
|
|
|
|
|
|
|
|
|
|
A `BlockSpec` takes an `index_map` function and a `block_shape`.
|
|
|
|
|
Logically, it takes an array and slices it along each axis into
|
|
|
|
|
`block_shape` sizes blocks.
|
|
|
|
|
The `index_map` function takes loop indices (from the grid index set)
|
|
|
|
|
and maps them to block indices.
|
|
|
|
|
The transform function converts `Ref`s into logical views of the
|
|
|
|
|
`Ref` at the corresponding block.
|
|
|
|
|
When we specify `None` in an entry in block_shape,
|
|
|
|
|
that corresponds to “mapping” over that dimension,
|
|
|
|
|
removing it from the block within the kernel.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
class BlockSpec:
|
|
|
|
|
index_map: Callable[[Tuple[Int, ...]], Tuple[Int, ...]]
|
|
|
|
|
block_shape: Tuple[Optional[int], ...]
|
2023-11-09 10:57:47 -05:00
|
|
|
|
|
2023-08-03 16:01:05 -07:00
|
|
|
|
def transform(self, ref, *loop_indices):
|
|
|
|
|
block_indices = self.transform_function(loop_indices)
|
|
|
|
|
# Returns a view of `ref` starting at `block_indices` of shape self.block_shape
|
|
|
|
|
...
|
|
|
|
|
```
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
We could also imagine other `Spec`s that are used with `pallas_call`,
|
|
|
|
|
for example a `Spec` that corresponds to overlapping windows to, say,
|
|
|
|
|
implement convolutions.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
### Immediate benefits of Pallas as a front-end
|
2024-06-25 12:46:25 +03:00
|
|
|
|
By offering a JAX front-end for kernel writing, we can immediately
|
|
|
|
|
reap some benefits.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### More flexible front end
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
The first is that JAX users are already accustomed to the benefits
|
|
|
|
|
(and limitations) of programming with JAX and its tracing-based
|
|
|
|
|
transformations.
|
|
|
|
|
This means users can use closures and other familiar Python constructs
|
|
|
|
|
when writing Pallas kernels.
|
|
|
|
|
This is unlike the existing AST-parsing-based Triton front end or the
|
|
|
|
|
MLIR builders for Mosaic.
|
|
|
|
|
For example, this makes Pallas far more amenable to templating than
|
|
|
|
|
Triton.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
See this example of how we can use higher-order functions in Python
|
|
|
|
|
to template a kernel.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
def make_kernel(eltwise_kernel):
|
|
|
|
|
def add(x_ref, y_ref, o_ref):
|
|
|
|
|
x = pl.load(x_ref, ())
|
|
|
|
|
y = pl.load(y_ref, ())
|
|
|
|
|
pl.store(o_ref, (), eltwise_kernel(x + y))
|
|
|
|
|
return add
|
|
|
|
|
|
|
|
|
|
kernel1 = make_kernel(lambda x: x * 2)
|
|
|
|
|
kernel2 = make_kernel(jnp.exp)
|
|
|
|
|
|
|
|
|
|
pl.pallas_call(kernel1, out_shape=x, grid=1)(1., 1.)
|
|
|
|
|
pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
#### Emulation mode
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
By representing kernels as programs with JAX primitives and some new
|
|
|
|
|
Pallas primitives, we can also lower Pallas programs to StableHLO
|
|
|
|
|
directly and compile/execute them with XLA.
|
|
|
|
|
Specifically, a `pallas_call` can be implemented as a `lax.scan` over
|
|
|
|
|
the grid.
|
|
|
|
|
This enables us to develop GPU or TPU kernels on any XLA-supported
|
|
|
|
|
platform (even CPU!) and debug them using JAX/XLA debugging tools
|
|
|
|
|
(like `jax.debug.print`).
|
|
|
|
|
We can also use the more reliable and better tested XLA numerics to
|
|
|
|
|
verify the correctness of the Triton and Mosaic compilers.
|
|
|
|
|
One could also imagine perturbing the `scan` ordering to simulate the
|
|
|
|
|
parallel reads and writes that happen on GPU.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
### Examples
|
|
|
|
|
|
|
|
|
|
#### `add`
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
We modify our `add_kernel` example to operate over (2,)-sized blocks
|
|
|
|
|
using `BlockSpec`s.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
def add_kernel(x_ref, y_ref, o_ref):
|
|
|
|
|
# In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`s
|
|
|
|
|
x = x_ref[:]
|
|
|
|
|
y = y_ref[:]
|
|
|
|
|
o_ref[:] = x + y
|
|
|
|
|
x, y = jnp.arange(8), jnp.arange(8, 16)
|
|
|
|
|
add = pl.pallas_call(
|
2023-11-09 10:57:47 -05:00
|
|
|
|
add_kernel,
|
2023-08-03 16:01:05 -07:00
|
|
|
|
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
|
|
|
|
|
in_specs=[
|
2024-07-01 14:25:25 -07:00
|
|
|
|
pl.BlockSpec((2,), lambda i: i),
|
|
|
|
|
pl.BlockSpec((2,), lambda i: i)
|
2023-08-03 16:01:05 -07:00
|
|
|
|
],
|
2024-07-01 14:25:25 -07:00
|
|
|
|
out_specs=pl.BlockSpec((2,), lambda i: i),
|
2024-04-02 13:20:00 -07:00
|
|
|
|
grid=(4,))
|
2023-08-03 16:01:05 -07:00
|
|
|
|
add(x, y)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
#### Templated matmul
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
In this example, we compute tiles of the output by doing an unrolled
|
|
|
|
|
accumulation over blocks of rows and columns from our input arrays.
|
|
|
|
|
We inline an activation function into the body of the kernel using a
|
|
|
|
|
higher order function so we can emit a fused kernel.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
|
2024-04-15 09:39:02 -07:00
|
|
|
|
acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
|
2024-04-03 10:46:05 +01:00
|
|
|
|
for k in range(x_ref.shape[1] // block_k):
|
2023-08-03 16:01:05 -07:00
|
|
|
|
x = x_ref[:, k*block_k:(k+1)*block_k]
|
|
|
|
|
y = y_ref[k*block_k:(k+1)*block_k, :]
|
|
|
|
|
acc += x @ y
|
|
|
|
|
o_ref[:, :] = activation(acc).astype(o_ref.dtype)
|
|
|
|
|
|
|
|
|
|
x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
|
2024-04-15 09:39:02 -07:00
|
|
|
|
block_shape = 128, 256, 128
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
@partial(jax.jit, static_argnames=["block_shape", "activation"])
|
|
|
|
|
def matmul(x, y, *, block_shape, activation):
|
|
|
|
|
block_m, block_n, block_k = block_shape
|
|
|
|
|
fused_matmul = pl.pallas_call(
|
2023-11-09 10:57:47 -05:00
|
|
|
|
partial(matmul_kernel, block_k=block_k, activation=activation),
|
2023-08-03 16:01:05 -07:00
|
|
|
|
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
|
|
|
|
|
in_specs=[
|
2024-07-01 14:25:25 -07:00
|
|
|
|
pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)),
|
|
|
|
|
pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j))
|
2023-08-03 16:01:05 -07:00
|
|
|
|
],
|
2024-07-01 14:25:25 -07:00
|
|
|
|
out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)),
|
2024-04-03 10:46:05 +01:00
|
|
|
|
grid=(4, 4),
|
|
|
|
|
)
|
2023-08-03 16:01:05 -07:00
|
|
|
|
return fused_matmul(x, y)
|
|
|
|
|
|
|
|
|
|
z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)
|
|
|
|
|
```
|
|
|
|
|
### Lowering Pallas
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
After users express their Pallas kernels, we lower them to different
|
|
|
|
|
representations depending on the target backend.
|
|
|
|
|
On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas to
|
|
|
|
|
Mosaic.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### Lowering Pallas to Triton for GPU
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Lowering Pallas to Triton is easy because Pallas was designed with
|
|
|
|
|
Triton as a target language in mind.
|
|
|
|
|
The main differences between Pallas and Triton is that Triton doesn’t
|
|
|
|
|
have a notion of `BlockSpec`s and also uses pointers when doing
|
|
|
|
|
memory loads and stores as opposed to indices.
|
|
|
|
|
|
|
|
|
|
Triton supports pointers as an array element type in its language
|
|
|
|
|
and in Triton you can load from and store to arrays of pointers.
|
|
|
|
|
In Pallas, when given a `(4, 5)`-shaped `Ref`, `x_ref`, and then do
|
|
|
|
|
like `x_ref[3, 2]`, we need to lower this to computing a Triton
|
|
|
|
|
pointer to the appropriate row-major position in `x_ref` (that is,
|
|
|
|
|
doing 5 * 3 + 2 * 1).
|
|
|
|
|
Similarly, when we lower slices to Triton, e.g. `x_ref[4, :]` we need
|
|
|
|
|
to produce an array of pointers `5 * 4 + jnp.arange(3)`.
|
|
|
|
|
|
|
|
|
|
Other than that, lowering to Triton is fairly straightforward.
|
|
|
|
|
JAX dot products can be lowered to Triton dot products and JAX unary
|
|
|
|
|
primitives are lowered to their Triton equivalents.
|
|
|
|
|
Triton’s atomic operations are lowered via new Pallas atomic
|
|
|
|
|
primitives.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### Lowering Pallas to Mosaic for TPU
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Mosaic consumes (mostly) standard dialect MLIR and emits LLO to be
|
|
|
|
|
compiled for TPU.
|
|
|
|
|
Pallas can be lowered to Mosaic via translating JAX primitives to
|
|
|
|
|
MLIR (mostly the `vector` and `arith` dialects).
|
|
|
|
|
The `BlockSpec`s can be converted into pipeline schedules
|
|
|
|
|
(i.e. the `transform_func`s in Mosaic).
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
### Transforming Pallas
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
A natural question is how do JAX transformations interact with Pallas
|
|
|
|
|
kernels?
|
|
|
|
|
There are two main ways: transformations inside Pallas kernels and
|
|
|
|
|
transformations outside Pallas kernels.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Transformation inside Pallas kernels should actually “just work”,
|
|
|
|
|
so long as we are able to lower the transformed code.
|
|
|
|
|
For example, we could use `jax.grad(jnp.sin)(...)` inside of a JAX
|
|
|
|
|
kernel because we can lower a `cos` to both Triton and Mosaic.
|
|
|
|
|
However, we might not be able to lower a `jax.vmap(lax.dynamic_slice)`
|
|
|
|
|
because it could turn into a gather that we cannot lower.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
Transformations of Pallas kernels from the outer JAX programs is
|
|
|
|
|
perhaps the more interesting case. How do we handle things like
|
|
|
|
|
`vmap(pallas_call)` and `grad(pallas_call)`?
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### `vmap-of-pallas_call`
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
vmap automatically vectorizes JAX programs. While kernel writers might
|
|
|
|
|
want precise control over how a batched kernel will behave differently
|
|
|
|
|
from its unbatched variant, we can offer a reasonable default `vmap`
|
|
|
|
|
rule for `pallas_call` while offering the `jax.custom_vmap`
|
|
|
|
|
customization mechanism. When `pallas_call` is `vmap`-ed, we augment
|
|
|
|
|
the `pallas_call` to have an extra grid dimension corresponding to the
|
|
|
|
|
new batch dimension and transform the `BlockSpec`s to handle indexing
|
|
|
|
|
along that dimension.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### `grad-of-pallas_call`
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
`grad` of `pallas_call` enables automatic differentiation of kernels.
|
|
|
|
|
`jax.grad` breaks down into applications of three distinct transforms:
|
|
|
|
|
`jvp`, `partial_eval` and `transpose`.
|
|
|
|
|
In principle, we can re-use most of JAX’s infrastructure when
|
|
|
|
|
implementing these rules for `pallas_call` (since it behaves much like
|
|
|
|
|
existing JAX higher order primitives).
|
|
|
|
|
|
|
|
|
|
However, automatic differentiation of kernels can result in a
|
|
|
|
|
performance hit due to how memory access is transposed.
|
|
|
|
|
If we write a GPU kernel with overlapping-and-parallel reads and
|
|
|
|
|
disjoint-but-parallel writes, we automatically transpose it into a
|
|
|
|
|
kernel that has overlapping-but-parallel writes (which are slow when
|
|
|
|
|
done atomically) and disjoint-and-parallel reads.
|
|
|
|
|
To emit a kernel that better uses parallelism with shared memory,
|
|
|
|
|
we would need to reorder loops and change how the kernel is vectorized.
|
|
|
|
|
Unfortunately, we do not have a program representation amenable to
|
|
|
|
|
that in Pallas.
|
|
|
|
|
A potential direction to automatically differentiating kernels
|
|
|
|
|
efficiently is to explore a different representation, perhaps one
|
|
|
|
|
like that in Dex.
|
|
|
|
|
We could also look at how Enzyme approaches this problem.
|
|
|
|
|
However, AD of Pallas kernels may still be useful for a class of
|
|
|
|
|
kernels that does transpose efficiently (for example elementwise
|
|
|
|
|
kernels).
|
|
|
|
|
|
|
|
|
|
In general, though, `jax.custom_vjp` is a viable escape hatch to
|
|
|
|
|
express Pallas kernels that work with `jax.grad`.
|
2023-08-03 16:01:05 -07:00
|
|
|
|
|
|
|
|
|
#### Other transformations
|
|
|
|
|
|
2024-06-25 12:46:25 +03:00
|
|
|
|
We could imagine other JAX transformations applying to Pallas kernels
|
|
|
|
|
that we haven’t explicitly explored yet.
|
|
|
|
|
For example, `checkify` is a JAX transformation that does functional
|
|
|
|
|
error handling.
|
|
|
|
|
We could imagine using `checkify` with pallas_call to allow plumbing
|
|
|
|
|
out error codes from GPU kernels that indicate if OOB access or NaNs
|
|
|
|
|
were produced.
|
|
|
|
|
|
|
|
|
|
Another potential transformation to integrate with is
|
|
|
|
|
custom_partitioning to enable automatically partitionable kernels to
|
|
|
|
|
be used with pjit.
|