mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 03:16:07 +00:00
[spirv] Add spv.GroupNonUniformBallot
This CL also did the following cleanup: - Moved the test for spv.SubgroupBallotKHR to its own file - Wrapped generated canonicalization patterns in anonymous namespace - Updated header comments in SPVOps.td PiperOrigin-RevId: 283650091
This commit is contained in:
parent
c5ba37b6ae
commit
50b2b26e70
@ -953,7 +953,9 @@ class SPV_ScalarOrVectorOf<Type type> :
|
||||
def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
|
||||
def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
|
||||
|
||||
def SPV_I32Vec4 : VectorOfLengthAndType<[4], [I32]>;
|
||||
class SPV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
|
||||
def SPV_IntVec4 : SPV_Vec4<SPV_Integer>;
|
||||
def SPV_I32Vec4 : SPV_Vec4<I32>;
|
||||
|
||||
// TODO(antiagainst): Use a more appropriate way to model optional operands
|
||||
class SPV_Optional<Type type> : Variadic<type>;
|
||||
@ -1109,6 +1111,7 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
|
||||
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
|
||||
def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>;
|
||||
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
|
||||
def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
|
||||
def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
|
||||
|
||||
def SPV_OpcodeAttr :
|
||||
@ -1150,7 +1153,7 @@ def SPV_OpcodeAttr :
|
||||
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
|
||||
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
|
||||
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
|
||||
SPV_OC_OpSubgroupBallotKHR
|
||||
SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR
|
||||
]> {
|
||||
let cppNamespace = "::mlir::spirv";
|
||||
}
|
||||
|
78
mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
Normal file
78
mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
Normal file
@ -0,0 +1,78 @@
|
||||
//===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains non-uniform ops for the SPIR-V dialect. It corresponds to
|
||||
// "3.32.24. Non-Uniform Instructions" of the SPIR-V specification.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef SPIRV_NON_UNIFORM_OPS
|
||||
#define SPIRV_NON_UNIFORM_OPS
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
|
||||
let summary = [{
|
||||
Returns a bitfield value combining the Predicate value from all
|
||||
invocations in the group that execute the same dynamic instance of this
|
||||
instruction. The bit is set to one if the corresponding invocation is
|
||||
active and the Predicate for that invocation evaluated to true;
|
||||
otherwise, it is set to zero.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type must be a vector of four components of integer type scalar,
|
||||
whose Signedness operand is 0.
|
||||
|
||||
Result is a set of bitfields where the first invocation is represented
|
||||
in the lowest bit of the first vector component and the last (up to the
|
||||
size of the group) is the higher bit number of the last bitmask needed
|
||||
to represent all bits of the group invocations.
|
||||
|
||||
Execution must be Workgroup or Subgroup Scope.
|
||||
|
||||
Predicate must be a Boolean type.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
scope ::= `"Workgroup"` | `"Subgroup"`
|
||||
non-uniform-ballot-op ::= ssa-id `=` `spv.GroupNonUniformBallot` scope
|
||||
ssa-use `:` `vector` `<` 4 `x` `integer-type` `>`
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
%0 = spv.GroupNonUniformBallot "SubGroup" %predicate : vector<4xi32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_ScopeAttr:$execution_scope,
|
||||
SPV_Bool:$predicate
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_IntVec4:$result
|
||||
);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#endif // SPIRV_NON_UNIFORM_OPS
|
||||
|
@ -20,11 +20,12 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Note that for each op in this file, we use a tool to automatically generate
|
||||
// certain sections in its definition: basic structure, summary, description.
|
||||
// So modifications to these sections will not be respected. Modifications to
|
||||
// op traits, arguments, results, and sections after the results are retained.
|
||||
// Besides, ops in this file must be separated via the '// -----' marker.
|
||||
// Note that for each op in this file and the included files for specific op
|
||||
// categories, we use a tool to automatically generate certain sections in its
|
||||
// definition: basic structure, summary, description. So modifications to these
|
||||
// sections will not be respected. Modifications to op traits, arguments,
|
||||
// results, and sections after the results are retained. Besides, ops must be
|
||||
// separated via the '// -----' marker.
|
||||
|
||||
#ifndef SPIRV_OPS
|
||||
#define SPIRV_OPS
|
||||
@ -37,6 +38,7 @@ include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
|
||||
|
||||
// -----
|
||||
|
@ -385,7 +385,9 @@ static inline bool isMergeBlock(Block &block) {
|
||||
// TableGen'erated canonicalizers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
#include "SPIRVCanonicalization.inc"
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common parsers and printers
|
||||
@ -1551,6 +1553,44 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.GroupNonUniformBallotOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
spirv::Scope executionScope;
|
||||
OpAsmParser::OperandType operandInfo;
|
||||
Type resultType;
|
||||
IntegerType i1Type = parser.getBuilder().getI1Type();
|
||||
if (parseEnumAttribute(executionScope, parser, state,
|
||||
kExecutionScopeAttrName) ||
|
||||
parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
|
||||
parser.resolveOperand(operandInfo, i1Type, state.operands))
|
||||
return failure();
|
||||
|
||||
return parser.addTypeToList(resultType, state.types);
|
||||
}
|
||||
|
||||
static void print(spirv::GroupNonUniformBallotOp ballotOp,
|
||||
OpAsmPrinter &printer) {
|
||||
printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
|
||||
<< stringifyScope(ballotOp.execution_scope()) << "\" ";
|
||||
printer.printOperand(ballotOp.predicate());
|
||||
printer << " : " << ballotOp.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
|
||||
// TODO(antiagainst): check the result integer type's signedness bit is 0.
|
||||
|
||||
spirv::Scope scope = ballotOp.execution_scope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return ballotOp.emitOpError(
|
||||
"execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.IAdd
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
10
mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir
Normal file
10
mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir
Normal file
@ -0,0 +1,10 @@
|
||||
// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
// CHECK-LABEL: @group_non_uniform_ballot
|
||||
func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
|
||||
// CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
|
||||
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
|
||||
spv.ReturnValue %0: vector<4xi32>
|
||||
}
|
||||
}
|
11
mlir/test/Dialect/SPIRV/group-ops.mlir
Normal file
11
mlir/test/Dialect/SPIRV/group-ops.mlir
Normal file
@ -0,0 +1,11 @@
|
||||
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.SubgroupBallotKHR
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
|
||||
// CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}} : vector<4xi32>
|
||||
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
|
||||
return %0: vector<4xi32>
|
||||
}
|
19
mlir/test/Dialect/SPIRV/non-uniform-ops.mlir
Normal file
19
mlir/test/Dialect/SPIRV/non-uniform-ops.mlir
Normal file
@ -0,0 +1,19 @@
|
||||
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.GroupNonUniformBallot
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
|
||||
// CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
|
||||
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
|
||||
return %0: vector<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
|
||||
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
|
||||
%0 = spv.GroupNonUniformBallot "Device" %predicate : vector<4xi32>
|
||||
return %0: vector<4xi32>
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user