mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
Fix error in pallas tutorial
PiperOrigin-RevId: 737727935
This commit is contained in:
parent
20658fabb3
commit
4f70471310
@ -299,7 +299,7 @@
|
||||
" ):\n",
|
||||
" \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
|
||||
" del idxs_k_ref\n",
|
||||
" blk_idx = pl.program_id(0)\n",
|
||||
" blk_idx = pl.program_id(1)\n",
|
||||
" is_start = blk_idx == 0\n",
|
||||
" changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
|
||||
" @pl.when(is_start | changed_blocks)\n",
|
||||
@ -314,13 +314,13 @@
|
||||
" o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
||||
"def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||
" del j, blk_idxs_i, blk_idxs_k\n",
|
||||
" return (blk_idx, 0, 0)\n",
|
||||
"def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
||||
"def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||
" del blk_idxs_i\n",
|
||||
" return (blk_idxs_k[blk_idx], j)\n",
|
||||
"def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
||||
"def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||
" del blk_idxs_k\n",
|
||||
" return (blk_idxs_i[blk_idx], j)\n",
|
||||
"\n",
|
||||
@ -335,7 +335,7 @@
|
||||
" num_scalar_prefetch=2,\n",
|
||||
" # Note that while num_blocks is static here, Pallas does support\n",
|
||||
" # dynamic grid sizes.\n",
|
||||
" grid=(num_blocks, N // blk_N),\n",
|
||||
" grid=(N // blk_N, num_blocks),\n",
|
||||
" in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
|
||||
" pl.BlockSpec((blk_K, blk_N), y_map),\n",
|
||||
" # Placeholder for a zeros-array used by input_output_aliases.\n",
|
||||
|
@ -239,7 +239,7 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
|
||||
):
|
||||
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
|
||||
del idxs_k_ref
|
||||
blk_idx = pl.program_id(0)
|
||||
blk_idx = pl.program_id(1)
|
||||
is_start = blk_idx == 0
|
||||
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
|
||||
@pl.when(is_start | changed_blocks)
|
||||
@ -254,13 +254,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
|
||||
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
|
||||
|
||||
|
||||
def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
||||
def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||
del j, blk_idxs_i, blk_idxs_k
|
||||
return (blk_idx, 0, 0)
|
||||
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
||||
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||
del blk_idxs_i
|
||||
return (blk_idxs_k[blk_idx], j)
|
||||
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
||||
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||
del blk_idxs_k
|
||||
return (blk_idxs_i[blk_idx], j)
|
||||
|
||||
@ -275,7 +275,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=2,
|
||||
# Note that while num_blocks is static here, Pallas does support
|
||||
# dynamic grid sizes.
|
||||
grid=(num_blocks, N // blk_N),
|
||||
grid=(N // blk_N, num_blocks),
|
||||
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
|
||||
pl.BlockSpec((blk_K, blk_N), y_map),
|
||||
# Placeholder for a zeros-array used by input_output_aliases.
|
||||
|
Loading…
x
Reference in New Issue
Block a user