Changed `pl.BlockSpec to accept block_shape before index_map`

So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
This commit is contained in:
Sergei Lebedev 2024-07-01 14:25:25 -07:00 committed by jax authors
parent 94ba6c3f98
commit a2a5068e5e
11 changed files with 127 additions and 81 deletions

View File

@ -18,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
be passed *before* `index_map`. The old argument order is deprecated and
will be removed in a future release.
* Deprecations
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,

View File

@ -23,7 +23,7 @@ job of compiling user programs but inevitably some users hit XLA's
limitations.
In these cases, we need to provide an “escape hatch” to allow
experts to write hand-tuned kernels that outperform XLA at that
point in time.
point in time.
Furthermore, advances in ML systems research take some time to be
incorporated into XLA and users often want to run ahead with them.
Over time, the compiler can incorporate the optimizations that were proven
@ -431,10 +431,10 @@ add = pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
in_specs=[
pl.BlockSpec(lambda i: i, (2,)),
pl.BlockSpec(lambda i: i, (2,))
pl.BlockSpec((2,), lambda i: i),
pl.BlockSpec((2,), lambda i: i)
],
out_specs=pl.BlockSpec(lambda i: i, (2,)),
out_specs=pl.BlockSpec((2,), lambda i: i),
grid=(4,))
add(x, y)
```
@ -465,10 +465,10 @@ def matmul(x, y, *, block_shape, activation):
partial(matmul_kernel, block_k=block_k, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (block_m, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], block_n))
pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n)),
out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)),
grid=(4, 4),
)
return fused_matmul(x, y)

View File

@ -2,6 +2,8 @@
# Grids and BlockSpecs
<!--* freshness: { reviewed: '2024-06-01' } *-->
(pallas_grid)=
### `grid`, a.k.a. kernels in a loop
@ -115,7 +117,7 @@ below:
```python
>>> def slices_for_invocation(x_shape: tuple[int, ...],
... x_spec: pl.BlockSpec,
... grid: tuple[int, ...],
... grid: tuple[int, ...],
... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]:
... assert len(invocation_indices) == len(grid)
... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid))
@ -135,15 +137,15 @@ below:
For example:
```python
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)),
... grid = (10, 5),
... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
... grid = (10, 5),
... invocation_indices = (2, 3))
[slice(20, 30, None), slice(60, 80, None)]
>>> # Same shape of the array and blocks, but we iterate over each block 4 times
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec(lambda i, j, k: (i, j), (10, 20)),
... grid = (10, 5, 4),
... x_spec = pl.BlockSpec((10, 20), lambda i, j, k: (i, j)),
... grid = (10, 5, 4),
... invocation_indices = (2, 3, 0))
[slice(20, 30, None), slice(60, 80, None)]
@ -166,7 +168,7 @@ over the second axis:
... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32),
... grid=grid,
... in_specs=[],
... out_specs=pl.BlockSpec(out_index_map, block_shape),
... out_specs=pl.BlockSpec(block_shape, out_index_map),
... interpret=True)()
... print(res)
@ -195,7 +197,7 @@ Hence, we iterate over the same output block 10 times.
The output shown below was generated on CPU using `interpret=True`
mode, which at the moment executes the invocation sequentially.
On TPUs, programs are executed in a combination of parallel and sequential,
and this function generates the output shown.
and this function generates the output shown.
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).
```python

View File

@ -121,9 +121,10 @@
"source": [
"@jax.jit\n",
"def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:\n",
" return pl.pallas_call(add_vectors_kernel,\n",
" out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
" )(x, y)\n",
" return pl.pallas_call(\n",
" add_vectors_kernel,\n",
" out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
" )(x, y)\n",
"add_vectors(jnp.arange(8), jnp.arange(8))"
]
},
@ -378,12 +379,12 @@
"each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication.\n",
"To express this, we'd first use a `(2, 2)` grid (one block for each program).\n",
"\n",
"For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this\n",
"For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this\n",
"carves `x` up into \"row\" blocks.\n",
"To see this see how both program instances\n",
"`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.\n",
"For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`.\n",
"Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.\n",
"For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`.\n",
"Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`.\n",
"\n",
"These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n",
"\n",
@ -408,11 +409,11 @@
" out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n",
" grid=(2, 2),\n",
" in_specs=[\n",
" pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n",
" pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n",
" pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),\n",
" pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))\n",
" ],\n",
" out_specs=pl.BlockSpec(\n",
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
" (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),\n",
" )\n",
" )(x, y)\n",
"k1, k2 = jax.random.split(jax.random.key(0))\n",
@ -448,11 +449,11 @@
" out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n",
" grid=(2, 2),\n",
" in_specs=[\n",
" pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n",
" pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n",
" pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),\n",
" pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))\n",
" ],\n",
" out_specs=pl.BlockSpec(\n",
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
" (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)\n",
" ),\n",
" )(x, y)\n",
"k1, k2 = jax.random.split(jax.random.key(0))\n",

