mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic GPU] Add a result to the WGMMA op definition in the MLIR dialect
PiperOrigin-RevId: 718788390
This commit is contained in:
parent
3a411d883a
commit
6b747b4109
@ -321,10 +321,10 @@ def MosaicGPU_WGMMALayout :
|
||||
let genSpecializedAttr = 0;
|
||||
}
|
||||
|
||||
def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
|
||||
def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
|
||||
let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations.";
|
||||
let description = [{
|
||||
Schedules WGMMA operations that perform the following matrix multiple and
|
||||
Schedules WGMMA operations that perform the following matrix multiply and
|
||||
accumulate:
|
||||
|
||||
accumulator = a * b + accumulator
|
||||
@ -346,6 +346,8 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
|
||||
Where:
|
||||
- `k == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.)
|
||||
|
||||
The output has an identical shape and type as the input accumulator.
|
||||
|
||||
The `accumulator` is always in registers and `b` is always in shared memory.
|
||||
The last two dimensions of any input in shared memory may be physically
|
||||
transposed in memory. This is inferred from the strides of the provided
|
||||
@ -372,12 +374,32 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
|
||||
// Attributes
|
||||
DefaultValuedAttr<MosaicGPU_SwizzlingModeAttr, "SwizzlingMode::k128ByteSwizzle">:$swizzle
|
||||
);
|
||||
let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`accumulator` `(` $accumulator `:` type($accumulator) `)`
|
||||
`a` `(` $a `:` type($a) `)`
|
||||
`b` `(` $b `:` type($b) `)`
|
||||
attr-dict
|
||||
`->` type(results)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static llvm::LogicalResult inferReturnTypes(
|
||||
mlir::MLIRContext *,
|
||||
std::optional<mlir::Location> location,
|
||||
mlir::ValueRange operands,
|
||||
mlir::DictionaryAttr attributes,
|
||||
mlir::OpaqueProperties properties,
|
||||
mlir::RegionRange regions,
|
||||
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
|
||||
if (operands.empty()) {
|
||||
return ::mlir::emitOptionalError(
|
||||
location, "expected non-empty operands");
|
||||
}
|
||||
inferredReturnTypes.assign({operands[0].getType()});
|
||||
return ::mlir::success();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
Loading…
x
Reference in New Issue
Block a user