mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Fix error in pallas tutorial
PiperOrigin-RevId: 737727935
This commit is contained in:
parent
20658fabb3
commit
4f70471310
@ -299,7 +299,7 @@
|
|||||||
" ):\n",
|
" ):\n",
|
||||||
" \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
|
" \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
|
||||||
" del idxs_k_ref\n",
|
" del idxs_k_ref\n",
|
||||||
" blk_idx = pl.program_id(0)\n",
|
" blk_idx = pl.program_id(1)\n",
|
||||||
" is_start = blk_idx == 0\n",
|
" is_start = blk_idx == 0\n",
|
||||||
" changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
|
" changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
|
||||||
" @pl.when(is_start | changed_blocks)\n",
|
" @pl.when(is_start | changed_blocks)\n",
|
||||||
@ -314,13 +314,13 @@
|
|||||||
" o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
|
" o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
"def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||||
" del j, blk_idxs_i, blk_idxs_k\n",
|
" del j, blk_idxs_i, blk_idxs_k\n",
|
||||||
" return (blk_idx, 0, 0)\n",
|
" return (blk_idx, 0, 0)\n",
|
||||||
"def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
"def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||||
" del blk_idxs_i\n",
|
" del blk_idxs_i\n",
|
||||||
" return (blk_idxs_k[blk_idx], j)\n",
|
" return (blk_idxs_k[blk_idx], j)\n",
|
||||||
"def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
|
"def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
|
||||||
" del blk_idxs_k\n",
|
" del blk_idxs_k\n",
|
||||||
" return (blk_idxs_i[blk_idx], j)\n",
|
" return (blk_idxs_i[blk_idx], j)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -335,7 +335,7 @@
|
|||||||
" num_scalar_prefetch=2,\n",
|
" num_scalar_prefetch=2,\n",
|
||||||
" # Note that while num_blocks is static here, Pallas does support\n",
|
" # Note that while num_blocks is static here, Pallas does support\n",
|
||||||
" # dynamic grid sizes.\n",
|
" # dynamic grid sizes.\n",
|
||||||
" grid=(num_blocks, N // blk_N),\n",
|
" grid=(N // blk_N, num_blocks),\n",
|
||||||
" in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
|
" in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
|
||||||
" pl.BlockSpec((blk_K, blk_N), y_map),\n",
|
" pl.BlockSpec((blk_K, blk_N), y_map),\n",
|
||||||
" # Placeholder for a zeros-array used by input_output_aliases.\n",
|
" # Placeholder for a zeros-array used by input_output_aliases.\n",
|
||||||
|
@ -239,7 +239,7 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
|
|||||||
):
|
):
|
||||||
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
|
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
|
||||||
del idxs_k_ref
|
del idxs_k_ref
|
||||||
blk_idx = pl.program_id(0)
|
blk_idx = pl.program_id(1)
|
||||||
is_start = blk_idx == 0
|
is_start = blk_idx == 0
|
||||||
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
|
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
|
||||||
@pl.when(is_start | changed_blocks)
|
@pl.when(is_start | changed_blocks)
|
||||||
@ -254,13 +254,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
|
|||||||
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
|
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
|
||||||
|
|
||||||
|
|
||||||
def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||||
del j, blk_idxs_i, blk_idxs_k
|
del j, blk_idxs_i, blk_idxs_k
|
||||||
return (blk_idx, 0, 0)
|
return (blk_idx, 0, 0)
|
||||||
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||||
del blk_idxs_i
|
del blk_idxs_i
|
||||||
return (blk_idxs_k[blk_idx], j)
|
return (blk_idxs_k[blk_idx], j)
|
||||||
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
|
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
|
||||||
del blk_idxs_k
|
del blk_idxs_k
|
||||||
return (blk_idxs_i[blk_idx], j)
|
return (blk_idxs_i[blk_idx], j)
|
||||||
|
|
||||||
@ -275,7 +275,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
|
|||||||
num_scalar_prefetch=2,
|
num_scalar_prefetch=2,
|
||||||
# Note that while num_blocks is static here, Pallas does support
|
# Note that while num_blocks is static here, Pallas does support
|
||||||
# dynamic grid sizes.
|
# dynamic grid sizes.
|
||||||
grid=(num_blocks, N // blk_N),
|
grid=(N // blk_N, num_blocks),
|
||||||
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
|
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
|
||||||
pl.BlockSpec((blk_K, blk_N), y_map),
|
pl.BlockSpec((blk_K, blk_N), y_map),
|
||||||
# Placeholder for a zeros-array used by input_output_aliases.
|
# Placeholder for a zeros-array used by input_output_aliases.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user