mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mosaic:gpu] Relax constraint on stages
in matmul example.
PiperOrigin-RevId: 638993045
This commit is contained in:
parent
8deed95c7f
commit
3c64066097
@ -220,7 +220,6 @@ class WGMMACvtRhsImpl:
|
||||
return wgmma(acc, a_slice, smem_scratch["cvt"], b_order=b_order)
|
||||
|
||||
|
||||
|
||||
def mlir_context(f):
|
||||
def wrap(*args, **kw):
|
||||
with mlir.make_ir_context(), ir.Location.unknown():
|
||||
@ -262,15 +261,13 @@ def build_kernel(
|
||||
f" {((lhs_128b_elems, lhs_dtype), (rhs_128b_elems, rhs_dtype))}"
|
||||
)
|
||||
|
||||
if k % (stages * tile_k) != 0:
|
||||
raise ValueError(
|
||||
f"k must be divisible by {stages=} * {tile_k=} (={stages * tile_k}),"
|
||||
f" but got {k=}"
|
||||
)
|
||||
if k % tile_k != 0:
|
||||
raise ValueError(f"k must be divisible by {tile_k=}, but got {k=}")
|
||||
|
||||
block_tiling = Tiling(m=tile_m, n=tile_n, k=tile_k)
|
||||
tma_tiling = Tiling(m=64, n=rhs_128b_elems, k=lhs_128b_elems)
|
||||
k_steps = k // block_tiling.k
|
||||
stages = min(stages, k_steps)
|
||||
|
||||
f32 = ir.F32Type.get()
|
||||
index = ir.IndexType.get()
|
||||
|
Loading…
x
Reference in New Issue
Block a user