mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix errata in block-sparse kernel tutorial.
Correct M//blk_M to N//blk_N. It was ok because both values happen to be same. In addition, grid order is (num_blocks, j) as 'num_blocks' replaces 'i'. PiperOrigin-RevId: 677817478
This commit is contained in:
parent
c05706b7a9
commit
91f16419bb
@ -312,13 +312,13 @@
|
||||
" o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||
"def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
||||
" del j, blk_idxs_i, blk_idxs_k\n",
|
||||
" return (blk_idx, 0, 0)\n",
|
||||
"def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||
"def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
||||
" del blk_idxs_i\n",
|
||||
" return (blk_idxs_k[blk_idx], j)\n",
|
||||
"def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||
"def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
||||
" del blk_idxs_k\n",
|
||||
" return (blk_idxs_i[blk_idx], j)\n",
|
||||
"\n",
|
||||
@ -333,7 +333,7 @@
|
||||
" num_scalar_prefetch=2,\n",
|
||||
" # Note that while num_blocks is static here, Pallas does support\n",
|
||||
" # dynamic grid sizes.\n",
|
||||
" grid=(M // blk_M, num_blocks),\n",
|
||||
" grid=(num_blocks, N // blk_N),\n",
|
||||
" in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
|
||||
" pl.BlockSpec((blk_K, blk_N), y_map),\n",
|
||||
" # Placeholder for a zeros-array used by input_output_aliases.\n",
|
||||
|
@ -252,13 +252,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
|
||||
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
|
||||
|
||||
|
||||
def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||
def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
||||
del j, blk_idxs_i, blk_idxs_k
|
||||
return (blk_idx, 0, 0)
|
||||
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
||||
del blk_idxs_i
|
||||
return (blk_idxs_k[blk_idx], j)
|
||||
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
||||
del blk_idxs_k
|
||||
return (blk_idxs_i[blk_idx], j)
|
||||
|
||||
@ -273,7 +273,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=2,
|
||||
# Note that while num_blocks is static here, Pallas does support
|
||||
# dynamic grid sizes.
|
||||
grid=(M // blk_M, num_blocks),
|
||||
grid=(num_blocks, N // blk_N),
|
||||
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
|
||||
pl.BlockSpec((blk_K, blk_N), y_map),
|
||||
# Placeholder for a zeros-array used by input_output_aliases.
|
||||
|
Loading…
x
Reference in New Issue
Block a user