mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Changed `pl.BlockSpec
to accept
block_shape
before
index_map
`
So, instead of pl.BlockSpec(lambda i, j: ..., (42, 24)) ``pl.BlockSpec`` now expects pl.BlockSpec((42, 24), lambda i, j: ...) I will update Pallas tests in a follow up. PiperOrigin-RevId: 648486321
This commit is contained in:
parent
94ba6c3f98
commit
a2a5068e5e
@ -18,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
|
||||
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
|
||||
pip wheels.
|
||||
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
|
||||
be passed *before* `index_map`. The old argument order is deprecated and
|
||||
will be removed in a future release.
|
||||
* Deprecations
|
||||
* Removed a number of previously-deprecated internal APIs related to
|
||||
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
|
||||
|
@ -23,7 +23,7 @@ job of compiling user programs but inevitably some users hit XLA's
|
||||
limitations.
|
||||
In these cases, we need to provide an “escape hatch” to allow
|
||||
experts to write hand-tuned kernels that outperform XLA at that
|
||||
point in time.
|
||||
point in time.
|
||||
Furthermore, advances in ML systems research take some time to be
|
||||
incorporated into XLA and users often want to run ahead with them.
|
||||
Over time, the compiler can incorporate the optimizations that were proven
|
||||
@ -431,10 +431,10 @@ add = pl.pallas_call(
|
||||
add_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda i: i, (2,)),
|
||||
pl.BlockSpec(lambda i: i, (2,))
|
||||
pl.BlockSpec((2,), lambda i: i),
|
||||
pl.BlockSpec((2,), lambda i: i)
|
||||
],
|
||||
out_specs=pl.BlockSpec(lambda i: i, (2,)),
|
||||
out_specs=pl.BlockSpec((2,), lambda i: i),
|
||||
grid=(4,))
|
||||
add(x, y)
|
||||
```
|
||||
@ -465,10 +465,10 @@ def matmul(x, y, *, block_shape, activation):
|
||||
partial(matmul_kernel, block_k=block_k, activation=activation),
|
||||
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda i, j: (i, 0), (block_m, x.shape[1])),
|
||||
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], block_n))
|
||||
pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)),
|
||||
pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j))
|
||||
],
|
||||
out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n)),
|
||||
out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)),
|
||||
grid=(4, 4),
|
||||
)
|
||||
return fused_matmul(x, y)
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
# Grids and BlockSpecs
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-01' } *-->
|
||||
|
||||
(pallas_grid)=
|
||||
### `grid`, a.k.a. kernels in a loop
|
||||
|
||||
@ -115,7 +117,7 @@ below:
|
||||
```python
|
||||
>>> def slices_for_invocation(x_shape: tuple[int, ...],
|
||||
... x_spec: pl.BlockSpec,
|
||||
... grid: tuple[int, ...],
|
||||
... grid: tuple[int, ...],
|
||||
... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]:
|
||||
... assert len(invocation_indices) == len(grid)
|
||||
... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid))
|
||||
@ -135,15 +137,15 @@ below:
|
||||
For example:
|
||||
```python
|
||||
>>> slices_for_invocation(x_shape=(100, 100),
|
||||
... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)),
|
||||
... grid = (10, 5),
|
||||
... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
|
||||
... grid = (10, 5),
|
||||
... invocation_indices = (2, 3))
|
||||
[slice(20, 30, None), slice(60, 80, None)]
|
||||
|
||||
>>> # Same shape of the array and blocks, but we iterate over each block 4 times
|
||||
>>> slices_for_invocation(x_shape=(100, 100),
|
||||
... x_spec = pl.BlockSpec(lambda i, j, k: (i, j), (10, 20)),
|
||||
... grid = (10, 5, 4),
|
||||
... x_spec = pl.BlockSpec((10, 20), lambda i, j, k: (i, j)),
|
||||
... grid = (10, 5, 4),
|
||||
... invocation_indices = (2, 3, 0))
|
||||
[slice(20, 30, None), slice(60, 80, None)]
|
||||
|
||||
@ -166,7 +168,7 @@ over the second axis:
|
||||
... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32),
|
||||
... grid=grid,
|
||||
... in_specs=[],
|
||||
... out_specs=pl.BlockSpec(out_index_map, block_shape),
|
||||
... out_specs=pl.BlockSpec(block_shape, out_index_map),
|
||||
... interpret=True)()
|
||||
... print(res)
|
||||
|
||||
@ -195,7 +197,7 @@ Hence, we iterate over the same output block 10 times.
|
||||
The output shown below was generated on CPU using `interpret=True`
|
||||
mode, which at the moment executes the invocation sequentially.
|
||||
On TPUs, programs are executed in a combination of parallel and sequential,
|
||||
and this function generates the output shown.
|
||||
and this function generates the output shown.
|
||||
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).
|
||||
|
||||
```python
|
||||
|
@ -121,9 +121,10 @@
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:\n",
|
||||
" return pl.pallas_call(add_vectors_kernel,\n",
|
||||
" out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
|
||||
" )(x, y)\n",
|
||||
" return pl.pallas_call(\n",
|
||||
" add_vectors_kernel,\n",
|
||||
" out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
|
||||
" )(x, y)\n",
|
||||
"add_vectors(jnp.arange(8), jnp.arange(8))"
|
||||
]
|
||||
},
|
||||
@ -378,12 +379,12 @@
|
||||
"each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication.\n",
|
||||
"To express this, we'd first use a `(2, 2)` grid (one block for each program).\n",
|
||||
"\n",
|
||||
"For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this\n",
|
||||
"For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this\n",
|
||||
"carves `x` up into \"row\" blocks.\n",
|
||||
"To see this see how both program instances\n",
|
||||
"`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.\n",
|
||||
"For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`.\n",
|
||||
"Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.\n",
|
||||
"For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`.\n",
|
||||
"Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`.\n",
|
||||
"\n",
|
||||
"These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n",
|
||||
"\n",
|
||||
@ -408,11 +409,11 @@
|
||||
" out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n",
|
||||
" grid=(2, 2),\n",
|
||||
" in_specs=[\n",
|
||||
" pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n",
|
||||
" pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n",
|
||||
" pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),\n",
|
||||
" pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))\n",
|
||||
" ],\n",
|
||||
" out_specs=pl.BlockSpec(\n",
|
||||
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
|
||||
" (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),\n",
|
||||
" )\n",
|
||||
" )(x, y)\n",
|
||||
"k1, k2 = jax.random.split(jax.random.key(0))\n",
|
||||
@ -448,11 +449,11 @@
|
||||
" out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n",
|
||||
" grid=(2, 2),\n",
|
||||
" in_specs=[\n",
|
||||
" pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n",
|
||||
" pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n",
|
||||
" pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),\n",
|
||||
" pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))\n",
|
||||
" ],\n",
|
||||
" out_specs=pl.BlockSpec(\n",
|
||||
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
|
||||
" (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)\n",
|
||||
" ),\n",
|
||||
" )(x, y)\n",
|
||||
"k1, k2 = jax.random.split(jax.random.key(0))\n",
|
||||
|
@ -81,9 +81,10 @@ We use the `pallas_call` higher-order function.
|
||||
```{code-cell} ipython3
|
||||
@jax.jit
|
||||
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
|
||||
return pl.pallas_call(add_vectors_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
)(x, y)
|
||||
return pl.pallas_call(
|
||||
add_vectors_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
)(x, y)
|
||||
add_vectors(jnp.arange(8), jnp.arange(8))
|
||||
```
|
||||
|
||||
@ -272,12 +273,12 @@ the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where
|
||||
each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication.
|
||||
To express this, we'd first use a `(2, 2)` grid (one block for each program).
|
||||
|
||||
For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this
|
||||
For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this
|
||||
carves `x` up into "row" blocks.
|
||||
To see this see how both program instances
|
||||
`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.
|
||||
For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`.
|
||||
Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.
|
||||
For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`.
|
||||
Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`.
|
||||
|
||||
These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.
|
||||
|
||||
@ -296,11 +297,11 @@ def matmul(x: jax.Array, y: jax.Array):
|
||||
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
|
||||
grid=(2, 2),
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
|
||||
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
|
||||
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
|
||||
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
|
||||
],
|
||||
out_specs=pl.BlockSpec(
|
||||
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
|
||||
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
|
||||
)
|
||||
)(x, y)
|
||||
k1, k2 = jax.random.split(jax.random.key(0))
|
||||
@ -325,11 +326,11 @@ def matmul(x: jax.Array, y: jax.Array, *, activation):
|
||||
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
|
||||
grid=(2, 2),
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
|
||||
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
|
||||
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
|
||||
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
|
||||
],
|
||||
out_specs=pl.BlockSpec(
|
||||
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
|
||||
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
|
||||
),
|
||||
)(x, y)
|
||||
k1, k2 = jax.random.split(jax.random.key(0))
|
||||
|
@ -70,7 +70,7 @@
|
||||
" Compute units operate on values that live in SREGs and VREGs and output\n",
|
||||
" values into those registers as well.\n",
|
||||
"\n",
|
||||
"In order to do a vectorized computation on our values `x` and `y` that live \n",
|
||||
"In order to do a vectorized computation on our values `x` and `y` that live\n",
|
||||
"in HBM, we need to:\n",
|
||||
"\n",
|
||||
"1. Copy the values `x` and `y` into VMEM.\n",
|
||||
@ -174,7 +174,7 @@
|
||||
"Pallas exposes access to lower level memory spaces like VMEM and SMEM but\n",
|
||||
"writing kernels utilizing them adds some considerations.\n",
|
||||
"\n",
|
||||
"1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB \n",
|
||||
"1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB\n",
|
||||
" and SMEM ranges in the tens to hundreds of KiB.\n",
|
||||
" If our arrays are too big, we won't even be able to fit them into VMEM at all.\n",
|
||||
" For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't\n",
|
||||
@ -289,7 +289,7 @@
|
||||
"\n",
|
||||
"We'd then construct the `BlockSpec`:\n",
|
||||
"```python\n",
|
||||
"block_spec = pl.BlockSpec(x_index_map, (256, 512))\n",
|
||||
"block_spec = pl.BlockSpec((256, 512), x_index_map)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The `BlockSpec`s for `y` and `z` will be the same as the one for `x`."
|
||||
@ -335,13 +335,14 @@
|
||||
],
|
||||
"source": [
|
||||
"def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:\n",
|
||||
" block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n",
|
||||
" block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n",
|
||||
" return pl.pallas_call(\n",
|
||||
" add_matrices_kernel,\n",
|
||||
" out_shape=x,\n",
|
||||
" in_specs=[block_spec, block_spec],\n",
|
||||
" out_specs=block_spec,\n",
|
||||
" grid=(2,))(x, y)\n",
|
||||
" grid=(2,)\n",
|
||||
" )(x, y)\n",
|
||||
"\n",
|
||||
"add_matrices_pipelined(x, y)"
|
||||
]
|
||||
@ -403,8 +404,7 @@
|
||||
" x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n",
|
||||
") -> jax.Array:\n",
|
||||
" m, n = x.shape\n",
|
||||
" block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))\n",
|
||||
"\n",
|
||||
" block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n",
|
||||
" return pl.pallas_call(\n",
|
||||
" add_matrices_kernel,\n",
|
||||
" out_shape=x,\n",
|
||||
@ -413,7 +413,6 @@
|
||||
" grid=(m // bm, n // bn),\n",
|
||||
" )(x, y)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"np.testing.assert_array_equal(\n",
|
||||
" add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y\n",
|
||||
")\n",
|
||||
@ -526,10 +525,10 @@
|
||||
" naive_sum_kernel,\n",
|
||||
" grid=grid,\n",
|
||||
" # None in `block_shape` means we pick a size of 1 and squeeze it away\n",
|
||||
" in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n",
|
||||
" out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n",
|
||||
" out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n",
|
||||
" )(x)\n",
|
||||
" in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n",
|
||||
" out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n",
|
||||
" out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n",
|
||||
" )(x)\n",
|
||||
"naive_sum(x)"
|
||||
]
|
||||
},
|
||||
@ -603,10 +602,11 @@
|
||||
" sum_kernel,\n",
|
||||
" grid=grid,\n",
|
||||
" # None in `block_shape` means we pick a size of 1 and squeeze it away\n",
|
||||
" in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n",
|
||||
" out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n",
|
||||
" in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n",
|
||||
" out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n",
|
||||
" out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n",
|
||||
" )(x)\n",
|
||||
" )(x)\n",
|
||||
"\n",
|
||||
"sum(x)"
|
||||
]
|
||||
},
|
||||
@ -689,15 +689,15 @@
|
||||
],
|
||||
"source": [
|
||||
"def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:\n",
|
||||
" block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n",
|
||||
" block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n",
|
||||
" return pl.pallas_call(\n",
|
||||
" add_matrices_kernel,\n",
|
||||
" out_shape=x,\n",
|
||||
" in_specs=[block_spec, block_spec],\n",
|
||||
" out_specs=block_spec,\n",
|
||||
" grid=(2,),\n",
|
||||
" compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",))))(\n",
|
||||
" x, y)\n",
|
||||
" compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n",
|
||||
" )(x, y)\n",
|
||||
"\n",
|
||||
"x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",
|
||||
"add_matrices_pipelined_megacore(x, y)"
|
||||
|
@ -62,7 +62,7 @@ Let's talk about the components of this diagram in more detail:
|
||||
Compute units operate on values that live in SREGs and VREGs and output
|
||||
values into those registers as well.
|
||||
|
||||
In order to do a vectorized computation on our values `x` and `y` that live
|
||||
In order to do a vectorized computation on our values `x` and `y` that live
|
||||
in HBM, we need to:
|
||||
|
||||
1. Copy the values `x` and `y` into VMEM.
|
||||
@ -129,7 +129,7 @@ the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`.
|
||||
Pallas exposes access to lower level memory spaces like VMEM and SMEM but
|
||||
writing kernels utilizing them adds some considerations.
|
||||
|
||||
1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB
|
||||
1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB
|
||||
and SMEM ranges in the tens to hundreds of KiB.
|
||||
If our arrays are too big, we won't even be able to fit them into VMEM at all.
|
||||
For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't
|
||||
@ -229,7 +229,7 @@ def x_index_map(i):
|
||||
|
||||
We'd then construct the `BlockSpec`:
|
||||
```python
|
||||
block_spec = pl.BlockSpec(x_index_map, (256, 512))
|
||||
block_spec = pl.BlockSpec((256, 512), x_index_map)
|
||||
```
|
||||
|
||||
The `BlockSpec`s for `y` and `z` will be the same as the one for `x`.
|
||||
@ -247,13 +247,14 @@ and `out_specs` corresponds to the output).
|
||||
:outputId: 504bab29-83f3-4e1f-8664-1860ad15b6de
|
||||
|
||||
def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
|
||||
block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
|
||||
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
|
||||
return pl.pallas_call(
|
||||
add_matrices_kernel,
|
||||
out_shape=x,
|
||||
in_specs=[block_spec, block_spec],
|
||||
out_specs=block_spec,
|
||||
grid=(2,))(x, y)
|
||||
grid=(2,)
|
||||
)(x, y)
|
||||
|
||||
add_matrices_pipelined(x, y)
|
||||
```
|
||||
@ -295,8 +296,7 @@ def add_matrices_pipelined_2d(
|
||||
x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
|
||||
) -> jax.Array:
|
||||
m, n = x.shape
|
||||
block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))
|
||||
|
||||
block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
|
||||
return pl.pallas_call(
|
||||
add_matrices_kernel,
|
||||
out_shape=x,
|
||||
@ -305,7 +305,6 @@ def add_matrices_pipelined_2d(
|
||||
grid=(m // bm, n // bn),
|
||||
)(x, y)
|
||||
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
|
||||
)
|
||||
@ -359,10 +358,10 @@ def naive_sum(x: jax.Array) -> jax.Array:
|
||||
naive_sum_kernel,
|
||||
grid=grid,
|
||||
# None in `block_shape` means we pick a size of 1 and squeeze it away
|
||||
in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
|
||||
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
|
||||
)(x)
|
||||
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
|
||||
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
|
||||
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
|
||||
)(x)
|
||||
naive_sum(x)
|
||||
```
|
||||
|
||||
@ -409,10 +408,11 @@ def sum(x: jax.Array) -> jax.Array:
|
||||
sum_kernel,
|
||||
grid=grid,
|
||||
# None in `block_shape` means we pick a size of 1 and squeeze it away
|
||||
in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
|
||||
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
|
||||
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
|
||||
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
|
||||
)(x)
|
||||
)(x)
|
||||
|
||||
sum(x)
|
||||
```
|
||||
|
||||
@ -458,15 +458,15 @@ annotation to `pallas_call` called `dimension_semantics`.
|
||||
:outputId: 385ed87c-d95c-466c-af77-df3845c979f2
|
||||
|
||||
def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
|
||||
block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
|
||||
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
|
||||
return pl.pallas_call(
|
||||
add_matrices_kernel,
|
||||
out_shape=x,
|
||||
in_specs=[block_spec, block_spec],
|
||||
out_specs=block_spec,
|
||||
grid=(2,),
|
||||
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))(
|
||||
x, y)
|
||||
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",)))
|
||||
)(x, y)
|
||||
|
||||
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
|
||||
add_matrices_pipelined_megacore(x, y)
|
||||
|
@ -16,12 +16,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
import copy
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Any, Union
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
@ -169,16 +171,47 @@ blocked = Blocked()
|
||||
IndexingMode = Union[Blocked, Unblocked]
|
||||
|
||||
|
||||
_BLOCK_SPECSIG = inspect.Signature.from_callable(
|
||||
lambda index_map, block_shape: ...
|
||||
)
|
||||
|
||||
@dataclasses.dataclass(unsafe_hash=True)
|
||||
class BlockSpec:
|
||||
"""Specifies how an array should be sliced for each iteration of a kernel.
|
||||
|
||||
See :ref:`pallas_blockspec` for more details.
|
||||
"""
|
||||
index_map: Callable[..., Any] | None = None
|
||||
block_shape: tuple[int | None, ...] | None = None
|
||||
memory_space: Any | None = None
|
||||
indexing_mode: IndexingMode = blocked
|
||||
index_map: Callable[..., Any] | None = None
|
||||
memory_space: Any | None = dataclasses.field(kw_only=True, default=None)
|
||||
indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
memory_space: Any | None = None,
|
||||
indexing_mode: IndexingMode = blocked,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
bound_args = _BLOCK_SPECSIG.bind_partial(*args, **kwargs)
|
||||
block_shape = bound_args.arguments.get("block_shape", None)
|
||||
index_map = bound_args.arguments.get("index_map", None)
|
||||
if callable(block_shape):
|
||||
# TODO(slebedev): Remove this code path and update the signature of
|
||||
# __init__ after October 1, 2024.
|
||||
warnings.warn(
|
||||
"BlockSpec now expects ``block_shape`` to be passed before"
|
||||
" ``index_map``. Update your code by swapping the order of these"
|
||||
" arguments. For example, ``pl.BlockSpace(lambda i: i, (42,))``"
|
||||
" should be written as ``pl.BlockSpec((42,), lambda i: i)``.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
index_map, block_shape = block_shape, index_map
|
||||
|
||||
self.block_shape = block_shape
|
||||
self.index_map = index_map
|
||||
self.memory_space = memory_space
|
||||
self.indexing_mode = indexing_mode
|
||||
|
||||
def compute_index(self, *args):
|
||||
assert self.index_map is not None
|
||||
|
@ -1033,7 +1033,7 @@ def pallas_call(
|
||||
a structure matching that of the outputs.
|
||||
See details at :ref:`pallas_blockspec`.
|
||||
The default value for `out_specs` specifies the whole array,
|
||||
e.g., as `pl.BlockSpec(lambda *indices: indices, x.shape)`.
|
||||
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``.
|
||||
input_output_aliases: a dictionary mapping the index of some inputs to
|
||||
the index of the output that aliases them.
|
||||
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
|
||||
|
@ -61,6 +61,8 @@ filterwarnings = [
|
||||
"default:Special cases found for .* but none were parsed.*:UserWarning",
|
||||
"default:.*is not JSON-serializable. Using the repr instead.",
|
||||
# end array_api_tests-related warnings
|
||||
# TODO(slebedev): Remove once we migrate all pl.BlockSpec usages in JAX.
|
||||
"default:BlockSpec now expects .*:DeprecationWarning",
|
||||
]
|
||||
doctest_optionflags = [
|
||||
"NUMBER",
|
||||
|
@ -604,8 +604,12 @@ class PallasCallDMATest(PallasTPUTest):
|
||||
x = jnp.ones((8, 128), dtype=jnp.float32)
|
||||
y = self.pallas_call(
|
||||
kernel,
|
||||
in_specs=[pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY)],
|
||||
out_specs=pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY),
|
||||
in_specs=[
|
||||
pl.BlockSpec(None, None, memory_space=pltpu.TPUMemorySpace.ANY)
|
||||
],
|
||||
out_specs=pl.BlockSpec(
|
||||
None, None, memory_space=pltpu.TPUMemorySpace.ANY
|
||||
),
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
)(x)
|
||||
jax.block_until_ready(y)
|
||||
|
Loading…
x
Reference in New Issue
Block a user