From a2a5068e5e1dbdde15cd34e48a325a5acc37d7a6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 1 Jul 2024 14:25:25 -0700 Subject: [PATCH] Changed ``pl.BlockSpec`` to accept ``block_shape`` before ``index_map`` So, instead of pl.BlockSpec(lambda i, j: ..., (42, 24)) ``pl.BlockSpec`` now expects pl.BlockSpec((42, 24), lambda i, j: ...) I will update Pallas tests in a follow up. PiperOrigin-RevId: 648486321 --- CHANGELOG.md | 3 ++ docs/pallas/design.md | 14 +++++----- docs/pallas/grid_blockspec.md | 16 ++++++----- docs/pallas/quickstart.ipynb | 25 +++++++++-------- docs/pallas/quickstart.md | 25 +++++++++-------- docs/pallas/tpu/pipelining.ipynb | 36 ++++++++++++------------ docs/pallas/tpu/pipelining.md | 36 ++++++++++++------------ jax/_src/pallas/core.py | 41 +++++++++++++++++++++++++--- jax/_src/pallas/pallas_call.py | 2 +- pyproject.toml | 2 ++ tests/pallas/pallas_call_tpu_test.py | 8 ++++-- 11 files changed, 127 insertions(+), 81 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f8e1a084..6c5cd34d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list * `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be installed either as a part of local CUDA installation, or via NVIDIA's CUDA pip wheels. + * {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to + be passed *before* `index_map`. The old argument order is deprecated and + will be removed in a future release. * Deprecations * Removed a number of previously-deprecated internal APIs related to polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`, diff --git a/docs/pallas/design.md b/docs/pallas/design.md index 53a7cfb7e..3ba32f25a 100644 --- a/docs/pallas/design.md +++ b/docs/pallas/design.md @@ -23,7 +23,7 @@ job of compiling user programs but inevitably some users hit XLA's limitations. In these cases, we need to provide an “escape hatch” to allow experts to write hand-tuned kernels that outperform XLA at that -point in time. +point in time. Furthermore, advances in ML systems research take some time to be incorporated into XLA and users often want to run ahead with them. Over time, the compiler can incorporate the optimizations that were proven @@ -431,10 +431,10 @@ add = pl.pallas_call( add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32), in_specs=[ - pl.BlockSpec(lambda i: i, (2,)), - pl.BlockSpec(lambda i: i, (2,)) + pl.BlockSpec((2,), lambda i: i), + pl.BlockSpec((2,), lambda i: i) ], - out_specs=pl.BlockSpec(lambda i: i, (2,)), + out_specs=pl.BlockSpec((2,), lambda i: i), grid=(4,)) add(x, y) ``` @@ -465,10 +465,10 @@ def matmul(x, y, *, block_shape, activation): partial(matmul_kernel, block_k=block_k, activation=activation), out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32), in_specs=[ - pl.BlockSpec(lambda i, j: (i, 0), (block_m, x.shape[1])), - pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], block_n)) + pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j)) ], - out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n)), + out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)), grid=(4, 4), ) return fused_matmul(x, y) diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index 0fbc8e602..c89c536a7 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -2,6 +2,8 @@ # Grids and BlockSpecs + + (pallas_grid)= ### `grid`, a.k.a. kernels in a loop @@ -115,7 +117,7 @@ below: ```python >>> def slices_for_invocation(x_shape: tuple[int, ...], ... x_spec: pl.BlockSpec, -... grid: tuple[int, ...], +... grid: tuple[int, ...], ... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]: ... assert len(invocation_indices) == len(grid) ... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid)) @@ -135,15 +137,15 @@ below: For example: ```python >>> slices_for_invocation(x_shape=(100, 100), -... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)), -... grid = (10, 5), +... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)), +... grid = (10, 5), ... invocation_indices = (2, 3)) [slice(20, 30, None), slice(60, 80, None)] >>> # Same shape of the array and blocks, but we iterate over each block 4 times >>> slices_for_invocation(x_shape=(100, 100), -... x_spec = pl.BlockSpec(lambda i, j, k: (i, j), (10, 20)), -... grid = (10, 5, 4), +... x_spec = pl.BlockSpec((10, 20), lambda i, j, k: (i, j)), +... grid = (10, 5, 4), ... invocation_indices = (2, 3, 0)) [slice(20, 30, None), slice(60, 80, None)] @@ -166,7 +168,7 @@ over the second axis: ... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32), ... grid=grid, ... in_specs=[], -... out_specs=pl.BlockSpec(out_index_map, block_shape), +... out_specs=pl.BlockSpec(block_shape, out_index_map), ... interpret=True)() ... print(res) @@ -195,7 +197,7 @@ Hence, we iterate over the same output block 10 times. The output shown below was generated on CPU using `interpret=True` mode, which at the moment executes the invocation sequentially. On TPUs, programs are executed in a combination of parallel and sequential, -and this function generates the output shown. +and this function generates the output shown. See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). ```python diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index ea3209a57..5a8608f49 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -121,9 +121,10 @@ "source": [ "@jax.jit\n", "def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " return pl.pallas_call(add_vectors_kernel,\n", - " out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - " )(x, y)\n", + " return pl.pallas_call(\n", + " add_vectors_kernel,\n", + " out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + " )(x, y)\n", "add_vectors(jnp.arange(8), jnp.arange(8))" ] }, @@ -378,12 +379,12 @@ "each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication.\n", "To express this, we'd first use a `(2, 2)` grid (one block for each program).\n", "\n", - "For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this\n", + "For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this\n", "carves `x` up into \"row\" blocks.\n", "To see this see how both program instances\n", "`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.\n", - "For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`.\n", - "Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.\n", + "For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`.\n", + "Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`.\n", "\n", "These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n", "\n", @@ -408,11 +409,11 @@ " out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n", " grid=(2, 2),\n", " in_specs=[\n", - " pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n", - " pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n", + " pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),\n", + " pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))\n", " ],\n", " out_specs=pl.BlockSpec(\n", - " lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n", + " (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),\n", " )\n", " )(x, y)\n", "k1, k2 = jax.random.split(jax.random.key(0))\n", @@ -448,11 +449,11 @@ " out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),\n", " grid=(2, 2),\n", " in_specs=[\n", - " pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),\n", - " pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))\n", + " pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),\n", + " pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))\n", " ],\n", " out_specs=pl.BlockSpec(\n", - " lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n", + " (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)\n", " ),\n", " )(x, y)\n", "k1, k2 = jax.random.split(jax.random.key(0))\n", diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index 05685d0f6..36cc14bf5 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -81,9 +81,10 @@ We use the `pallas_call` higher-order function. ```{code-cell} ipython3 @jax.jit def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: - return pl.pallas_call(add_vectors_kernel, - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) - )(x, y) + return pl.pallas_call( + add_vectors_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) add_vectors(jnp.arange(8), jnp.arange(8)) ``` @@ -272,12 +273,12 @@ the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication. To express this, we'd first use a `(2, 2)` grid (one block for each program). -For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this +For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this carves `x` up into "row" blocks. To see this see how both program instances `(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`. -For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`. -Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`. +For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`. +Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`. These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`. @@ -296,11 +297,11 @@ def matmul(x: jax.Array, y: jax.Array): out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), grid=(2, 2), in_specs=[ - pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])), - pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2)) + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) ], out_specs=pl.BlockSpec( - lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2) + (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), ) )(x, y) k1, k2 = jax.random.split(jax.random.key(0)) @@ -325,11 +326,11 @@ def matmul(x: jax.Array, y: jax.Array, *, activation): out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), grid=(2, 2), in_specs=[ - pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])), - pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2)) + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) ], out_specs=pl.BlockSpec( - lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2) + (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j) ), )(x, y) k1, k2 = jax.random.split(jax.random.key(0)) diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index ed37f88a0..275a72f38 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -70,7 +70,7 @@ " Compute units operate on values that live in SREGs and VREGs and output\n", " values into those registers as well.\n", "\n", - "In order to do a vectorized computation on our values `x` and `y` that live \n", + "In order to do a vectorized computation on our values `x` and `y` that live\n", "in HBM, we need to:\n", "\n", "1. Copy the values `x` and `y` into VMEM.\n", @@ -174,7 +174,7 @@ "Pallas exposes access to lower level memory spaces like VMEM and SMEM but\n", "writing kernels utilizing them adds some considerations.\n", "\n", - "1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB \n", + "1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB\n", " and SMEM ranges in the tens to hundreds of KiB.\n", " If our arrays are too big, we won't even be able to fit them into VMEM at all.\n", " For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't\n", @@ -289,7 +289,7 @@ "\n", "We'd then construct the `BlockSpec`:\n", "```python\n", - "block_spec = pl.BlockSpec(x_index_map, (256, 512))\n", + "block_spec = pl.BlockSpec((256, 512), x_index_map)\n", "```\n", "\n", "The `BlockSpec`s for `y` and `z` will be the same as the one for `x`." @@ -335,13 +335,14 @@ ], "source": [ "def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n", + " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", " return pl.pallas_call(\n", " add_matrices_kernel,\n", " out_shape=x,\n", " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", - " grid=(2,))(x, y)\n", + " grid=(2,)\n", + " )(x, y)\n", "\n", "add_matrices_pipelined(x, y)" ] @@ -403,8 +404,7 @@ " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", ") -> jax.Array:\n", " m, n = x.shape\n", - " block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))\n", - "\n", + " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", " return pl.pallas_call(\n", " add_matrices_kernel,\n", " out_shape=x,\n", @@ -413,7 +413,6 @@ " grid=(m // bm, n // bn),\n", " )(x, y)\n", "\n", - "\n", "np.testing.assert_array_equal(\n", " add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y\n", ")\n", @@ -526,10 +525,10 @@ " naive_sum_kernel,\n", " grid=grid,\n", " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n", - " out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n", - " )(x)\n", + " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", + " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", "naive_sum(x)" ] }, @@ -603,10 +602,11 @@ " sum_kernel,\n", " grid=grid,\n", " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],\n", - " out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),\n", + " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", + " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n", - " )(x)\n", + " )(x)\n", + "\n", "sum(x)" ] }, @@ -689,15 +689,15 @@ ], "source": [ "def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))\n", + " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", " return pl.pallas_call(\n", " add_matrices_kernel,\n", " out_shape=x,\n", " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",))))(\n", - " x, y)\n", + " compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n", + " )(x, y)\n", "\n", "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", "add_matrices_pipelined_megacore(x, y)" diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index cab4d6b28..d753b404d 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -62,7 +62,7 @@ Let's talk about the components of this diagram in more detail: Compute units operate on values that live in SREGs and VREGs and output values into those registers as well. -In order to do a vectorized computation on our values `x` and `y` that live +In order to do a vectorized computation on our values `x` and `y` that live in HBM, we need to: 1. Copy the values `x` and `y` into VMEM. @@ -129,7 +129,7 @@ the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. Pallas exposes access to lower level memory spaces like VMEM and SMEM but writing kernels utilizing them adds some considerations. -1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB +1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB and SMEM ranges in the tens to hundreds of KiB. If our arrays are too big, we won't even be able to fit them into VMEM at all. For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't @@ -229,7 +229,7 @@ def x_index_map(i): We'd then construct the `BlockSpec`: ```python -block_spec = pl.BlockSpec(x_index_map, (256, 512)) +block_spec = pl.BlockSpec((256, 512), x_index_map) ``` The `BlockSpec`s for `y` and `z` will be the same as the one for `x`. @@ -247,13 +247,14 @@ and `out_specs` corresponds to the output). :outputId: 504bab29-83f3-4e1f-8664-1860ad15b6de def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array: - block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512)) + block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) return pl.pallas_call( add_matrices_kernel, out_shape=x, in_specs=[block_spec, block_spec], out_specs=block_spec, - grid=(2,))(x, y) + grid=(2,) + )(x, y) add_matrices_pipelined(x, y) ``` @@ -295,8 +296,7 @@ def add_matrices_pipelined_2d( x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256 ) -> jax.Array: m, n = x.shape - block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn)) - + block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j)) return pl.pallas_call( add_matrices_kernel, out_shape=x, @@ -305,7 +305,6 @@ def add_matrices_pipelined_2d( grid=(m // bm, n // bn), )(x, y) - np.testing.assert_array_equal( add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y ) @@ -359,10 +358,10 @@ def naive_sum(x: jax.Array) -> jax.Array: naive_sum_kernel, grid=grid, # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))], - out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape), - out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype) - )(x) + in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], + out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) naive_sum(x) ``` @@ -409,10 +408,11 @@ def sum(x: jax.Array) -> jax.Array: sum_kernel, grid=grid, # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))], - out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape), + in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], + out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype) - )(x) + )(x) + sum(x) ``` @@ -458,15 +458,15 @@ annotation to `pallas_call` called `dimension_semantics`. :outputId: 385ed87c-d95c-466c-af77-df3845c979f2 def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: - block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512)) + block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) return pl.pallas_call( add_matrices_kernel, out_shape=x, in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))( - x, y) + compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))) + )(x, y) x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) add_matrices_pipelined_megacore(x, y) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 85a2eda28..2af17fc51 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -16,12 +16,14 @@ from __future__ import annotations from collections.abc import Callable, Iterator, Sequence -import copy import contextlib +import copy import dataclasses import functools +import inspect import threading from typing import Any, Union +import warnings import jax from jax._src import api_util @@ -169,16 +171,47 @@ blocked = Blocked() IndexingMode = Union[Blocked, Unblocked] +_BLOCK_SPECSIG = inspect.Signature.from_callable( + lambda index_map, block_shape: ... +) + @dataclasses.dataclass(unsafe_hash=True) class BlockSpec: """Specifies how an array should be sliced for each iteration of a kernel. See :ref:`pallas_blockspec` for more details. """ - index_map: Callable[..., Any] | None = None block_shape: tuple[int | None, ...] | None = None - memory_space: Any | None = None - indexing_mode: IndexingMode = blocked + index_map: Callable[..., Any] | None = None + memory_space: Any | None = dataclasses.field(kw_only=True, default=None) + indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked) + + def __init__( + self, + *args, + memory_space: Any | None = None, + indexing_mode: IndexingMode = blocked, + **kwargs: Any, + ) -> None: + bound_args = _BLOCK_SPECSIG.bind_partial(*args, **kwargs) + block_shape = bound_args.arguments.get("block_shape", None) + index_map = bound_args.arguments.get("index_map", None) + if callable(block_shape): + # TODO(slebedev): Remove this code path and update the signature of + # __init__ after October 1, 2024. + warnings.warn( + "BlockSpec now expects ``block_shape`` to be passed before" + " ``index_map``. Update your code by swapping the order of these" + " arguments. For example, ``pl.BlockSpace(lambda i: i, (42,))``" + " should be written as ``pl.BlockSpec((42,), lambda i: i)``.", + DeprecationWarning, + ) + index_map, block_shape = block_shape, index_map + + self.block_shape = block_shape + self.index_map = index_map + self.memory_space = memory_space + self.indexing_mode = indexing_mode def compute_index(self, *args): assert self.index_map is not None diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index ea61f1c7f..41b2d2e61 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1033,7 +1033,7 @@ def pallas_call( a structure matching that of the outputs. See details at :ref:`pallas_blockspec`. The default value for `out_specs` specifies the whole array, - e.g., as `pl.BlockSpec(lambda *indices: indices, x.shape)`. + e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``. input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the diff --git a/pyproject.toml b/pyproject.toml index e81320c13..21b32fc92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,8 @@ filterwarnings = [ "default:Special cases found for .* but none were parsed.*:UserWarning", "default:.*is not JSON-serializable. Using the repr instead.", # end array_api_tests-related warnings + # TODO(slebedev): Remove once we migrate all pl.BlockSpec usages in JAX. + "default:BlockSpec now expects .*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER", diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index b09d4b307..17431dc33 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -604,8 +604,12 @@ class PallasCallDMATest(PallasTPUTest): x = jnp.ones((8, 128), dtype=jnp.float32) y = self.pallas_call( kernel, - in_specs=[pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY)], - out_specs=pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY), + in_specs=[ + pl.BlockSpec(None, None, memory_space=pltpu.TPUMemorySpace.ANY) + ], + out_specs=pl.BlockSpec( + None, None, memory_space=pltpu.TPUMemorySpace.ANY + ), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) jax.block_until_ready(y)