View File

@ -81,9 +81,10 @@ We use the `pallas_call` higher-order function.
```{code-cell} ipython3
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
return pl.pallas_call(
add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
```
@ -272,12 +273,12 @@ the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where
each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication.
To express this, we'd first use a `(2, 2)` grid (one block for each program).
For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this
For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this
carves `x` up into "row" blocks.
To see this see how both program instances
`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.
For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`.
Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.
For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`.
Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`.
These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.
@ -296,11 +297,11 @@ def matmul(x: jax.Array, y: jax.Array):
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
)
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
@ -325,11 +326,11 @@ def matmul(x: jax.Array, y: jax.Array, *, activation):
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
),
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))

View File

@ -70,7 +70,7 @@
" Compute units operate on values that live in SREGs and VREGs and output\n",
" values into those registers as well.\n",
"\n",
"In order to do a vectorized computation on our values `x` and `y` that live \n",
"In order to do a vectorized computation on our values `x` and `y` that live\n",
"in HBM, we need to:\n",
"\n",
"1. Copy the values `x` and `y` into VMEM.\n",
@ -174,7 +174,7 @@
"Pallas exposes access to lower level memory spaces like VMEM and SMEM but\n",
"writing kernels utilizing them adds some considerations.\n",
"\n",
"1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB \n",
"1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB\n",
" and SMEM ranges in the tens to hundreds of KiB.\n",
" If our arrays are too big, we won't even be able to fit them into VMEM at all.\n",
" For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't\n",
@ -289,7 +289,7 @@
"\n",
"We'd then construct the `BlockSpec`:\n",
"```python\n",
"block_spec = pl.BlockSpec(x_index_map, (256, 512))\n",
"block_spec = pl.BlockSpec((256, 512), x_index_map)\n",
"```\n",
"\n",
"The `BlockSpec`s for `y` and `z` will be the same as the one for `x`."
@ -335,13 +335,14 @@
],
"source": [
"def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:\n",
" block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n",
" block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n",
" return pl.pallas_call(\n",
" add_matrices_kernel,\n",
" out_shape=x,\n",
" in_specs=[block_spec, block_spec],\n",
" out_specs=block_spec,\n",
" grid=(2,))(x, y)\n",
" grid=(2,)\n",
" )(x, y)\n",
"\n",
"add_matrices_pipelined(x, y)"
]
@ -403,8 +404,7 @@
" x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n",
") -> jax.Array:\n",
" m, n = x.shape\n",
" block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))\n",
"\n",
" block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n",
" return pl.pallas_call(\n",
" add_matrices_kernel,\n",
" out_shape=x,\n",
@ -413,7 +413,6 @@
" grid=(m // bm, n // bn),\n",
" )(x, y)\n",
"\n",
"\n",
"np.testing.assert_array_equal(\n",
" add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y\n",
")\n",
@ -526,10 +525,10 @@
" naive_sum_kernel,\n",
" grid=grid,\n",
" # None in `block_shape` means we pick a size of 1 and squeeze it away\n",
" in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n",
" out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n",
" out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n",
" )(x)\n",
" in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n",
" out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n",
" out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n",
" )(x)\n",
"naive_sum(x)"
]
},
@ -603,10 +602,11 @@
" sum_kernel,\n",
" grid=grid,\n",
" # None in `block_shape` means we pick a size of 1 and squeeze it away\n",
" in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n",
" out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n",
" in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n",
" out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n",
" out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n",
" )(x)\n",
" )(x)\n",
"\n",
"sum(x)"
]
},
@ -689,15 +689,15 @@
],
"source": [
"def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:\n",
" block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n",
" block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n",
" return pl.pallas_call(\n",
" add_matrices_kernel,\n",
" out_shape=x,\n",
" in_specs=[block_spec, block_spec],\n",
" out_specs=block_spec,\n",
" grid=(2,),\n",
" compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",))))(\n",
" x, y)\n",
" compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n",
" )(x, y)\n",
"\n",
"x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",
"add_matrices_pipelined_megacore(x, y)"

View File

@ -62,7 +62,7 @@ Let's talk about the components of this diagram in more detail:
Compute units operate on values that live in SREGs and VREGs and output
values into those registers as well.
In order to do a vectorized computation on our values `x` and `y` that live
In order to do a vectorized computation on our values `x` and `y` that live
in HBM, we need to:
1. Copy the values `x` and `y` into VMEM.
@ -129,7 +129,7 @@ the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`.
Pallas exposes access to lower level memory spaces like VMEM and SMEM but
writing kernels utilizing them adds some considerations.
1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB
1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB
and SMEM ranges in the tens to hundreds of KiB.
If our arrays are too big, we won't even be able to fit them into VMEM at all.
For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't
@ -229,7 +229,7 @@ def x_index_map(i):
We'd then construct the `BlockSpec`:
```python
block_spec = pl.BlockSpec(x_index_map, (256, 512))
block_spec = pl.BlockSpec((256, 512), x_index_map)
```
The `BlockSpec`s for `y` and `z` will be the same as the one for `x`.
@ -247,13 +247,14 @@ and `out_specs` corresponds to the output).
:outputId: 504bab29-83f3-4e1f-8664-1860ad15b6de
def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,))(x, y)
grid=(2,)
)(x, y)
add_matrices_pipelined(x, y)
```
@ -295,8 +296,7 @@ def add_matrices_pipelined_2d(
x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
m, n = x.shape
block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))
block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
@ -305,7 +305,6 @@ def add_matrices_pipelined_2d(
grid=(m // bm, n // bn),
)(x, y)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
@ -359,10 +358,10 @@ def naive_sum(x: jax.Array) -> jax.Array:
naive_sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
)(x)
naive_sum(x)
```
@ -409,10 +408,11 @@ def sum(x: jax.Array) -> jax.Array:
sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
)(x)
sum(x)
```
@ -458,15 +458,15 @@ annotation to `pallas_call` called `dimension_semantics`.
:outputId: 385ed87c-d95c-466c-af77-df3845c979f2
def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))(
x, y)
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",)))
)(x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)

