mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Mosaic GPU] Raise a NotImplementedError
if swizzle=16
.
Unswizzled MMAs don't lower correctly, and are not currently intended to be supported. PiperOrigin-RevId: 737981373
This commit is contained in:
parent
8da93249d2
commit
1e36cbe597
@ -83,6 +83,8 @@ def mma(
|
||||
accumulate: ir.Value | bool = True,
|
||||
collective: bool = False,
|
||||
):
|
||||
if a_swizzle == 16 or b_swizzle == 16:
|
||||
raise NotImplementedError("No swizzle is not supported")
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
if isinstance(accumulate, bool):
|
||||
|
@ -259,6 +259,8 @@ def wgmma(
|
||||
The refs must be contiguous or be contiguous except for having their two minor
|
||||
dimensions swapped.
|
||||
"""
|
||||
if swizzle == 16:
|
||||
raise NotImplementedError("No swizzle is not supported")
|
||||
# Step 1. Establish the shape and element type of the operation.
|
||||
if not ir.MemRefType.isinstance(b.type):
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user