From bfdf8f4bd3397b1316c77531b675149b451a879d Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 27 Jun 2024 11:07:26 +0300 Subject: [PATCH] [pallas] Added more documentation for grid and BlockSpec. The starting point was the text in pipelining.md, where I replaced it now with a reference to the separate grid and BlockSpec documentation. The grids and BlockSpecs are also documented in the quickstart.md, which I mostly left alone because it was good enough for a simple example. I have also attempted to add a few docstrings. --- docs/jax.experimental.pallas.rst | 23 +++ docs/jax.experimental.rst | 1 + docs/pallas/design.md | 7 +- docs/pallas/grid_blockspec.md | 213 ++++++++++++++++++++++++++++ docs/pallas/index.rst | 4 +- docs/pallas/quickstart.ipynb | 18 ++- docs/pallas/quickstart.md | 16 ++- docs/pallas/tpu/details.rst | 3 +- docs/pallas/tpu/pipelining.ipynb | 80 ++--------- docs/pallas/tpu/pipelining.md | 63 ++------ jax/_src/pallas/core.py | 4 + jax/_src/pallas/pallas_call.py | 41 +++++- jax/_src/pallas/primitives.py | 10 +- jax/experimental/pallas/__init__.py | 5 +- 14 files changed, 343 insertions(+), 145 deletions(-) create mode 100644 docs/jax.experimental.pallas.rst create mode 100644 docs/pallas/grid_blockspec.md diff --git a/docs/jax.experimental.pallas.rst b/docs/jax.experimental.pallas.rst new file mode 100644 index 000000000..f10bb2524 --- /dev/null +++ b/docs/jax.experimental.pallas.rst @@ -0,0 +1,23 @@ +``jax.experimental.pallas`` module +================================== + +.. automodule:: jax.experimental.pallas + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + + BlockSpec + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary + + pallas_call + program_id + num_programs + diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index a4d1f664a..3052e391e 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -28,6 +28,7 @@ Experimental Modules jax.experimental.mesh_utils jax.experimental.serialize_executable jax.experimental.shard_map + jax.experimental.pallas Experimental APIs ----------------- diff --git a/docs/pallas/design.md b/docs/pallas/design.md index caa1f1eb2..53a7cfb7e 100644 --- a/docs/pallas/design.md +++ b/docs/pallas/design.md @@ -286,9 +286,10 @@ The signature of `pallas_call` is as follows: ```python def pallas_call( kernel: Callable, + out_shape: Sequence[jax.ShapeDtypeStruct], + *, in_specs: Sequence[Spec], out_specs: Sequence[Spec], - out_shapes: Sequence[jax.ShapeDtypeStruct], grid: Optional[Tuple[int, ...]] = None) -> Callable: ... ``` @@ -303,9 +304,9 @@ information about how the kernel will be scheduled on the accelerator. The (rough) semantics for `pallas_call` are as follows: ```python -def pallas_call(kernel, in_specs, out_specs, out_shapes, grid): +def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid): def execute(*args): - outputs = map(empty_ref, out_shapes) + outputs = map(empty_ref, out_shape) grid_indices = map(range, grid) for indices in itertools.product(*grid_indices): # Could run in parallel! local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md new file mode 100644 index 000000000..0fbc8e602 --- /dev/null +++ b/docs/pallas/grid_blockspec.md @@ -0,0 +1,213 @@ +(pallas_grids_and_blockspecs)= + +# Grids and BlockSpecs + +(pallas_grid)= +### `grid`, a.k.a. kernels in a loop + +When using {func}`jax.experimental.pallas.pallas_call` the kernel function +is executed multiple times on different inputs, as specified via the `grid` argument +to `pallas_call`. Conceptually: +```python +pl.pallas_call(some_kernel, grid=(n,))(...) +``` +maps to +```python +for i in range(n): + some_kernel(...) +``` +Grids can be generalized to be multi-dimensional, corresponding to nested +loops. For example, + +```python +pl.pallas_call(some_kernel, grid=(n, m))(...) +``` +is equivalent to +```python +for i in range(n): + for j in range(m): + some_kernel(...) +``` +This generalizes to any tuple of integers (a length `d` grid will correspond +to `d` nested loops). +The kernel is executed as many times +as `prod(grid)`. Each of these invocations is referred to as a "program". +To access which program (i.e. which element of the grid) the kernel is currently +executing, we use {func}`jax.experimental.pallas.program_id`. +For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and +`program_id(axis=1)` returns `2`. +You can also use {func}`jax.experimental.pallas.num_programs` to get the +grid size for a given axis. + +Here's an example kernel that uses a `grid` and `program_id`. + +```python +>>> import jax +>>> from jax.experimental import pallas as pl + +>>> def iota_kernel(o_ref): +... i = pl.program_id(0) +... o_ref[i] = i + +``` + +We now execute it using `pallas_call` with an additional `grid` argument. + +```python +>>> def iota(size: int): +... return pl.pallas_call(iota_kernel, +... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), +... grid=(size,), interpret=True)() +>>> iota(8) +Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) + +``` + +On GPUs, each program is executed in parallel on separate thread blocks. +Thus, we need to think about race conditions on writes to HBM. +A reasonable approach is to write our kernels in such a way that different +programs write to disjoint places in HBM to avoid these parallel writes. + +On TPUs, programs are executed in a combination of parallel and sequential +(depending on the architecture) so there are slightly different considerations. +See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). + +(pallas_blockspec)= + +### `BlockSpec`, a.k.a. how to chunk up inputs + +```{note} +The documentation here applies to the ``indexing_mode == Blocked``, which +is the default. +The documentation for the ``indexing_mode == Unblocked`` is coming. +``` + +In conjunction with the `grid` argument, we need to provide Pallas +the information on how to slice up the input for each invocation. +Specifically, we need to provide a mapping between *the iteration of the loop* +to *which block of our inputs and outputs to be operated on*. +This is provided via {class}`jax.experimental.pallas.BlockSpec` objects. + +Before we get into the details of `BlockSpec`s, you may want +to revisit the +[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example). + +`BlockSpec`s are provided to `pallas_call` via the +`in_specs` and `out_specs`, one for each input and output respectively. + +Informally, the `index_map` of the `BlockSpec` takes as arguments +the invocation indices (as many as the length of the `grid` tuple), +and returns **block indices** (one block index for each axis of +the overall array). Each block index is then multiplied by the +corresponding axis size from `block_shape` +to get the actual element index on the corresponding array axis. + +```{note} +This documentation applies to the case when the block shape divides +the array shape. +The documentation for the other cases is pending. +``` + +More precisely, the slices for each axis of the input `x` of +shape `x_shape` are computed as in the function `slice_for_invocation` +below: + +```python +>>> def slices_for_invocation(x_shape: tuple[int, ...], +... x_spec: pl.BlockSpec, +... grid: tuple[int, ...], +... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]: +... assert len(invocation_indices) == len(grid) +... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid)) +... block_indices = x_spec.index_map(*invocation_indices) +... assert len(x_shape) == len(x_spec.block_shape) == len(block_indices) +... elem_indices = [] +... for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices): +... assert block_size <= x_size # Blocks must be smaller than the array +... start_idx = block_idx * block_size +... # For now, we document only the case when the entire iteration is in bounds +... assert start_idx + block_size <= x_size +... elem_indices.append(slice(start_idx, start_idx + block_size)) +... return elem_indices + +``` + +For example: +```python +>>> slices_for_invocation(x_shape=(100, 100), +... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)), +... grid = (10, 5), +... invocation_indices = (2, 3)) +[slice(20, 30, None), slice(60, 80, None)] + +>>> # Same shape of the array and blocks, but we iterate over each block 4 times +>>> slices_for_invocation(x_shape=(100, 100), +... x_spec = pl.BlockSpec(lambda i, j, k: (i, j), (10, 20)), +... grid = (10, 5, 4), +... invocation_indices = (2, 3, 0)) +[slice(20, 30, None), slice(60, 80, None)] + +``` + +The function `show_invocations` defined below uses Pallas to show the +invocation indices. The `iota_2D_kernel` will fill each output block +with a decimal number where the first digit represents the invocation +index over the first axis, and the second the invocation index +over the second axis: + +```python +>>> def show_invocations(x_shape, block_shape, grid, out_index_map=lambda i, j: (i, j)): +... def iota_2D_kernel(o_ref): +... axes = 0 +... for axis in range(len(grid)): +... axes += pl.program_id(axis) * 10**(len(grid) - 1 - axis) +... o_ref[...] = jnp.full(o_ref.shape, axes) +... res = pl.pallas_call(iota_2D_kernel, +... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32), +... grid=grid, +... in_specs=[], +... out_specs=pl.BlockSpec(out_index_map, block_shape), +... interpret=True)() +... print(res) + +``` + +For example: +```python +>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2)) +[[ 0 0 0 1 1 1] + [ 0 0 0 1 1 1] + [10 10 10 11 11 11] + [10 10 10 11 11 11] + [20 20 20 21 21 21] + [20 20 20 21 21 21] + [30 30 30 31 31 31] + [30 30 30 31 31 31]] + +``` + +When multiple invocations write to the same elements of the output +array the result is platform dependent. + +In the example below, we have a 3D grid with the last grid dimension +not used in the block selection (`out_index_map=lambda i, j, k: (i, j)`). +Hence, we iterate over the same output block 10 times. +The output shown below was generated on CPU using `interpret=True` +mode, which at the moment executes the invocation sequentially. +On TPUs, programs are executed in a combination of parallel and sequential, +and this function generates the output shown. +See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). + +```python +>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10), +... out_index_map=lambda i, j, k: (i, j)) +[[ 9 9 9 19 19 19] + [ 9 9 9 19 19 19] + [109 109 109 119 119 119] + [109 109 109 119 119 119] + [209 209 209 219 219 219] + [209 209 209 219 219 219] + [309 309 309 319 319 319] + [309 309 309 319 319 319]] + +``` diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index bd086bd47..9fbb560d1 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -4,13 +4,15 @@ Pallas: a JAX kernel language ============================= Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. This section contains tutorials, guides and examples for using Pallas. +See also the :class:`jax.experimental.pallas` module API documentation. .. toctree:: :caption: Guides :maxdepth: 2 - design quickstart + design + grid_blockspec .. toctree:: :caption: Platform Features diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 47d1a1409..ea3209a57 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -72,7 +72,7 @@ "\n", "Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n", "it does not take in `jax.Array`s as inputs and doesn't return any values.\n", - "Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n", + "Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n", "but we are given an `o_ref`, which corresponds to the desired output.\n", "\n", "**Reading from `Ref`s**\n", @@ -194,7 +194,7 @@ "live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n", "that operate on \"blocks\" of those arrays that can fit in SRAM.\n", "\n", - "### Grids\n", + "### Grids by example\n", "\n", "To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n", "`BlockSpec`s to `pallas_call`.\n", @@ -259,10 +259,10 @@ } ], "source": [ - "def iota(len: int):\n", + "def iota(size: int):\n", " return pl.pallas_call(iota_kernel,\n", - " out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),\n", - " grid=(len,))()\n", + " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", + " grid=(size,))()\n", "iota(8)" ] }, @@ -279,7 +279,9 @@ "operations like matrix multiplications really quickly.\n", "\n", "On TPUs, programs are executed in a combination of parallel and sequential\n", - "(depending on the architecture) so there are slightly different considerations." + "(depending on the architecture) so there are slightly different considerations.\n", + "\n", + "You can read more details at {ref}`pallas_grid`." ] }, { @@ -287,7 +289,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Block specs" + "### Block specs by example" ] }, { @@ -385,6 +387,8 @@ "\n", "These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n", "\n", + "For more detail on `BlockSpec`s see {ref}`pallas_blockspec`.\n", + "\n", "Underneath the hood, `pallas_call` will automatically carve up your inputs and\n", "outputs into `Ref`s for each block that will be passed into the kernel." ] diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index d42459f89..05685d0f6 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -53,7 +53,7 @@ def add_vectors_kernel(x_ref, y_ref, o_ref): Let's dissect this function a bit. Unlike most JAX functions you've probably written, it does not take in `jax.Array`s as inputs and doesn't return any values. -Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs +Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs but we are given an `o_ref`, which corresponds to the desired output. **Reading from `Ref`s** @@ -133,7 +133,7 @@ Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on "blocks" of those arrays that can fit in SRAM. -### Grids +### Grids by example To automatically "carve" up the inputs and outputs, you provide a `grid` and `BlockSpec`s to `pallas_call`. @@ -170,10 +170,10 @@ def iota_kernel(o_ref): We now execute it using `pallas_call` with an additional `grid` argument. ```{code-cell} ipython3 -def iota(len: int): +def iota(size: int): return pl.pallas_call(iota_kernel, - out_shape=jax.ShapeDtypeStruct((len,), jnp.int32), - grid=(len,))() + out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), + grid=(size,))() iota(8) ``` @@ -187,9 +187,11 @@ operations like matrix multiplications really quickly. On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. +You can read more details at {ref}`pallas_grid`. + +++ -### Block specs +### Block specs by example +++ @@ -279,6 +281,8 @@ Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`. These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`. +For more detail on `BlockSpec`s see {ref}`pallas_blockspec`. + Underneath the hood, `pallas_call` will automatically carve up your inputs and outputs into `Ref`s for each block that will be passed into the kernel. diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index 58ecebad4..718ade0c7 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -65,7 +65,8 @@ Noteworthy properties and restrictions ``BlockSpec``\s and grid iteration ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -``BlockSpec``\s generally behave as expected in Pallas --- every invocation of +``BlockSpec``\s (see :ref:`pallas_blockspec`) generally behave as expected +in Pallas --- every invocation of the kernel body gets access to slices of the inputs and is meant to initialize a slice of the output. diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 06bbe9135..ed37f88a0 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -6,7 +6,7 @@ "id": "teoJ_fUwlu0l" }, "source": [ - "# Pipelining and `BlockSpec`s\n", + "# Pipelining\n", "\n", "" ] @@ -262,83 +262,23 @@ "It seems like a complex sequence of asynchronous data operations and\n", "executing kernels that would be a pain to implement manually.\n", "Fear not! Pallas offers an API for expressing pipelines without too much\n", - "boilerplate, namely through `grid`s and `BlockSpec`s." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "x-LQKu8HwED7" - }, - "source": [ - "### `grid`, a.k.a. kernels in a loop\n", + "boilerplate, namely through `grid`s and `BlockSpec`s.\n", "\n", "See how in the above pipelined example, we are executing the same logic\n", "multiple times: steps 3-5 and 8-10 both execute the same operations,\n", "only on different inputs.\n", - "The generalized version of this is a loop in which the same kernel is\n", - "executed multiple times.\n", - "`pallas_call` provides an option to do exactly that.\n", + "The {func}`jax.experimental.pallas.pallas_call` provides a way to\n", + "execute a kernel multiple times, by using the `grid` argument.\n", + "See {ref}`pallas_grid`.\n", "\n", - "The number of iterations in the loop is specified via the `grid` argument\n", - "to `pallas_call`. Conceptually:\n", - "```python\n", - "pl.pallas_call(some_kernel, grid=n)(...)\n", - "```\n", - "maps to\n", - "```python\n", - "for i in range(n):\n", - " # do HBM -> VMEM copies\n", - " some_kernel(...)\n", - " # do VMEM -> HBM copies\n", - "```\n", - "Grids can be generalized to be multi-dimensional, corresponding to nested\n", - "loops. For example,\n", + "We also use {class}`jax.experimental.pallas.BlockSpec` to specify\n", + "how to construct the input of each kernel invocation.\n", + "See {ref}`pallas_blockspec`.\n", "\n", - "```python\n", - "pl.pallas_call(some_kernel, grid=(n, m))(...)\n", - "```\n", - "is equivalent to\n", - "```python\n", - "for i in range(n):\n", - " for j in range(m):\n", - " # do HBM -> VMEM copies\n", - " some_kernel(...)\n", - " # do VMEM -> HBM copies\n", - "```\n", - "This generalizes to any tuple of integers (a length `d` grid will correspond\n", - "to `d` nested loops)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hRLr5JeyyEwM" - }, - "source": [ - "### `BlockSpec`, a.k.a. how to chunk up inputs" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "miWgPkytyIIa" - }, - "source": [ - "The next piece of information we need to provide Pallas in order to\n", - "automatically pipeline our computation is information on how to chunk it up.\n", - "Specifically, we need to provide a mapping between *the iteration of the loop*\n", - "to *which block of our inputs and outputs to be operated on*.\n", - "A `BlockSpec` is exactly these two pieces of information.\n", - "\n", - "First we pick a `block_shape` for our inputs.\n", "In the pipelining example above, we had `(512, 512)`-shaped arrays and\n", "split them along the leading dimension into two `(256, 512)`-shaped arrays.\n", - "In this pipeline, our `block_shape` would be `(256, 512)`.\n", - "\n", - "We then provide an `index_map` function that maps the iteration space to the\n", - "blocks.\n", - "Specifically, in the aforementioned pipeline, on the 1st iteration we'd\n", + "In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`.\n", + "On the 1st iteration we'd\n", "like to select `x1` and on the second iteration we'd like to use `x2`.\n", "This can be expressed with the following `index_map`:\n", "\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 77d029229..cab4d6b28 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "teoJ_fUwlu0l"} -# Pipelining and `BlockSpec`s +# Pipelining @@ -204,66 +204,21 @@ executing kernels that would be a pain to implement manually. Fear not! Pallas offers an API for expressing pipelines without too much boilerplate, namely through `grid`s and `BlockSpec`s. -+++ {"id": "x-LQKu8HwED7"} - -### `grid`, a.k.a. kernels in a loop - See how in the above pipelined example, we are executing the same logic multiple times: steps 3-5 and 8-10 both execute the same operations, only on different inputs. -The generalized version of this is a loop in which the same kernel is -executed multiple times. -`pallas_call` provides an option to do exactly that. +The {func}`jax.experimental.pallas.pallas_call` provides a way to +execute a kernel multiple times, by using the `grid` argument. +See {ref}`pallas_grid`. -The number of iterations in the loop is specified via the `grid` argument -to `pallas_call`. Conceptually: -```python -pl.pallas_call(some_kernel, grid=n)(...) -``` -maps to -```python -for i in range(n): - # do HBM -> VMEM copies - some_kernel(...) - # do VMEM -> HBM copies -``` -Grids can be generalized to be multi-dimensional, corresponding to nested -loops. For example, +We also use {class}`jax.experimental.pallas.BlockSpec` to specify +how to construct the input of each kernel invocation. +See {ref}`pallas_blockspec`. -```python -pl.pallas_call(some_kernel, grid=(n, m))(...) -``` -is equivalent to -```python -for i in range(n): - for j in range(m): - # do HBM -> VMEM copies - some_kernel(...) - # do VMEM -> HBM copies -``` -This generalizes to any tuple of integers (a length `d` grid will correspond -to `d` nested loops). - -+++ {"id": "hRLr5JeyyEwM"} - -### `BlockSpec`, a.k.a. how to chunk up inputs - -+++ {"id": "miWgPkytyIIa"} - -The next piece of information we need to provide Pallas in order to -automatically pipeline our computation is information on how to chunk it up. -Specifically, we need to provide a mapping between *the iteration of the loop* -to *which block of our inputs and outputs to be operated on*. -A `BlockSpec` is exactly these two pieces of information. - -First we pick a `block_shape` for our inputs. In the pipelining example above, we had `(512, 512)`-shaped arrays and split them along the leading dimension into two `(256, 512)`-shaped arrays. -In this pipeline, our `block_shape` would be `(256, 512)`. - -We then provide an `index_map` function that maps the iteration space to the -blocks. -Specifically, in the aforementioned pipeline, on the 1st iteration we'd +In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`. +On the 1st iteration we'd like to select `x1` and on the second iteration we'd like to use `x2`. This can be expressed with the following `index_map`: diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 38866b082..006ff50be 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -171,6 +171,10 @@ IndexingMode = Union[Blocked, Unblocked] @dataclasses.dataclass(unsafe_hash=True) class BlockSpec: + """Specifies how an array should be sliced for each iteration of a kernel. + + See :ref:`pallas_blockspec` for more details. + """ index_map: Callable[..., Any] | None = None block_shape: tuple[int | None, ...] | None = None memory_space: Any | None = None diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 0748e78a2..ea61f1c7f 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -228,7 +228,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, # Pad values to evenly divide into block dimensions. # This allows interpret mode to catch errors on OOB memory accesses - # by poisoning values with NaN. It also fixes an inconstency with + # by poisoning values with NaN. It also fixes an inconsistency with # lax.dynamic_slice where if the slice goes out of bounds, it will instead # move the start_index backwards so the slice will fit in memory. carry = map(_pad_values_to_block_dimension, carry, block_shapes) @@ -1009,6 +1009,45 @@ def pallas_call( name: str | None = None, compiler_params: dict[str, Any] | None = None, ) -> Callable[..., Any]: + """Invokes a Pallas kernel on some inputs. + + See `Pallas Quickstart `_. + + Args: + f: the kernel function, that receives a Ref for each input and output. + The shape of the Refs are given by the ``block_shape`` in the + corresponding ``in_specs`` and ``out_specs``. + out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape + and dtypes of the outputs. + grid_spec: TO BE DOCUMENTED. + debug: if True, Pallas prints various intermediate forms of the kernel + as it is being processed. + grid: the iteration space, as a tuple of integers. The kernel is executed + as many times as ``prod(grid)``. The default value ``None`` is equivalent + to ``()``. + See details at :ref:`pallas_grid`. + in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with + a structure matching that of the positional arguments. + See details at :ref:`pallas_blockspec`. + out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with + a structure matching that of the outputs. + See details at :ref:`pallas_blockspec`. + The default value for `out_specs` specifies the whole array, + e.g., as `pl.BlockSpec(lambda *indices: indices, x.shape)`. + input_output_aliases: a dictionary mapping the index of some inputs to + the index of the output that aliases them. + interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the + grid whose body is the kernel lowered as a JAX function. This does not + require a TPU or a GPU, and is the only way to run Pallas kernels on CPU. + This is useful for debugging. + name: TO BE DOCUMENTED. + compiler_params: TO BE DOCUMENTED. + + Returns: + A function that can be called on a number of positional array arguments to + invoke the Pallas kernel. + + """ name = _extract_function_name(f, name) if compiler_params is None: compiler_params = {} diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 4abc5ced1..ce87f2bc0 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -49,7 +49,15 @@ zip, unsafe_zip = util.safe_zip, zip program_id_p = jax_core.Primitive("program_id") def program_id(axis: int) -> jax.Array: - """Returns the kernel execution position along the given axis of the grid.""" + """Returns the kernel execution position along the given axis of the grid. + + For example, with a 2D `grid` in the kernel execution corresponding to the + grid coordinates `(1, 2)`, + `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. + + Args: + axis: the axis of the grid along which to count the program. + """ return program_id_p.bind(axis=axis) def program_id_bind(*, axis: int): diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 7d688bb4d..65fd4f466 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Module for pallas, a JAX extension for custom kernels.""" +"""Module for Pallas, a JAX extension for custom kernels. + +See the Pallas documentation at https://jax.readthedocs.io/en/latest/pallas.html. +""" from jax._src import pallas from jax._src.pallas.core import BlockSpec