From 50b2b26e70fd904c44b4e80788e1cb64ce2b7c9d Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 3 Dec 2019 16:43:40 -0800 Subject: [PATCH] [spirv] Add spv.GroupNonUniformBallot This CL also did the following cleanup: - Moved the test for spv.SubgroupBallotKHR to its own file - Wrapped generated canonicalization patterns in anonymous namespace - Updated header comments in SPVOps.td PiperOrigin-RevId: 283650091 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 7 +- .../mlir/Dialect/SPIRV/SPIRVNonUniformOps.td | 78 +++++++++++++++++++ mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 12 +-- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 40 ++++++++++ .../SPIRV/Serialization/non-uniform-ops.mlir | 10 +++ mlir/test/Dialect/SPIRV/group-ops.mlir | 11 +++ mlir/test/Dialect/SPIRV/non-uniform-ops.mlir | 19 +++++ 7 files changed, 170 insertions(+), 7 deletions(-) create mode 100644 mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td create mode 100644 mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir create mode 100644 mlir/test/Dialect/SPIRV/group-ops.mlir create mode 100644 mlir/test/Dialect/SPIRV/non-uniform-ops.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index bfb7497aadac..2ee8f3bdd43f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -953,7 +953,9 @@ class SPV_ScalarOrVectorOf : def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; -def SPV_I32Vec4 : VectorOfLengthAndType<[4], [I32]>; +class SPV_Vec4 : VectorOfLengthAndType<[4], [type]>; +def SPV_IntVec4 : SPV_Vec4; +def SPV_I32Vec4 : SPV_Vec4; // TODO(antiagainst): Use a more appropriate way to model optional operands class SPV_Optional : Variadic; @@ -1109,6 +1111,7 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>; def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; +def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>; def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPV_OpcodeAttr : @@ -1150,7 +1153,7 @@ def SPV_OpcodeAttr : SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed, - SPV_OC_OpSubgroupBallotKHR + SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR ]> { let cppNamespace = "::mlir::spirv"; } diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td new file mode 100644 index 000000000000..a37f5b576fd6 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -0,0 +1,78 @@ +//===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains non-uniform ops for the SPIR-V dialect. It corresponds to +// "3.32.24. Non-Uniform Instructions" of the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_NON_UNIFORM_OPS +#define SPIRV_NON_UNIFORM_OPS + +// ----- + +def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> { + let summary = [{ + Returns a bitfield value combining the Predicate value from all + invocations in the group that execute the same dynamic instance of this + instruction. The bit is set to one if the corresponding invocation is + active and the Predicate for that invocation evaluated to true; + otherwise, it is set to zero. + }]; + + let description = [{ + Result Type must be a vector of four components of integer type scalar, + whose Signedness operand is 0. + + Result is a set of bitfields where the first invocation is represented + in the lowest bit of the first vector component and the last (up to the + size of the group) is the higher bit number of the last bitmask needed + to represent all bits of the group invocations. + + Execution must be Workgroup or Subgroup Scope. + + Predicate must be a Boolean type. + + ### Custom assembly form + + ``` {.ebnf} + scope ::= `"Workgroup"` | `"Subgroup"` + non-uniform-ballot-op ::= ssa-id `=` `spv.GroupNonUniformBallot` scope + ssa-use `:` `vector` `<` 4 `x` `integer-type` `>` + ``` + + For example: + + ``` + %0 = spv.GroupNonUniformBallot "SubGroup" %predicate : vector<4xi32> + ``` + }]; + + let arguments = (ins + SPV_ScopeAttr:$execution_scope, + SPV_Bool:$predicate + ); + + let results = (outs + SPV_IntVec4:$result + ); +} + +// ----- + +#endif // SPIRV_NON_UNIFORM_OPS + diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 178db0add4ee..149c2359fdaa 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -20,11 +20,12 @@ // //===----------------------------------------------------------------------===// -// Note that for each op in this file, we use a tool to automatically generate -// certain sections in its definition: basic structure, summary, description. -// So modifications to these sections will not be respected. Modifications to -// op traits, arguments, results, and sections after the results are retained. -// Besides, ops in this file must be separated via the '// -----' marker. +// Note that for each op in this file and the included files for specific op +// categories, we use a tool to automatically generate certain sections in its +// definition: basic structure, summary, description. So modifications to these +// sections will not be respected. Modifications to op traits, arguments, +// results, and sections after the results are retained. Besides, ops must be +// separated via the '// -----' marker. #ifndef SPIRV_OPS #define SPIRV_OPS @@ -37,6 +38,7 @@ include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td" include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" include "mlir/Dialect/SPIRV/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" +include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td" include "mlir/Dialect/SPIRV/SPIRVStructureOps.td" // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 6e115f7ba763..89abbe894e65 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -385,7 +385,9 @@ static inline bool isMergeBlock(Block &block) { // TableGen'erated canonicalizers //===----------------------------------------------------------------------===// +namespace { #include "SPIRVCanonicalization.inc" +} //===----------------------------------------------------------------------===// // Common parsers and printers @@ -1551,6 +1553,44 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) { return success(); } +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformBallotOp +//===----------------------------------------------------------------------===// + +static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser, + OperationState &state) { + spirv::Scope executionScope; + OpAsmParser::OperandType operandInfo; + Type resultType; + IntegerType i1Type = parser.getBuilder().getI1Type(); + if (parseEnumAttribute(executionScope, parser, state, + kExecutionScopeAttrName) || + parser.parseOperand(operandInfo) || parser.parseColonType(resultType) || + parser.resolveOperand(operandInfo, i1Type, state.operands)) + return failure(); + + return parser.addTypeToList(resultType, state.types); +} + +static void print(spirv::GroupNonUniformBallotOp ballotOp, + OpAsmPrinter &printer) { + printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \"" + << stringifyScope(ballotOp.execution_scope()) << "\" "; + printer.printOperand(ballotOp.predicate()); + printer << " : " << ballotOp.getType(); +} + +static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { + // TODO(antiagainst): check the result integer type's signedness bit is 0. + + spirv::Scope scope = ballotOp.execution_scope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return ballotOp.emitOpError( + "execution scope must be 'Workgroup' or 'Subgroup'"); + + return success(); +} + //===----------------------------------------------------------------------===// // spv.IAdd //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir new file mode 100644 index 000000000000..282811ec6ec8 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s + +spv.module "Logical" "GLSL450" { + // CHECK-LABEL: @group_non_uniform_ballot + func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> { + // CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32> + %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32> + spv.ReturnValue %0: vector<4xi32> + } +} diff --git a/mlir/test/Dialect/SPIRV/group-ops.mlir b/mlir/test/Dialect/SPIRV/group-ops.mlir new file mode 100644 index 000000000000..ba5e79209e31 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/group-ops.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.SubgroupBallotKHR +//===----------------------------------------------------------------------===// + +func @subgroup_ballot(%predicate: i1) -> vector<4xi32> { + // CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}} : vector<4xi32> + %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32> + return %0: vector<4xi32> +} diff --git a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir new file mode 100644 index 000000000000..483a7319ab6f --- /dev/null +++ b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformBallot +//===----------------------------------------------------------------------===// + +func @subgroup_ballot(%predicate: i1) -> vector<4xi32> { + // CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32> + %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32> + return %0: vector<4xi32> +} + +// ----- + +func @subgroup_ballot(%predicate: i1) -> vector<4xi32> { + // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} + %0 = spv.GroupNonUniformBallot "Device" %predicate : vector<4xi32> + return %0: vector<4xi32> +}