mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 04:36:07 +00:00
[mlir][spirv] Add TransposeOp
Add Transpose operation to SPIRV dialect. Differential Revision: https://reviews.llvm.org/D82308
This commit is contained in:
parent
090c108d04
commit
2bcb620868
@ -3141,6 +3141,7 @@ def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>
|
||||
def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>;
|
||||
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
|
||||
def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>;
|
||||
def SPV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 84>;
|
||||
def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>;
|
||||
def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>;
|
||||
def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>;
|
||||
@ -3265,20 +3266,21 @@ def SPV_OpcodeAttr :
|
||||
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
|
||||
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
|
||||
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
|
||||
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
|
||||
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
|
||||
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
|
||||
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
|
||||
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
|
||||
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
|
||||
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
|
||||
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
|
||||
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
|
||||
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
|
||||
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
|
||||
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
|
||||
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
|
||||
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
|
||||
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
|
||||
SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
|
||||
SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
|
||||
SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
|
||||
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
|
||||
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
|
||||
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
|
||||
SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
|
||||
SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
|
||||
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
|
||||
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
|
||||
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
|
||||
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
|
||||
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
|
||||
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
|
||||
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
|
||||
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
|
||||
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
|
||||
|
@ -45,6 +45,13 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
|
||||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_Matrix]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_AnyMatrix:$matrix,
|
||||
SPV_Float:$scalar
|
||||
@ -72,4 +79,58 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_TransposeOp : SPV_Op<"Transpose", []> {
|
||||
let summary = "Transpose a matrix.";
|
||||
|
||||
let description = [{
|
||||
Result Type must be an OpTypeMatrix.
|
||||
|
||||
Matrix must be an object of type OpTypeMatrix. The number of columns and
|
||||
the column size of Matrix must be the reverse of those in Result Type.
|
||||
The types of the scalar components in Matrix and Result Type must be the
|
||||
same.
|
||||
|
||||
Matrix must have of type of OpTypeMatrix.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
transpose-op ::= ssa-id `=` `spv.Transpose` ssa-use `:` matrix-type `->`
|
||||
matrix-type
|
||||
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%0 = spv.Transpose %matrix: !spv.matrix<2 x vector<3xf32>> ->
|
||||
!spv.matrix<3 x vector<2xf32>>
|
||||
|
||||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_Matrix]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_AnyMatrix:$matrix
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_AnyMatrix:$result
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict `:` type($matrix) `->` type($result)
|
||||
}];
|
||||
|
||||
let verifier = [{ return verifyTranspose(*this); }];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#endif // SPIRV_MATRIX_OPS
|
@ -2815,6 +2815,36 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Transpose
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyTranspose(spirv::TransposeOp op) {
|
||||
auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
|
||||
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
|
||||
|
||||
// Verify that the input and output matrices have correct shapes.
|
||||
if (auto inputMatrixColumns =
|
||||
inputMatrix.getElementType().dyn_cast<VectorType>()) {
|
||||
if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements())
|
||||
return op.emitError("input matrix rows count must be equal to "
|
||||
"output matrix columns count");
|
||||
if (auto resultMatrixColumns =
|
||||
resultMatrix.getElementType().dyn_cast<VectorType>()) {
|
||||
if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements())
|
||||
return op.emitError("input matrix columns count must be equal "
|
||||
"to output matrix rows count");
|
||||
|
||||
// Verify that the input and output matrices have the same component type
|
||||
if (inputMatrixColumns.getElementType() !=
|
||||
resultMatrixColumns.getElementType())
|
||||
return op.emitError("input and output matrices must have the "
|
||||
"same component type");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
|
@ -22,6 +22,13 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>>
|
||||
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @matrix_transpose_1
|
||||
spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" {
|
||||
// CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
|
||||
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -2,11 +2,25 @@
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
// CHECK-LABEL: @matrix_times_scalar
|
||||
spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
|
||||
spv.func @matrix_times_scalar(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
|
||||
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
|
||||
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @matrix_transpose_1
|
||||
spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" {
|
||||
// CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
|
||||
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @matrix_transpose_2
|
||||
spv.func @matrix_transpose_2(%arg0 : !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None" {
|
||||
// CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
|
||||
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -37,5 +51,26 @@ func @input_output_size_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 :
|
||||
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_op_shape_mismatch_1(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
|
||||
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
|
||||
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<3 x vector<3xf32>>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_op_shape_mismatch_2(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
|
||||
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
|
||||
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<2 x vector<4xf32>>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_op_type_mismatch(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
|
||||
// expected-error @+1 {{input and output matrices must have the same component type}}
|
||||
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>>
|
||||
spv.Return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user