[Mosaic GPU] Add a result to the WGMMA op definition in the MLIR dialect

PiperOrigin-RevId: 718788390
This commit is contained in:
Dimitar (Mitko) Asenov 2025-01-23 03:09:38 -08:00 committed by jax authors
parent 3a411d883a
commit 6b747b4109

View File

@ -321,10 +321,10 @@ def MosaicGPU_WGMMALayout :
let genSpecializedAttr = 0;
}
def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations.";
let description = [{
Schedules WGMMA operations that perform the following matrix multiple and
Schedules WGMMA operations that perform the following matrix multiply and
accumulate:
accumulator = a * b + accumulator
@ -346,6 +346,8 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
Where:
- `k == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.)
The output has an identical shape and type as the input accumulator.
The `accumulator` is always in registers and `b` is always in shared memory.
The last two dimensions of any input in shared memory may be physically
transposed in memory. This is inferred from the strides of the provided
@ -372,12 +374,32 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
// Attributes
DefaultValuedAttr<MosaicGPU_SwizzlingModeAttr, "SwizzlingMode::k128ByteSwizzle">:$swizzle
);
let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>);
let assemblyFormat = [{
`accumulator` `(` $accumulator `:` type($accumulator) `)`
`a` `(` $a `:` type($a) `)`
`b` `(` $b `:` type($b) `)`
attr-dict
`->` type(results)
}];
let extraClassDeclaration = [{
static llvm::LogicalResult inferReturnTypes(
mlir::MLIRContext *,
std::optional<mlir::Location> location,
mlir::ValueRange operands,
mlir::DictionaryAttr attributes,
mlir::OpaqueProperties properties,
mlir::RegionRange regions,
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
if (operands.empty()) {
return ::mlir::emitOptionalError(
location, "expected non-empty operands");
}
inferredReturnTypes.assign({operands[0].getType()});
return ::mlir::success();
}
}];
let hasVerifier = 1;