[Mosaic GPU] Raise a NotImplementedError if swizzle=16.

Unswizzled MMAs don't lower correctly, and are not currently intended to be
supported.

PiperOrigin-RevId: 737981373
This commit is contained in:
Benjamin Chetioui 2025-03-18 06:28:05 -07:00 committed by jax authors
parent 8da93249d2
commit 1e36cbe597
2 changed files with 4 additions and 0 deletions

View File

@ -83,6 +83,8 @@ def mma(
accumulate: ir.Value | bool = True,
collective: bool = False,
):
if a_swizzle == 16 or b_swizzle == 16:
raise NotImplementedError("No swizzle is not supported")
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
if isinstance(accumulate, bool):

View File

@ -259,6 +259,8 @@ def wgmma(
The refs must be contiguous or be contiguous except for having their two minor
dimensions swapped.
"""
if swizzle == 16:
raise NotImplementedError("No swizzle is not supported")
# Step 1. Establish the shape and element type of the operation.
if not ir.MemRefType.isinstance(b.type):
raise ValueError(f"B must be a memref, got: {b.type}")