diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index e5a2d3aa5..3330500cd 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -83,6 +83,8 @@ def mma( accumulate: ir.Value | bool = True, collective: bool = False, ): + if a_swizzle == 16 or b_swizzle == 16: + raise NotImplementedError("No swizzle is not supported") i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) if isinstance(accumulate, bool): diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index ce0c5946a..8baa16d8a 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -259,6 +259,8 @@ def wgmma( The refs must be contiguous or be contiguous except for having their two minor dimensions swapped. """ + if swizzle == 16: + raise NotImplementedError("No swizzle is not supported") # Step 1. Establish the shape and element type of the operation. if not ir.MemRefType.isinstance(b.type): raise ValueError(f"B must be a memref, got: {b.type}")