[spirv] Add spv.GroupNonUniformBallot

This CL also did the following cleanup:
- Moved the test for spv.SubgroupBallotKHR to its own file
- Wrapped generated canonicalization patterns in anonymous namespace
- Updated header comments in SPVOps.td

PiperOrigin-RevId: 283650091
This commit is contained in:
Lei Zhang 2019-12-03 16:43:40 -08:00 committed by A. Unique TensorFlower
parent c5ba37b6ae
commit 50b2b26e70
7 changed files with 170 additions and 7 deletions

View File

@ -953,7 +953,9 @@ class SPV_ScalarOrVectorOf<Type type> :
def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
def SPV_I32Vec4 : VectorOfLengthAndType<[4], [I32]>;
class SPV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
def SPV_IntVec4 : SPV_Vec4<SPV_Integer>;
def SPV_I32Vec4 : SPV_Vec4<I32>;
// TODO(antiagainst): Use a more appropriate way to model optional operands
class SPV_Optional<Type type> : Variadic<type>;
@ -1109,6 +1111,7 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>;
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPV_OpcodeAttr :
@ -1150,7 +1153,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
SPV_OC_OpSubgroupBallotKHR
SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR
]> {
let cppNamespace = "::mlir::spirv";
}

View File

@ -0,0 +1,78 @@
//===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains non-uniform ops for the SPIR-V dialect. It corresponds to
// "3.32.24. Non-Uniform Instructions" of the SPIR-V specification.
//
//===----------------------------------------------------------------------===//
#ifndef SPIRV_NON_UNIFORM_OPS
#define SPIRV_NON_UNIFORM_OPS
// -----
def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
let summary = [{
Returns a bitfield value combining the Predicate value from all
invocations in the group that execute the same dynamic instance of this
instruction. The bit is set to one if the corresponding invocation is
active and the Predicate for that invocation evaluated to true;
otherwise, it is set to zero.
}];
let description = [{
Result Type must be a vector of four components of integer type scalar,
whose Signedness operand is 0.
Result is a set of bitfields where the first invocation is represented
in the lowest bit of the first vector component and the last (up to the
size of the group) is the higher bit number of the last bitmask needed
to represent all bits of the group invocations.
Execution must be Workgroup or Subgroup Scope.
Predicate must be a Boolean type.
### Custom assembly form
``` {.ebnf}
scope ::= `"Workgroup"` | `"Subgroup"`
non-uniform-ballot-op ::= ssa-id `=` `spv.GroupNonUniformBallot` scope
ssa-use `:` `vector` `<` 4 `x` `integer-type` `>`
```
For example:
```
%0 = spv.GroupNonUniformBallot "SubGroup" %predicate : vector<4xi32>
```
}];
let arguments = (ins
SPV_ScopeAttr:$execution_scope,
SPV_Bool:$predicate
);
let results = (outs
SPV_IntVec4:$result
);
}
// -----
#endif // SPIRV_NON_UNIFORM_OPS

View File

@ -20,11 +20,12 @@
//
//===----------------------------------------------------------------------===//
// Note that for each op in this file, we use a tool to automatically generate
// certain sections in its definition: basic structure, summary, description.
// So modifications to these sections will not be respected. Modifications to
// op traits, arguments, results, and sections after the results are retained.
// Besides, ops in this file must be separated via the '// -----' marker.
// Note that for each op in this file and the included files for specific op
// categories, we use a tool to automatically generate certain sections in its
// definition: basic structure, summary, description. So modifications to these
// sections will not be respected. Modifications to op traits, arguments,
// results, and sections after the results are retained. Besides, ops must be
// separated via the '// -----' marker.
#ifndef SPIRV_OPS
#define SPIRV_OPS
@ -37,6 +38,7 @@ include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
// -----

View File

@ -385,7 +385,9 @@ static inline bool isMergeBlock(Block &block) {
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//
namespace {
#include "SPIRVCanonicalization.inc"
}
//===----------------------------------------------------------------------===//
// Common parsers and printers
@ -1551,6 +1553,44 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformBallotOp
//===----------------------------------------------------------------------===//
static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
OperationState &state) {
spirv::Scope executionScope;
OpAsmParser::OperandType operandInfo;
Type resultType;
IntegerType i1Type = parser.getBuilder().getI1Type();
if (parseEnumAttribute(executionScope, parser, state,
kExecutionScopeAttrName) ||
parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
parser.resolveOperand(operandInfo, i1Type, state.operands))
return failure();
return parser.addTypeToList(resultType, state.types);
}
static void print(spirv::GroupNonUniformBallotOp ballotOp,
OpAsmPrinter &printer) {
printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
<< stringifyScope(ballotOp.execution_scope()) << "\" ";
printer.printOperand(ballotOp.predicate());
printer << " : " << ballotOp.getType();
}
static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
// TODO(antiagainst): check the result integer type's signedness bit is 0.
spirv::Scope scope = ballotOp.execution_scope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return ballotOp.emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
return success();
}
//===----------------------------------------------------------------------===//
// spv.IAdd
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,10 @@
// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
spv.module "Logical" "GLSL450" {
// CHECK-LABEL: @group_non_uniform_ballot
func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
// CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
}

View File

@ -0,0 +1,11 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// spv.SubgroupBallotKHR
//===----------------------------------------------------------------------===//
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
// CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}} : vector<4xi32>
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
return %0: vector<4xi32>
}

View File

@ -0,0 +1,19 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformBallot
//===----------------------------------------------------------------------===//
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
// CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
return %0: vector<4xi32>
}
// -----
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
%0 = spv.GroupNonUniformBallot "Device" %predicate : vector<4xi32>
return %0: vector<4xi32>
}