mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[pallas] Added more documentation for grid and BlockSpec.
The starting point was the text in pipelining.md, where I replaced it now with a reference to the separate grid and BlockSpec documentation. The grids and BlockSpecs are also documented in the quickstart.md, which I mostly left alone because it was good enough for a simple example. I have also attempted to add a few docstrings.
This commit is contained in:
parent
945b1c3b8a
commit
bfdf8f4bd3
23
docs/jax.experimental.pallas.rst
Normal file
23
docs/jax.experimental.pallas.rst
Normal file
@ -0,0 +1,23 @@
|
||||
``jax.experimental.pallas`` module
|
||||
==================================
|
||||
|
||||
.. automodule:: jax.experimental.pallas
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
BlockSpec
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
pallas_call
|
||||
program_id
|
||||
num_programs
|
||||
|
@ -28,6 +28,7 @@ Experimental Modules
|
||||
jax.experimental.mesh_utils
|
||||
jax.experimental.serialize_executable
|
||||
jax.experimental.shard_map
|
||||
jax.experimental.pallas
|
||||
|
||||
Experimental APIs
|
||||
-----------------
|
||||
|
@ -286,9 +286,10 @@ The signature of `pallas_call` is as follows:
|
||||
```python
|
||||
def pallas_call(
|
||||
kernel: Callable,
|
||||
out_shape: Sequence[jax.ShapeDtypeStruct],
|
||||
*,
|
||||
in_specs: Sequence[Spec],
|
||||
out_specs: Sequence[Spec],
|
||||
out_shapes: Sequence[jax.ShapeDtypeStruct],
|
||||
grid: Optional[Tuple[int, ...]] = None) -> Callable:
|
||||
...
|
||||
```
|
||||
@ -303,9 +304,9 @@ information about how the kernel will be scheduled on the accelerator.
|
||||
The (rough) semantics for `pallas_call` are as follows:
|
||||
|
||||
```python
|
||||
def pallas_call(kernel, in_specs, out_specs, out_shapes, grid):
|
||||
def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid):
|
||||
def execute(*args):
|
||||
outputs = map(empty_ref, out_shapes)
|
||||
outputs = map(empty_ref, out_shape)
|
||||
grid_indices = map(range, grid)
|
||||
for indices in itertools.product(*grid_indices): # Could run in parallel!
|
||||
local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
|
||||
|
213
docs/pallas/grid_blockspec.md
Normal file
213
docs/pallas/grid_blockspec.md
Normal file
@ -0,0 +1,213 @@
|
||||
(pallas_grids_and_blockspecs)=
|
||||
|
||||
# Grids and BlockSpecs
|
||||
|
||||
(pallas_grid)=
|
||||
### `grid`, a.k.a. kernels in a loop
|
||||
|
||||
When using {func}`jax.experimental.pallas.pallas_call` the kernel function
|
||||
is executed multiple times on different inputs, as specified via the `grid` argument
|
||||
to `pallas_call`. Conceptually:
|
||||
```python
|
||||
pl.pallas_call(some_kernel, grid=(n,))(...)
|
||||
```
|
||||
maps to
|
||||
```python
|
||||
for i in range(n):
|
||||
some_kernel(...)
|
||||
```
|
||||
Grids can be generalized to be multi-dimensional, corresponding to nested
|
||||
loops. For example,
|
||||
|
||||
```python
|
||||
pl.pallas_call(some_kernel, grid=(n, m))(...)
|
||||
```
|
||||
is equivalent to
|
||||
```python
|
||||
for i in range(n):
|
||||
for j in range(m):
|
||||
some_kernel(...)
|
||||
```
|
||||
This generalizes to any tuple of integers (a length `d` grid will correspond
|
||||
to `d` nested loops).
|
||||
The kernel is executed as many times
|
||||
as `prod(grid)`. Each of these invocations is referred to as a "program".
|
||||
To access which program (i.e. which element of the grid) the kernel is currently
|
||||
executing, we use {func}`jax.experimental.pallas.program_id`.
|
||||
For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
|
||||
`program_id(axis=1)` returns `2`.
|
||||
You can also use {func}`jax.experimental.pallas.num_programs` to get the
|
||||
grid size for a given axis.
|
||||
|
||||
Here's an example kernel that uses a `grid` and `program_id`.
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax.experimental import pallas as pl
|
||||
|
||||
>>> def iota_kernel(o_ref):
|
||||
... i = pl.program_id(0)
|
||||
... o_ref[i] = i
|
||||
|
||||
```
|
||||
|
||||
We now execute it using `pallas_call` with an additional `grid` argument.
|
||||
|
||||
```python
|
||||
>>> def iota(size: int):
|
||||
... return pl.pallas_call(iota_kernel,
|
||||
... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
|
||||
... grid=(size,), interpret=True)()
|
||||
>>> iota(8)
|
||||
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
|
||||
|
||||
```
|
||||
|
||||
On GPUs, each program is executed in parallel on separate thread blocks.
|
||||
Thus, we need to think about race conditions on writes to HBM.
|
||||
A reasonable approach is to write our kernels in such a way that different
|
||||
programs write to disjoint places in HBM to avoid these parallel writes.
|
||||
|
||||
On TPUs, programs are executed in a combination of parallel and sequential
|
||||
(depending on the architecture) so there are slightly different considerations.
|
||||
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).
|
||||
|
||||
(pallas_blockspec)=
|
||||
|
||||
### `BlockSpec`, a.k.a. how to chunk up inputs
|
||||
|
||||
```{note}
|
||||
The documentation here applies to the ``indexing_mode == Blocked``, which
|
||||
is the default.
|
||||
The documentation for the ``indexing_mode == Unblocked`` is coming.
|
||||
```
|
||||
|
||||
In conjunction with the `grid` argument, we need to provide Pallas
|
||||
the information on how to slice up the input for each invocation.
|
||||
Specifically, we need to provide a mapping between *the iteration of the loop*
|
||||
to *which block of our inputs and outputs to be operated on*.
|
||||
This is provided via {class}`jax.experimental.pallas.BlockSpec` objects.
|
||||
|
||||
Before we get into the details of `BlockSpec`s, you may want
|
||||
to revisit the
|
||||
[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example).
|
||||
|
||||
`BlockSpec`s are provided to `pallas_call` via the
|
||||
`in_specs` and `out_specs`, one for each input and output respectively.
|
||||
|
||||
Informally, the `index_map` of the `BlockSpec` takes as arguments
|
||||
the invocation indices (as many as the length of the `grid` tuple),
|
||||
and returns **block indices** (one block index for each axis of
|
||||
the overall array). Each block index is then multiplied by the
|
||||
corresponding axis size from `block_shape`
|
||||
to get the actual element index on the corresponding array axis.
|
||||
|
||||
```{note}
|
||||
This documentation applies to the case when the block shape divides
|
||||
the array shape.
|
||||
The documentation for the other cases is pending.
|
||||
```
|
||||
|
||||
More precisely, the slices for each axis of the input `x` of
|
||||
shape `x_shape` are computed as in the function `slice_for_invocation`
|
||||
below:
|
||||
|
||||
```python
|
||||
>>> def slices_for_invocation(x_shape: tuple[int, ...],
|
||||
... x_spec: pl.BlockSpec,
|
||||
... grid: tuple[int, ...],
|
||||
... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]:
|
||||
... assert len(invocation_indices) == len(grid)
|
||||
... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid))
|
||||
... block_indices = x_spec.index_map(*invocation_indices)
|
||||
... assert len(x_shape) == len(x_spec.block_shape) == len(block_indices)
|
||||
... elem_indices = []
|
||||
... for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices):
|
||||
... assert block_size <= x_size # Blocks must be smaller than the array
|
||||
... start_idx = block_idx * block_size
|
||||
... # For now, we document only the case when the entire iteration is in bounds
|
||||
... assert start_idx + block_size <= x_size
|
||||
... elem_indices.append(slice(start_idx, start_idx + block_size))
|
||||
... return elem_indices
|
||||
|
||||
```
|
||||
|
||||
For example:
|
||||
```python
|
||||
>>> slices_for_invocation(x_shape=(100, 100),
|
||||
... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)),
|
||||
... grid = (10, 5),
|
||||
... invocation_indices = (2, 3))
|
||||
[slice(20, 30, None), slice(60, 80, None)]
|
||||
|
||||
>>> # Same shape of the array and blocks, but we iterate over each block 4 times
|
||||
>>> slices_for_invocation(x_shape=(100, 100),
|
||||
... x_spec = pl.BlockSpec(lambda i, j, k: (i, j), (10, 20)),
|
||||
... grid = (10, 5, 4),
|
||||
... invocation_indices = (2, 3, 0))
|
||||
[slice(20, 30, None), slice(60, 80, None)]
|
||||
|
||||
```
|
||||
|
||||
The function `show_invocations` defined below uses Pallas to show the
|
||||
invocation indices. The `iota_2D_kernel` will fill each output block
|
||||
with a decimal number where the first digit represents the invocation
|
||||
index over the first axis, and the second the invocation index
|
||||
over the second axis:
|
||||
|
||||
```python
|
||||
>>> def show_invocations(x_shape, block_shape, grid, out_index_map=lambda i, j: (i, j)):
|
||||
... def iota_2D_kernel(o_ref):
|
||||
... axes = 0
|
||||
... for axis in range(len(grid)):
|
||||
... axes += pl.program_id(axis) * 10**(len(grid) - 1 - axis)
|
||||
... o_ref[...] = jnp.full(o_ref.shape, axes)
|
||||
... res = pl.pallas_call(iota_2D_kernel,
|
||||
... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32),
|
||||
... grid=grid,
|
||||
... in_specs=[],
|
||||
... out_specs=pl.BlockSpec(out_index_map, block_shape),
|
||||
... interpret=True)()
|
||||
... print(res)
|
||||
|
||||
```
|
||||
|
||||
For example:
|
||||
```python
|
||||
>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2))
|
||||
[[ 0 0 0 1 1 1]
|
||||
[ 0 0 0 1 1 1]
|
||||
[10 10 10 11 11 11]
|
||||
[10 10 10 11 11 11]
|
||||
[20 20 20 21 21 21]
|
||||
[20 20 20 21 21 21]
|
||||
[30 30 30 31 31 31]
|
||||
[30 30 30 31 31 31]]
|
||||
|
||||
```
|
||||
|
||||
When multiple invocations write to the same elements of the output
|
||||
array the result is platform dependent.
|
||||
|
||||
In the example below, we have a 3D grid with the last grid dimension
|
||||
not used in the block selection (`out_index_map=lambda i, j, k: (i, j)`).
|
||||
Hence, we iterate over the same output block 10 times.
|
||||
The output shown below was generated on CPU using `interpret=True`
|
||||
mode, which at the moment executes the invocation sequentially.
|
||||
On TPUs, programs are executed in a combination of parallel and sequential,
|
||||
and this function generates the output shown.
|
||||
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).
|
||||
|
||||
```python
|
||||
>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10),
|
||||
... out_index_map=lambda i, j, k: (i, j))
|
||||
[[ 9 9 9 19 19 19]
|
||||
[ 9 9 9 19 19 19]
|
||||
[109 109 109 119 119 119]
|
||||
[109 109 109 119 119 119]
|
||||
[209 209 209 219 219 219]
|
||||
[209 209 209 219 219 219]
|
||||
[309 309 309 319 319 319]
|
||||
[309 309 309 319 319 319]]
|
||||
|
||||
```
|
@ -4,13 +4,15 @@ Pallas: a JAX kernel language
|
||||
=============================
|
||||
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.
|
||||
This section contains tutorials, guides and examples for using Pallas.
|
||||
See also the :class:`jax.experimental.pallas` module API documentation.
|
||||
|
||||
.. toctree::
|
||||
:caption: Guides
|
||||
:maxdepth: 2
|
||||
|
||||
design
|
||||
quickstart
|
||||
design
|
||||
grid_blockspec
|
||||
|
||||
.. toctree::
|
||||
:caption: Platform Features
|
||||
|
@ -72,7 +72,7 @@
|
||||
"\n",
|
||||
"Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n",
|
||||
"it does not take in `jax.Array`s as inputs and doesn't return any values.\n",
|
||||
"Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
|
||||
"Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
|
||||
"but we are given an `o_ref`, which corresponds to the desired output.\n",
|
||||
"\n",
|
||||
"**Reading from `Ref`s**\n",
|
||||
@ -194,7 +194,7 @@
|
||||
"live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n",
|
||||
"that operate on \"blocks\" of those arrays that can fit in SRAM.\n",
|
||||
"\n",
|
||||
"### Grids\n",
|
||||
"### Grids by example\n",
|
||||
"\n",
|
||||
"To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n",
|
||||
"`BlockSpec`s to `pallas_call`.\n",
|
||||
@ -259,10 +259,10 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def iota(len: int):\n",
|
||||
"def iota(size: int):\n",
|
||||
" return pl.pallas_call(iota_kernel,\n",
|
||||
" out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),\n",
|
||||
" grid=(len,))()\n",
|
||||
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
|
||||
" grid=(size,))()\n",
|
||||
"iota(8)"
|
||||
]
|
||||
},
|
||||
@ -279,7 +279,9 @@
|
||||
"operations like matrix multiplications really quickly.\n",
|
||||
"\n",
|
||||
"On TPUs, programs are executed in a combination of parallel and sequential\n",
|
||||
"(depending on the architecture) so there are slightly different considerations."
|
||||
"(depending on the architecture) so there are slightly different considerations.\n",
|
||||
"\n",
|
||||
"You can read more details at {ref}`pallas_grid`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -287,7 +289,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Block specs"
|
||||
"### Block specs by example"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -385,6 +387,8 @@
|
||||
"\n",
|
||||
"These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n",
|
||||
"\n",
|
||||
"For more detail on `BlockSpec`s see {ref}`pallas_blockspec`.\n",
|
||||
"\n",
|
||||
"Underneath the hood, `pallas_call` will automatically carve up your inputs and\n",
|
||||
"outputs into `Ref`s for each block that will be passed into the kernel."
|
||||
]
|
||||
|
@ -53,7 +53,7 @@ def add_vectors_kernel(x_ref, y_ref, o_ref):
|
||||
|
||||
Let's dissect this function a bit. Unlike most JAX functions you've probably written,
|
||||
it does not take in `jax.Array`s as inputs and doesn't return any values.
|
||||
Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
|
||||
Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
|
||||
but we are given an `o_ref`, which corresponds to the desired output.
|
||||
|
||||
**Reading from `Ref`s**
|
||||
@ -133,7 +133,7 @@ Part of writing Pallas kernels is thinking about how to take big arrays that
|
||||
live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations
|
||||
that operate on "blocks" of those arrays that can fit in SRAM.
|
||||
|
||||
### Grids
|
||||
### Grids by example
|
||||
|
||||
To automatically "carve" up the inputs and outputs, you provide a `grid` and
|
||||
`BlockSpec`s to `pallas_call`.
|
||||
@ -170,10 +170,10 @@ def iota_kernel(o_ref):
|
||||
We now execute it using `pallas_call` with an additional `grid` argument.
|
||||
|
||||
```{code-cell} ipython3
|
||||
def iota(len: int):
|
||||
def iota(size: int):
|
||||
return pl.pallas_call(iota_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),
|
||||
grid=(len,))()
|
||||
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
|
||||
grid=(size,))()
|
||||
iota(8)
|
||||
```
|
||||
|
||||
@ -187,9 +187,11 @@ operations like matrix multiplications really quickly.
|
||||
On TPUs, programs are executed in a combination of parallel and sequential
|
||||
(depending on the architecture) so there are slightly different considerations.
|
||||
|
||||
You can read more details at {ref}`pallas_grid`.
|
||||
|
||||
+++
|
||||
|
||||
### Block specs
|
||||
### Block specs by example
|
||||
|
||||
+++
|
||||
|
||||
@ -279,6 +281,8 @@ Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.
|
||||
|
||||
These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.
|
||||
|
||||
For more detail on `BlockSpec`s see {ref}`pallas_blockspec`.
|
||||
|
||||
Underneath the hood, `pallas_call` will automatically carve up your inputs and
|
||||
outputs into `Ref`s for each block that will be passed into the kernel.
|
||||
|
||||
|
@ -65,7 +65,8 @@ Noteworthy properties and restrictions
|
||||
``BlockSpec``\s and grid iteration
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
``BlockSpec``\s generally behave as expected in Pallas --- every invocation of
|
||||
``BlockSpec``\s (see :ref:`pallas_blockspec`) generally behave as expected
|
||||
in Pallas --- every invocation of
|
||||
the kernel body gets access to slices of the inputs and is meant to initialize a slice
|
||||
of the output.
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
"id": "teoJ_fUwlu0l"
|
||||
},
|
||||
"source": [
|
||||
"# Pipelining and `BlockSpec`s\n",
|
||||
"# Pipelining\n",
|
||||
"\n",
|
||||
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
|
||||
]
|
||||
@ -262,83 +262,23 @@
|
||||
"It seems like a complex sequence of asynchronous data operations and\n",
|
||||
"executing kernels that would be a pain to implement manually.\n",
|
||||
"Fear not! Pallas offers an API for expressing pipelines without too much\n",
|
||||
"boilerplate, namely through `grid`s and `BlockSpec`s."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "x-LQKu8HwED7"
|
||||
},
|
||||
"source": [
|
||||
"### `grid`, a.k.a. kernels in a loop\n",
|
||||
"boilerplate, namely through `grid`s and `BlockSpec`s.\n",
|
||||
"\n",
|
||||
"See how in the above pipelined example, we are executing the same logic\n",
|
||||
"multiple times: steps 3-5 and 8-10 both execute the same operations,\n",
|
||||
"only on different inputs.\n",
|
||||
"The generalized version of this is a loop in which the same kernel is\n",
|
||||
"executed multiple times.\n",
|
||||
"`pallas_call` provides an option to do exactly that.\n",
|
||||
"The {func}`jax.experimental.pallas.pallas_call` provides a way to\n",
|
||||
"execute a kernel multiple times, by using the `grid` argument.\n",
|
||||
"See {ref}`pallas_grid`.\n",
|
||||
"\n",
|
||||
"The number of iterations in the loop is specified via the `grid` argument\n",
|
||||
"to `pallas_call`. Conceptually:\n",
|
||||
"```python\n",
|
||||
"pl.pallas_call(some_kernel, grid=n)(...)\n",
|
||||
"```\n",
|
||||
"maps to\n",
|
||||
"```python\n",
|
||||
"for i in range(n):\n",
|
||||
" # do HBM -> VMEM copies\n",
|
||||
" some_kernel(...)\n",
|
||||
" # do VMEM -> HBM copies\n",
|
||||
"```\n",
|
||||
"Grids can be generalized to be multi-dimensional, corresponding to nested\n",
|
||||
"loops. For example,\n",
|
||||
"We also use {class}`jax.experimental.pallas.BlockSpec` to specify\n",
|
||||
"how to construct the input of each kernel invocation.\n",
|
||||
"See {ref}`pallas_blockspec`.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"pl.pallas_call(some_kernel, grid=(n, m))(...)\n",
|
||||
"```\n",
|
||||
"is equivalent to\n",
|
||||
"```python\n",
|
||||
"for i in range(n):\n",
|
||||
" for j in range(m):\n",
|
||||
" # do HBM -> VMEM copies\n",
|
||||
" some_kernel(...)\n",
|
||||
" # do VMEM -> HBM copies\n",
|
||||
"```\n",
|
||||
"This generalizes to any tuple of integers (a length `d` grid will correspond\n",
|
||||
"to `d` nested loops)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hRLr5JeyyEwM"
|
||||
},
|
||||
"source": [
|
||||
"### `BlockSpec`, a.k.a. how to chunk up inputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "miWgPkytyIIa"
|
||||
},
|
||||
"source": [
|
||||
"The next piece of information we need to provide Pallas in order to\n",
|
||||
"automatically pipeline our computation is information on how to chunk it up.\n",
|
||||
"Specifically, we need to provide a mapping between *the iteration of the loop*\n",
|
||||
"to *which block of our inputs and outputs to be operated on*.\n",
|
||||
"A `BlockSpec` is exactly these two pieces of information.\n",
|
||||
"\n",
|
||||
"First we pick a `block_shape` for our inputs.\n",
|
||||
"In the pipelining example above, we had `(512, 512)`-shaped arrays and\n",
|
||||
"split them along the leading dimension into two `(256, 512)`-shaped arrays.\n",
|
||||
"In this pipeline, our `block_shape` would be `(256, 512)`.\n",
|
||||
"\n",
|
||||
"We then provide an `index_map` function that maps the iteration space to the\n",
|
||||
"blocks.\n",
|
||||
"Specifically, in the aforementioned pipeline, on the 1st iteration we'd\n",
|
||||
"In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`.\n",
|
||||
"On the 1st iteration we'd\n",
|
||||
"like to select `x1` and on the second iteration we'd like to use `x2`.\n",
|
||||
"This can be expressed with the following `index_map`:\n",
|
||||
"\n",
|
||||
|
@ -13,7 +13,7 @@ kernelspec:
|
||||
|
||||
+++ {"id": "teoJ_fUwlu0l"}
|
||||
|
||||
# Pipelining and `BlockSpec`s
|
||||
# Pipelining
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
@ -204,66 +204,21 @@ executing kernels that would be a pain to implement manually.
|
||||
Fear not! Pallas offers an API for expressing pipelines without too much
|
||||
boilerplate, namely through `grid`s and `BlockSpec`s.
|
||||
|
||||
+++ {"id": "x-LQKu8HwED7"}
|
||||
|
||||
### `grid`, a.k.a. kernels in a loop
|
||||
|
||||
See how in the above pipelined example, we are executing the same logic
|
||||
multiple times: steps 3-5 and 8-10 both execute the same operations,
|
||||
only on different inputs.
|
||||
The generalized version of this is a loop in which the same kernel is
|
||||
executed multiple times.
|
||||
`pallas_call` provides an option to do exactly that.
|
||||
The {func}`jax.experimental.pallas.pallas_call` provides a way to
|
||||
execute a kernel multiple times, by using the `grid` argument.
|
||||
See {ref}`pallas_grid`.
|
||||
|
||||
The number of iterations in the loop is specified via the `grid` argument
|
||||
to `pallas_call`. Conceptually:
|
||||
```python
|
||||
pl.pallas_call(some_kernel, grid=n)(...)
|
||||
```
|
||||
maps to
|
||||
```python
|
||||
for i in range(n):
|
||||
# do HBM -> VMEM copies
|
||||
some_kernel(...)
|
||||
# do VMEM -> HBM copies
|
||||
```
|
||||
Grids can be generalized to be multi-dimensional, corresponding to nested
|
||||
loops. For example,
|
||||
We also use {class}`jax.experimental.pallas.BlockSpec` to specify
|
||||
how to construct the input of each kernel invocation.
|
||||
See {ref}`pallas_blockspec`.
|
||||
|
||||
```python
|
||||
pl.pallas_call(some_kernel, grid=(n, m))(...)
|
||||
```
|
||||
is equivalent to
|
||||
```python
|
||||
for i in range(n):
|
||||
for j in range(m):
|
||||
# do HBM -> VMEM copies
|
||||
some_kernel(...)
|
||||
# do VMEM -> HBM copies
|
||||
```
|
||||
This generalizes to any tuple of integers (a length `d` grid will correspond
|
||||
to `d` nested loops).
|
||||
|
||||
+++ {"id": "hRLr5JeyyEwM"}
|
||||
|
||||
### `BlockSpec`, a.k.a. how to chunk up inputs
|
||||
|
||||
+++ {"id": "miWgPkytyIIa"}
|
||||
|
||||
The next piece of information we need to provide Pallas in order to
|
||||
automatically pipeline our computation is information on how to chunk it up.
|
||||
Specifically, we need to provide a mapping between *the iteration of the loop*
|
||||
to *which block of our inputs and outputs to be operated on*.
|
||||
A `BlockSpec` is exactly these two pieces of information.
|
||||
|
||||
First we pick a `block_shape` for our inputs.
|
||||
In the pipelining example above, we had `(512, 512)`-shaped arrays and
|
||||
split them along the leading dimension into two `(256, 512)`-shaped arrays.
|
||||
In this pipeline, our `block_shape` would be `(256, 512)`.
|
||||
|
||||
We then provide an `index_map` function that maps the iteration space to the
|
||||
blocks.
|
||||
Specifically, in the aforementioned pipeline, on the 1st iteration we'd
|
||||
In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`.
|
||||
On the 1st iteration we'd
|
||||
like to select `x1` and on the second iteration we'd like to use `x2`.
|
||||
This can be expressed with the following `index_map`:
|
||||
|
||||
|
@ -171,6 +171,10 @@ IndexingMode = Union[Blocked, Unblocked]
|
||||
|
||||
@dataclasses.dataclass(unsafe_hash=True)
|
||||
class BlockSpec:
|
||||
"""Specifies how an array should be sliced for each iteration of a kernel.
|
||||
|
||||
See :ref:`pallas_blockspec` for more details.
|
||||
"""
|
||||
index_map: Callable[..., Any] | None = None
|
||||
block_shape: tuple[int | None, ...] | None = None
|
||||
memory_space: Any | None = None
|
||||
|
@ -228,7 +228,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes,
|
||||
|
||||
# Pad values to evenly divide into block dimensions.
|
||||
# This allows interpret mode to catch errors on OOB memory accesses
|
||||
# by poisoning values with NaN. It also fixes an inconstency with
|
||||
# by poisoning values with NaN. It also fixes an inconsistency with
|
||||
# lax.dynamic_slice where if the slice goes out of bounds, it will instead
|
||||
# move the start_index backwards so the slice will fit in memory.
|
||||
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
|
||||
@ -1009,6 +1009,45 @@ def pallas_call(
|
||||
name: str | None = None,
|
||||
compiler_params: dict[str, Any] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Invokes a Pallas kernel on some inputs.
|
||||
|
||||
See `Pallas Quickstart <https://jax.readthedocs.io/en/latest/pallas/quickstart.html>`_.
|
||||
|
||||
Args:
|
||||
f: the kernel function, that receives a Ref for each input and output.
|
||||
The shape of the Refs are given by the ``block_shape`` in the
|
||||
corresponding ``in_specs`` and ``out_specs``.
|
||||
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
|
||||
and dtypes of the outputs.
|
||||
grid_spec: TO BE DOCUMENTED.
|
||||
debug: if True, Pallas prints various intermediate forms of the kernel
|
||||
as it is being processed.
|
||||
grid: the iteration space, as a tuple of integers. The kernel is executed
|
||||
as many times as ``prod(grid)``. The default value ``None`` is equivalent
|
||||
to ``()``.
|
||||
See details at :ref:`pallas_grid`.
|
||||
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
||||
a structure matching that of the positional arguments.
|
||||
See details at :ref:`pallas_blockspec`.
|
||||
out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
||||
a structure matching that of the outputs.
|
||||
See details at :ref:`pallas_blockspec`.
|
||||
The default value for `out_specs` specifies the whole array,
|
||||
e.g., as `pl.BlockSpec(lambda *indices: indices, x.shape)`.
|
||||
input_output_aliases: a dictionary mapping the index of some inputs to
|
||||
the index of the output that aliases them.
|
||||
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
|
||||
grid whose body is the kernel lowered as a JAX function. This does not
|
||||
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
|
||||
This is useful for debugging.
|
||||
name: TO BE DOCUMENTED.
|
||||
compiler_params: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
A function that can be called on a number of positional array arguments to
|
||||
invoke the Pallas kernel.
|
||||
|
||||
"""
|
||||
name = _extract_function_name(f, name)
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
|
@ -49,7 +49,15 @@ zip, unsafe_zip = util.safe_zip, zip
|
||||
program_id_p = jax_core.Primitive("program_id")
|
||||
|
||||
def program_id(axis: int) -> jax.Array:
|
||||
"""Returns the kernel execution position along the given axis of the grid."""
|
||||
"""Returns the kernel execution position along the given axis of the grid.
|
||||
|
||||
For example, with a 2D `grid` in the kernel execution corresponding to the
|
||||
grid coordinates `(1, 2)`,
|
||||
`program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.
|
||||
|
||||
Args:
|
||||
axis: the axis of the grid along which to count the program.
|
||||
"""
|
||||
return program_id_p.bind(axis=axis)
|
||||
|
||||
def program_id_bind(*, axis: int):
|
||||
|
@ -12,7 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for pallas, a JAX extension for custom kernels."""
|
||||
"""Module for Pallas, a JAX extension for custom kernels.
|
||||
|
||||
See the Pallas documentation at https://jax.readthedocs.io/en/latest/pallas.html.
|
||||
"""
|
||||
|
||||
from jax._src import pallas
|
||||
from jax._src.pallas.core import BlockSpec
|
||||
|
Loading…
x
Reference in New Issue
Block a user