[mlir][spirv] Add TransposeOp

Add Transpose operation to SPIRV dialect.

Differential Revision: https://reviews.llvm.org/D82308
This commit is contained in:
HazemAbdelhafez 2020-06-24 20:34:34 -04:00 committed by Lei Zhang
parent 090c108d04
commit 2bcb620868
5 changed files with 150 additions and 15 deletions

View File

@ -3141,6 +3141,7 @@ def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>
def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>;
def SPV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 84>;
def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>;
def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>;
def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>;
@ -3265,20 +3266,21 @@ def SPV_OpcodeAttr :
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

View File

@ -45,6 +45,13 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Matrix]>
];
let arguments = (ins
SPV_AnyMatrix:$matrix,
SPV_Float:$scalar
@ -72,4 +79,58 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
// -----
def SPV_TransposeOp : SPV_Op<"Transpose", []> {
let summary = "Transpose a matrix.";
let description = [{
Result Type must be an OpTypeMatrix.
Matrix must be an object of type OpTypeMatrix. The number of columns and
the column size of Matrix must be the reverse of those in Result Type.
The types of the scalar components in Matrix and Result Type must be the
same.
Matrix must have of type of OpTypeMatrix.
<!-- End of AutoGen section -->
```
transpose-op ::= ssa-id `=` `spv.Transpose` ssa-use `:` matrix-type `->`
matrix-type
```mlir
#### Example:
```
%0 = spv.Transpose %matrix: !spv.matrix<2 x vector<3xf32>> ->
!spv.matrix<3 x vector<2xf32>>
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Matrix]>
];
let arguments = (ins
SPV_AnyMatrix:$matrix
);
let results = (outs
SPV_AnyMatrix:$result
);
let assemblyFormat = [{
operands attr-dict `:` type($matrix) `->` type($result)
}];
let verifier = [{ return verifyTranspose(*this); }];
}
// -----
#endif // SPIRV_MATRIX_OPS

View File

@ -2815,6 +2815,36 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.Transpose
//===----------------------------------------------------------------------===//
static LogicalResult verifyTranspose(spirv::TransposeOp op) {
auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
// Verify that the input and output matrices have correct shapes.
if (auto inputMatrixColumns =
inputMatrix.getElementType().dyn_cast<VectorType>()) {
if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements())
return op.emitError("input matrix rows count must be equal to "
"output matrix columns count");
if (auto resultMatrixColumns =
resultMatrix.getElementType().dyn_cast<VectorType>()) {
if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements())
return op.emitError("input matrix columns count must be equal "
"to output matrix rows count");
// Verify that the input and output matrices have the same component type
if (inputMatrixColumns.getElementType() !=
resultMatrixColumns.getElementType())
return op.emitError("input and output matrices must have the "
"same component type");
}
}
return success();
}
namespace mlir {
namespace spirv {

View File

@ -22,6 +22,13 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>>
}
// CHECK-LABEL: @matrix_transpose_1
spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
}
}
// -----

View File

@ -2,11 +2,25 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @matrix_times_scalar
spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
spv.func @matrix_times_scalar(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_transpose_1
spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_transpose_2
spv.func @matrix_transpose_2(%arg0 : !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
}
// -----
@ -37,5 +51,26 @@ func @input_output_size_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 :
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>>
}
// -----
func @transpose_op_shape_mismatch_1(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<3 x vector<3xf32>>
spv.Return
}
// -----
func @transpose_op_shape_mismatch_2(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<2 x vector<4xf32>>
spv.Return
}
// -----
func @transpose_op_type_mismatch(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
// expected-error @+1 {{input and output matrices must have the same component type}}
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>>
spv.Return
}