[pallas] Added more documentation for grid and BlockSpec.

The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.

The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.

I have also attempted to add a few docstrings.
This commit is contained in:
George Necula 2024-06-27 11:07:26 +03:00
parent 945b1c3b8a
commit bfdf8f4bd3
14 changed files with 343 additions and 145 deletions

View File

@ -0,0 +1,23 @@
``jax.experimental.pallas`` module
==================================
.. automodule:: jax.experimental.pallas
Classes
-------
.. autosummary::
:toctree: _autosummary
BlockSpec
Functions
---------
.. autosummary::
:toctree: _autosummary
pallas_call
program_id
num_programs

View File

@ -28,6 +28,7 @@ Experimental Modules
jax.experimental.mesh_utils
jax.experimental.serialize_executable
jax.experimental.shard_map
jax.experimental.pallas
Experimental APIs
-----------------

View File

@ -286,9 +286,10 @@ The signature of `pallas_call` is as follows:
```python
def pallas_call(
kernel: Callable,
out_shape: Sequence[jax.ShapeDtypeStruct],
*,
in_specs: Sequence[Spec],
out_specs: Sequence[Spec],
out_shapes: Sequence[jax.ShapeDtypeStruct],
grid: Optional[Tuple[int, ...]] = None) -> Callable:
...
```
@ -303,9 +304,9 @@ information about how the kernel will be scheduled on the accelerator.
The (rough) semantics for `pallas_call` are as follows:
```python
def pallas_call(kernel, in_specs, out_specs, out_shapes, grid):
def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid):
def execute(*args):
outputs = map(empty_ref, out_shapes)
outputs = map(empty_ref, out_shape)
grid_indices = map(range, grid)
for indices in itertools.product(*grid_indices): # Could run in parallel!
local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in

View File

@ -0,0 +1,213 @@
(pallas_grids_and_blockspecs)=
# Grids and BlockSpecs
(pallas_grid)=
### `grid`, a.k.a. kernels in a loop
When using {func}`jax.experimental.pallas.pallas_call` the kernel function
is executed multiple times on different inputs, as specified via the `grid` argument
to `pallas_call`. Conceptually:
```python
pl.pallas_call(some_kernel, grid=(n,))(...)
```
maps to
```python
for i in range(n):
some_kernel(...)
```
Grids can be generalized to be multi-dimensional, corresponding to nested
loops. For example,
```python
pl.pallas_call(some_kernel, grid=(n, m))(...)
```
is equivalent to
```python
for i in range(n):
for j in range(m):
some_kernel(...)
```
This generalizes to any tuple of integers (a length `d` grid will correspond
to `d` nested loops).
The kernel is executed as many times
as `prod(grid)`. Each of these invocations is referred to as a "program".
To access which program (i.e. which element of the grid) the kernel is currently
executing, we use {func}`jax.experimental.pallas.program_id`.
For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
`program_id(axis=1)` returns `2`.
You can also use {func}`jax.experimental.pallas.num_programs` to get the
grid size for a given axis.
Here's an example kernel that uses a `grid` and `program_id`.
```python
>>> import jax
>>> from jax.experimental import pallas as pl
>>> def iota_kernel(o_ref):
... i = pl.program_id(0)
... o_ref[i] = i
```
We now execute it using `pallas_call` with an additional `grid` argument.
```python
>>> def iota(size: int):
... return pl.pallas_call(iota_kernel,
... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
... grid=(size,), interpret=True)()
>>> iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
```
On GPUs, each program is executed in parallel on separate thread blocks.
Thus, we need to think about race conditions on writes to HBM.
A reasonable approach is to write our kernels in such a way that different
programs write to disjoint places in HBM to avoid these parallel writes.
On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).
(pallas_blockspec)=
### `BlockSpec`, a.k.a. how to chunk up inputs
```{note}
The documentation here applies to the ``indexing_mode == Blocked``, which
is the default.
The documentation for the ``indexing_mode == Unblocked`` is coming.
```
In conjunction with the `grid` argument, we need to provide Pallas
the information on how to slice up the input for each invocation.
Specifically, we need to provide a mapping between *the iteration of the loop*
to *which block of our inputs and outputs to be operated on*.
This is provided via {class}`jax.experimental.pallas.BlockSpec` objects.
Before we get into the details of `BlockSpec`s, you may want
to revisit the
[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example).
`BlockSpec`s are provided to `pallas_call` via the
`in_specs` and `out_specs`, one for each input and output respectively.
Informally, the `index_map` of the `BlockSpec` takes as arguments
the invocation indices (as many as the length of the `grid` tuple),
and returns **block indices** (one block index for each axis of
the overall array). Each block index is then multiplied by the
corresponding axis size from `block_shape`
to get the actual element index on the corresponding array axis.
```{note}
This documentation applies to the case when the block shape divides
the array shape.
The documentation for the other cases is pending.
```
More precisely, the slices for each axis of the input `x` of
shape `x_shape` are computed as in the function `slice_for_invocation`
below:
```python
>>> def slices_for_invocation(x_shape: tuple[int, ...],
... x_spec: pl.BlockSpec,
... grid: tuple[int, ...],
... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]:
... assert len(invocation_indices) == len(grid)
... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid))
... block_indices = x_spec.index_map(*invocation_indices)
... assert len(x_shape) == len(x_spec.block_shape) == len(block_indices)
... elem_indices = []
... for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices):
... assert block_size <= x_size # Blocks must be smaller than the array
... start_idx = block_idx * block_size
... # For now, we document only the case when the entire iteration is in bounds
... assert start_idx + block_size <= x_size
... elem_indices.append(slice(start_idx, start_idx + block_size))
... return elem_indices
```
For example:
```python
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)),
... grid = (10, 5),
... invocation_indices = (2, 3))
[slice(20, 30, None), slice(60, 80, None)]
>>> # Same shape of the array and blocks, but we iterate over each block 4 times
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec(lambda i, j, k: (i, j), (10, 20)),
... grid = (10, 5, 4),
... invocation_indices = (2, 3, 0))
[slice(20, 30, None), slice(60, 80, None)]
```
The function `show_invocations` defined below uses Pallas to show the
invocation indices. The `iota_2D_kernel` will fill each output block
with a decimal number where the first digit represents the invocation
index over the first axis, and the second the invocation index
over the second axis:
```python
>>> def show_invocations(x_shape, block_shape, grid, out_index_map=lambda i, j: (i, j)):
... def iota_2D_kernel(o_ref):
... axes = 0
... for axis in range(len(grid)):
... axes += pl.program_id(axis) * 10**(len(grid) - 1 - axis)
... o_ref[...] = jnp.full(o_ref.shape, axes)
... res = pl.pallas_call(iota_2D_kernel,
... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32),
... grid=grid,
... in_specs=[],
... out_specs=pl.BlockSpec(out_index_map, block_shape),
... interpret=True)()
... print(res)
```
For example:
```python
>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2))
[[ 0 0 0 1 1 1]
[ 0 0 0 1 1 1]
[10 10 10 11 11 11]
[10 10 10 11 11 11]
[20 20 20 21 21 21]
[20 20 20 21 21 21]
[30 30 30 31 31 31]
[30 30 30 31 31 31]]
```
When multiple invocations write to the same elements of the output
array the result is platform dependent.
In the example below, we have a 3D grid with the last grid dimension
not used in the block selection (`out_index_map=lambda i, j, k: (i, j)`).
Hence, we iterate over the same output block 10 times.
The output shown below was generated on CPU using `interpret=True`
mode, which at the moment executes the invocation sequentially.
On TPUs, programs are executed in a combination of parallel and sequential,
and this function generates the output shown.
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).
```python
>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10),
... out_index_map=lambda i, j, k: (i, j))
[[ 9 9 9 19 19 19]
[ 9 9 9 19 19 19]
[109 109 109 119 119 119]
[109 109 109 119 119 119]
[209 209 209 219 219 219]
[209 209 209 219 219 219]
[309 309 309 319 319 319]
[309 309 309 319 319 319]]
```

View File

@ -4,13 +4,15 @@ Pallas: a JAX kernel language
=============================
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.
This section contains tutorials, guides and examples for using Pallas.
See also the :class:`jax.experimental.pallas` module API documentation.
.. toctree::
:caption: Guides
:maxdepth: 2
design
quickstart
design
grid_blockspec
.. toctree::
:caption: Platform Features

View File

@ -72,7 +72,7 @@
"\n",
"Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n",
"it does not take in `jax.Array`s as inputs and doesn't return any values.\n",
"Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
"Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
"but we are given an `o_ref`, which corresponds to the desired output.\n",
"\n",
"**Reading from `Ref`s**\n",
@ -194,7 +194,7 @@
"live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n",
"that operate on \"blocks\" of those arrays that can fit in SRAM.\n",
"\n",
"### Grids\n",
"### Grids by example\n",
"\n",
"To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n",
"`BlockSpec`s to `pallas_call`.\n",
@ -259,10 +259,10 @@
}
],
"source": [
"def iota(len: int):\n",
"def iota(size: int):\n",
" return pl.pallas_call(iota_kernel,\n",
" out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),\n",
" grid=(len,))()\n",
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
" grid=(size,))()\n",
"iota(8)"
]
},
@ -279,7 +279,9 @@
"operations like matrix multiplications really quickly.\n",
"\n",
"On TPUs, programs are executed in a combination of parallel and sequential\n",
"(depending on the architecture) so there are slightly different considerations."
"(depending on the architecture) so there are slightly different considerations.\n",
"\n",
"You can read more details at {ref}`pallas_grid`."
]
},
{
@ -287,7 +289,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Block specs"
"### Block specs by example"
]
},
{
@ -385,6 +387,8 @@
"\n",
"These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n",
"\n",
"For more detail on `BlockSpec`s see {ref}`pallas_blockspec`.\n",
"\n",
"Underneath the hood, `pallas_call` will automatically carve up your inputs and\n",
"outputs into `Ref`s for each block that will be passed into the kernel."
]

View File

@ -53,7 +53,7 @@ def add_vectors_kernel(x_ref, y_ref, o_ref):
Let's dissect this function a bit. Unlike most JAX functions you've probably written,
it does not take in `jax.Array`s as inputs and doesn't return any values.
Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
but we are given an `o_ref`, which corresponds to the desired output.
**Reading from `Ref`s**
@ -133,7 +133,7 @@ Part of writing Pallas kernels is thinking about how to take big arrays that
live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations
that operate on "blocks" of those arrays that can fit in SRAM.
### Grids
### Grids by example
To automatically "carve" up the inputs and outputs, you provide a `grid` and
`BlockSpec`s to `pallas_call`.
@ -170,10 +170,10 @@ def iota_kernel(o_ref):
We now execute it using `pallas_call` with an additional `grid` argument.
```{code-cell} ipython3
def iota(len: int):
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),
grid=(len,))()
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
```
@ -187,9 +187,11 @@ operations like matrix multiplications really quickly.
On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.
You can read more details at {ref}`pallas_grid`.
+++
### Block specs
### Block specs by example
+++
@ -279,6 +281,8 @@ Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.
These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.
For more detail on `BlockSpec`s see {ref}`pallas_blockspec`.
Underneath the hood, `pallas_call` will automatically carve up your inputs and
outputs into `Ref`s for each block that will be passed into the kernel.

View File

@ -65,7 +65,8 @@ Noteworthy properties and restrictions
``BlockSpec``\s and grid iteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``BlockSpec``\s generally behave as expected in Pallas --- every invocation of
``BlockSpec``\s (see :ref:`pallas_blockspec`) generally behave as expected
in Pallas --- every invocation of
the kernel body gets access to slices of the inputs and is meant to initialize a slice
of the output.

View File

@ -6,7 +6,7 @@
"id": "teoJ_fUwlu0l"
},
"source": [
"# Pipelining and `BlockSpec`s\n",
"# Pipelining\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->"
]
@ -262,83 +262,23 @@
"It seems like a complex sequence of asynchronous data operations and\n",
"executing kernels that would be a pain to implement manually.\n",
"Fear not! Pallas offers an API for expressing pipelines without too much\n",
"boilerplate, namely through `grid`s and `BlockSpec`s."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x-LQKu8HwED7"
},
"source": [
"### `grid`, a.k.a. kernels in a loop\n",
"boilerplate, namely through `grid`s and `BlockSpec`s.\n",
"\n",
"See how in the above pipelined example, we are executing the same logic\n",
"multiple times: steps 3-5 and 8-10 both execute the same operations,\n",
"only on different inputs.\n",
"The generalized version of this is a loop in which the same kernel is\n",
"executed multiple times.\n",
"`pallas_call` provides an option to do exactly that.\n",
"The {func}`jax.experimental.pallas.pallas_call` provides a way to\n",
"execute a kernel multiple times, by using the `grid` argument.\n",
"See {ref}`pallas_grid`.\n",
"\n",
"The number of iterations in the loop is specified via the `grid` argument\n",
"to `pallas_call`. Conceptually:\n",
"```python\n",
"pl.pallas_call(some_kernel, grid=n)(...)\n",
"```\n",
"maps to\n",
"```python\n",
"for i in range(n):\n",
" # do HBM -> VMEM copies\n",
" some_kernel(...)\n",
" # do VMEM -> HBM copies\n",
"```\n",
"Grids can be generalized to be multi-dimensional, corresponding to nested\n",
"loops. For example,\n",
"We also use {class}`jax.experimental.pallas.BlockSpec` to specify\n",
"how to construct the input of each kernel invocation.\n",
"See {ref}`pallas_blockspec`.\n",
"\n",
"```python\n",
"pl.pallas_call(some_kernel, grid=(n, m))(...)\n",
"```\n",
"is equivalent to\n",
"```python\n",
"for i in range(n):\n",
" for j in range(m):\n",
" # do HBM -> VMEM copies\n",
" some_kernel(...)\n",
" # do VMEM -> HBM copies\n",
"```\n",
"This generalizes to any tuple of integers (a length `d` grid will correspond\n",
"to `d` nested loops)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hRLr5JeyyEwM"
},
"source": [
"### `BlockSpec`, a.k.a. how to chunk up inputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "miWgPkytyIIa"
},
"source": [
"The next piece of information we need to provide Pallas in order to\n",
"automatically pipeline our computation is information on how to chunk it up.\n",
"Specifically, we need to provide a mapping between *the iteration of the loop*\n",
"to *which block of our inputs and outputs to be operated on*.\n",
"A `BlockSpec` is exactly these two pieces of information.\n",
"\n",
"First we pick a `block_shape` for our inputs.\n",
"In the pipelining example above, we had `(512, 512)`-shaped arrays and\n",
"split them along the leading dimension into two `(256, 512)`-shaped arrays.\n",
"In this pipeline, our `block_shape` would be `(256, 512)`.\n",
"\n",
"We then provide an `index_map` function that maps the iteration space to the\n",
"blocks.\n",
"Specifically, in the aforementioned pipeline, on the 1st iteration we'd\n",
"In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`.\n",
"On the 1st iteration we'd\n",
"like to select `x1` and on the second iteration we'd like to use `x2`.\n",
"This can be expressed with the following `index_map`:\n",
"\n",

View File

@ -13,7 +13,7 @@ kernelspec:
+++ {"id": "teoJ_fUwlu0l"}
# Pipelining and `BlockSpec`s
# Pipelining
<!--* freshness: { reviewed: '2024-04-08' } *-->
@ -204,66 +204,21 @@ executing kernels that would be a pain to implement manually.
Fear not! Pallas offers an API for expressing pipelines without too much
boilerplate, namely through `grid`s and `BlockSpec`s.
+++ {"id": "x-LQKu8HwED7"}
### `grid`, a.k.a. kernels in a loop
See how in the above pipelined example, we are executing the same logic
multiple times: steps 3-5 and 8-10 both execute the same operations,
only on different inputs.
The generalized version of this is a loop in which the same kernel is
executed multiple times.
`pallas_call` provides an option to do exactly that.
The {func}`jax.experimental.pallas.pallas_call` provides a way to
execute a kernel multiple times, by using the `grid` argument.
See {ref}`pallas_grid`.
The number of iterations in the loop is specified via the `grid` argument
to `pallas_call`. Conceptually:
```python
pl.pallas_call(some_kernel, grid=n)(...)
```
maps to
```python
for i in range(n):
# do HBM -> VMEM copies
some_kernel(...)
# do VMEM -> HBM copies
```
Grids can be generalized to be multi-dimensional, corresponding to nested
loops. For example,
We also use {class}`jax.experimental.pallas.BlockSpec` to specify
how to construct the input of each kernel invocation.
See {ref}`pallas_blockspec`.
```python
pl.pallas_call(some_kernel, grid=(n, m))(...)
```
is equivalent to
```python
for i in range(n):
for j in range(m):
# do HBM -> VMEM copies
some_kernel(...)
# do VMEM -> HBM copies
```
This generalizes to any tuple of integers (a length `d` grid will correspond
to `d` nested loops).
+++ {"id": "hRLr5JeyyEwM"}
### `BlockSpec`, a.k.a. how to chunk up inputs
+++ {"id": "miWgPkytyIIa"}
The next piece of information we need to provide Pallas in order to
automatically pipeline our computation is information on how to chunk it up.
Specifically, we need to provide a mapping between *the iteration of the loop*
to *which block of our inputs and outputs to be operated on*.
A `BlockSpec` is exactly these two pieces of information.
First we pick a `block_shape` for our inputs.
In the pipelining example above, we had `(512, 512)`-shaped arrays and
split them along the leading dimension into two `(256, 512)`-shaped arrays.
In this pipeline, our `block_shape` would be `(256, 512)`.
We then provide an `index_map` function that maps the iteration space to the
blocks.
Specifically, in the aforementioned pipeline, on the 1st iteration we'd
In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`.
On the 1st iteration we'd
like to select `x1` and on the second iteration we'd like to use `x2`.
This can be expressed with the following `index_map`:

View File

@ -171,6 +171,10 @@ IndexingMode = Union[Blocked, Unblocked]
@dataclasses.dataclass(unsafe_hash=True)
class BlockSpec:
"""Specifies how an array should be sliced for each iteration of a kernel.
See :ref:`pallas_blockspec` for more details.
"""
index_map: Callable[..., Any] | None = None
block_shape: tuple[int | None, ...] | None = None
memory_space: Any | None = None

View File

@ -228,7 +228,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes,
# Pad values to evenly divide into block dimensions.
# This allows interpret mode to catch errors on OOB memory accesses
# by poisoning values with NaN. It also fixes an inconstency with
# by poisoning values with NaN. It also fixes an inconsistency with
# lax.dynamic_slice where if the slice goes out of bounds, it will instead
# move the start_index backwards so the slice will fit in memory.
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
@ -1009,6 +1009,45 @@ def pallas_call(
name: str | None = None,
compiler_params: dict[str, Any] | None = None,
) -> Callable[..., Any]:
"""Invokes a Pallas kernel on some inputs.
See `Pallas Quickstart <https://jax.readthedocs.io/en/latest/pallas/quickstart.html>`_.
Args:
f: the kernel function, that receives a Ref for each input and output.
The shape of the Refs are given by the ``block_shape`` in the
corresponding ``in_specs`` and ``out_specs``.
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
and dtypes of the outputs.
grid_spec: TO BE DOCUMENTED.
debug: if True, Pallas prints various intermediate forms of the kernel
as it is being processed.
grid: the iteration space, as a tuple of integers. The kernel is executed
as many times as ``prod(grid)``. The default value ``None`` is equivalent
to ``()``.
See details at :ref:`pallas_grid`.
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
a structure matching that of the positional arguments.
See details at :ref:`pallas_blockspec`.
out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
a structure matching that of the outputs.
See details at :ref:`pallas_blockspec`.
The default value for `out_specs` specifies the whole array,
e.g., as `pl.BlockSpec(lambda *indices: indices, x.shape)`.
input_output_aliases: a dictionary mapping the index of some inputs to
the index of the output that aliases them.
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
grid whose body is the kernel lowered as a JAX function. This does not
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
This is useful for debugging.
name: TO BE DOCUMENTED.
compiler_params: TO BE DOCUMENTED.
Returns:
A function that can be called on a number of positional array arguments to
invoke the Pallas kernel.
"""
name = _extract_function_name(f, name)
if compiler_params is None:
compiler_params = {}

View File

@ -49,7 +49,15 @@ zip, unsafe_zip = util.safe_zip, zip
program_id_p = jax_core.Primitive("program_id")
def program_id(axis: int) -> jax.Array:
"""Returns the kernel execution position along the given axis of the grid."""
"""Returns the kernel execution position along the given axis of the grid.
For example, with a 2D `grid` in the kernel execution corresponding to the
grid coordinates `(1, 2)`,
`program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.
Args:
axis: the axis of the grid along which to count the program.
"""
return program_id_p.bind(axis=axis)
def program_id_bind(*, axis: int):

View File

@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for pallas, a JAX extension for custom kernels."""
"""Module for Pallas, a JAX extension for custom kernels.
See the Pallas documentation at https://jax.readthedocs.io/en/latest/pallas.html.
"""
from jax._src import pallas
from jax._src.pallas.core import BlockSpec