View File

@ -16,12 +16,14 @@
from __future__ import annotations
from collections.abc import Callable, Iterator, Sequence
import copy
import contextlib
import copy
import dataclasses
import functools
import inspect
import threading
from typing import Any, Union
import warnings
import jax
from jax._src import api_util
@ -169,16 +171,47 @@ blocked = Blocked()
IndexingMode = Union[Blocked, Unblocked]
_BLOCK_SPECSIG = inspect.Signature.from_callable(
lambda index_map, block_shape: ...
)
@dataclasses.dataclass(unsafe_hash=True)
class BlockSpec:
"""Specifies how an array should be sliced for each iteration of a kernel.
See :ref:`pallas_blockspec` for more details.
"""
index_map: Callable[..., Any] | None = None
block_shape: tuple[int | None, ...] | None = None
memory_space: Any | None = None
indexing_mode: IndexingMode = blocked
index_map: Callable[..., Any] | None = None
memory_space: Any | None = dataclasses.field(kw_only=True, default=None)
indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked)
def __init__(
self,
*args,
memory_space: Any | None = None,
indexing_mode: IndexingMode = blocked,
**kwargs: Any,
) -> None:
bound_args = _BLOCK_SPECSIG.bind_partial(*args, **kwargs)
block_shape = bound_args.arguments.get("block_shape", None)
index_map = bound_args.arguments.get("index_map", None)
if callable(block_shape):
# TODO(slebedev): Remove this code path and update the signature of
# __init__ after October 1, 2024.
warnings.warn(
"BlockSpec now expects ``block_shape`` to be passed before"
" ``index_map``. Update your code by swapping the order of these"
" arguments. For example, ``pl.BlockSpace(lambda i: i, (42,))``"
" should be written as ``pl.BlockSpec((42,), lambda i: i)``.",
DeprecationWarning,
)
index_map, block_shape = block_shape, index_map
self.block_shape = block_shape
self.index_map = index_map
self.memory_space = memory_space
self.indexing_mode = indexing_mode
def compute_index(self, *args):
assert self.index_map is not None

View File

@ -1033,7 +1033,7 @@ def pallas_call(
a structure matching that of the outputs.
See details at :ref:`pallas_blockspec`.
The default value for `out_specs` specifies the whole array,
e.g., as `pl.BlockSpec(lambda *indices: indices, x.shape)`.
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``.
input_output_aliases: a dictionary mapping the index of some inputs to
the index of the output that aliases them.
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the

View File

@ -61,6 +61,8 @@ filterwarnings = [
"default:Special cases found for .* but none were parsed.*:UserWarning",
"default:.*is not JSON-serializable. Using the repr instead.",
# end array_api_tests-related warnings
# TODO(slebedev): Remove once we migrate all pl.BlockSpec usages in JAX.
"default:BlockSpec now expects .*:DeprecationWarning",
]
doctest_optionflags = [
"NUMBER",

View File

@ -604,8 +604,12 @@ class PallasCallDMATest(PallasTPUTest):
x = jnp.ones((8, 128), dtype=jnp.float32)
y = self.pallas_call(
kernel,
in_specs=[pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY)],
out_specs=pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY),
in_specs=[
pl.BlockSpec(None, None, memory_space=pltpu.TPUMemorySpace.ANY)
],
out_specs=pl.BlockSpec(
None, None, memory_space=pltpu.TPUMemorySpace.ANY
),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
jax.block_until_ready(y)