[mosaic:gpu] Relax constraint on stages in matmul example.

PiperOrigin-RevId: 638993045
This commit is contained in:
Chris Jones 2024-05-31 04:39:24 -07:00 committed by jax authors
parent 8deed95c7f
commit 3c64066097

View File

@ -220,7 +220,6 @@ class WGMMACvtRhsImpl:
return wgmma(acc, a_slice, smem_scratch["cvt"], b_order=b_order)
def mlir_context(f):
def wrap(*args, **kw):
with mlir.make_ir_context(), ir.Location.unknown():
@ -262,15 +261,13 @@ def build_kernel(
f" {((lhs_128b_elems, lhs_dtype), (rhs_128b_elems, rhs_dtype))}"
)
if k % (stages * tile_k) != 0:
raise ValueError(
f"k must be divisible by {stages=} * {tile_k=} (={stages * tile_k}),"
f" but got {k=}"
)
if k % tile_k != 0:
raise ValueError(f"k must be divisible by {tile_k=}, but got {k=}")
block_tiling = Tiling(m=tile_m, n=tile_n, k=tile_k)
tma_tiling = Tiling(m=64, n=rhs_128b_elems, k=lhs_128b_elems)
k_steps = k // block_tiling.k
stages = min(stages, k_steps)
f32 = ir.F32Type.get()
index = ir.IndexType.get()