diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index 5b37e7b05..ac3a0dad2 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -299,7 +299,7 @@ " ):\n", " \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n", " del idxs_k_ref\n", - " blk_idx = pl.program_id(0)\n", + " blk_idx = pl.program_id(1)\n", " is_start = blk_idx == 0\n", " changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n", " @pl.when(is_start | changed_blocks)\n", @@ -314,13 +314,13 @@ " o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n", "\n", "\n", - "def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", + "def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", " del j, blk_idxs_i, blk_idxs_k\n", " return (blk_idx, 0, 0)\n", - "def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", + "def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", " del blk_idxs_i\n", " return (blk_idxs_k[blk_idx], j)\n", - "def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", + "def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", " del blk_idxs_k\n", " return (blk_idxs_i[blk_idx], j)\n", "\n", @@ -335,7 +335,7 @@ " num_scalar_prefetch=2,\n", " # Note that while num_blocks is static here, Pallas does support\n", " # dynamic grid sizes.\n", - " grid=(num_blocks, N // blk_N),\n", + " grid=(N // blk_N, num_blocks),\n", " in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n", " pl.BlockSpec((blk_K, blk_N), y_map),\n", " # Placeholder for a zeros-array used by input_output_aliases.\n", diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 36a6e07e9..113f31d8b 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -239,7 +239,7 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs. ): """A DSD (Dense = Sparse @ Dense) matmul kernel.""" del idxs_k_ref - blk_idx = pl.program_id(0) + blk_idx = pl.program_id(1) is_start = blk_idx == 0 changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)]) @pl.when(is_start | changed_blocks) @@ -254,13 +254,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs. o_ref[...] = accum_scratch[...].astype(o_ref.dtype) -def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k): +def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k): del j, blk_idxs_i, blk_idxs_k return (blk_idx, 0, 0) -def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k): +def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k): del blk_idxs_i return (blk_idxs_k[blk_idx], j) -def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k): +def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k): del blk_idxs_k return (blk_idxs_i[blk_idx], j) @@ -275,7 +275,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=2, # Note that while num_blocks is static here, Pallas does support # dynamic grid sizes. - grid=(num_blocks, N // blk_N), + grid=(N // blk_N, num_blocks), in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map), pl.BlockSpec((blk_K, blk_N), y_map), # Placeholder for a zeros-array used by input_output_aliases.