[mlir][spirv] Add spv.GroupNonUniformElect and spv.GroupNonUniformIAdd

Differential Revision: https://reviews.llvm.org/D73349
This commit is contained in:
Lei Zhang 2020-01-26 10:19:24 -05:00
parent 377e86d12e
commit ae21e37eb4
5 changed files with 385 additions and 8 deletions

View File

@ -2353,6 +2353,53 @@ def SPV_FunctionControlAttr :
SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const
]>;
def SPV_GO_Reduce : I32EnumAttrCase<"Reduce", 0> {
list<Availability> availability = [
Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, SPV_C_Kernel]>
];
}
def SPV_GO_InclusiveScan : I32EnumAttrCase<"InclusiveScan", 1> {
list<Availability> availability = [
Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, SPV_C_Kernel]>
];
}
def SPV_GO_ExclusiveScan : I32EnumAttrCase<"ExclusiveScan", 2> {
list<Availability> availability = [
Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, SPV_C_Kernel]>
];
}
def SPV_GO_ClusteredReduce : I32EnumAttrCase<"ClusteredReduce", 3> {
list<Availability> availability = [
MinVersion<SPV_V_1_3>,
Capability<[SPV_C_GroupNonUniformClustered]>
];
}
def SPV_GO_PartitionedReduceNV : I32EnumAttrCase<"PartitionedReduceNV", 6> {
list<Availability> availability = [
Extension<[SPV_NV_shader_subgroup_partitioned]>,
Capability<[SPV_C_GroupNonUniformPartitionedNV]>
];
}
def SPV_GO_PartitionedInclusiveScanNV : I32EnumAttrCase<"PartitionedInclusiveScanNV", 7> {
list<Availability> availability = [
Extension<[SPV_NV_shader_subgroup_partitioned]>,
Capability<[SPV_C_GroupNonUniformPartitionedNV]>
];
}
def SPV_GO_PartitionedExclusiveScanNV : I32EnumAttrCase<"PartitionedExclusiveScanNV", 8> {
list<Availability> availability = [
Extension<[SPV_NV_shader_subgroup_partitioned]>,
Capability<[SPV_C_GroupNonUniformPartitionedNV]>
];
}
def SPV_GroupOperationAttr :
SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", [
SPV_GO_Reduce, SPV_GO_InclusiveScan, SPV_GO_ExclusiveScan,
SPV_GO_ClusteredReduce, SPV_GO_PartitionedReduceNV,
SPV_GO_PartitionedInclusiveScanNV, SPV_GO_PartitionedExclusiveScanNV
]>;
def SPV_IF_Unknown : I32EnumAttrCase<"Unknown", 0>;
def SPV_IF_Rgba32f : I32EnumAttrCase<"Rgba32f", 1> {
list<Availability> availability = [
@ -3108,7 +3155,9 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>;
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
def SPV_OC_OpGroupNonUniformIAdd : I32EnumAttrCase<"OpGroupNonUniformIAdd", 349>;
def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPV_OpcodeAttr :
@ -3155,7 +3204,8 @@ def SPV_OpcodeAttr :
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR
SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot,
SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpSubgroupBallotKHR
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

View File

@ -72,5 +72,118 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
// -----
#endif // SPIRV_NON_UNIFORM_OPS
def SPV_GroupNonUniformElectOp : SPV_Op<"GroupNonUniformElect", []> {
let summary = [{
Result is true only in the active invocation with the lowest id in the
group, otherwise result is false.
}];
let description = [{
Result Type must be a Boolean type.
Execution must be Workgroup or Subgroup Scope.
### Custom assembly form
```
scope ::= `"Workgroup"` | `"Subgroup"`
non-uniform-elect-op ::= ssa-id `=` `spv.GroupNonUniformElect` scope
`:` `i1`
```
For example:
```
%0 = spv.GroupNonUniformElect : i1
```
}];
let availability = [
MinVersion<SPV_V_1_3>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_GroupNonUniform]>
];
let arguments = (ins
SPV_ScopeAttr:$execution_scope
);
let results = (outs
SPV_Bool:$result
);
let builders = [
OpBuilder<[{Builder *builder, OperationState &state, spirv::Scope}]>
];
}
// -----
def SPV_GroupNonUniformIAddOp : SPV_Op<"GroupNonUniformIAdd", []> {
let summary = [{
An integer add group operation of all Value operands contributed active
by invocations in the group.
}];
let description = [{
Result Type must be a scalar or vector of integer type.
Execution must be Workgroup or Subgroup Scope.
The identity I for Operation is 0. If Operation is ClusteredReduce,
ClusterSize must be specified.
The type of Value must be the same as Result Type.
ClusterSize is the size of cluster to use. ClusterSize must be a scalar
of integer type, whose Signedness operand is 0. ClusterSize must come
from a constant instruction. ClusterSize must be at least 1, and must be
a power of 2. If ClusterSize is greater than the declared SubGroupSize,
executing this instruction results in undefined behavior.
### Custom assembly form
```
scope ::= `"Workgroup"` | `"Subgroup"`
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
non-uniform-iadd-op ::= ssa-id `=` `spv.GroupNonUniformIAdd` scope operation
ssa-use ( `cluster_size` `(` ssa_use `)` )?
`:` integer-scalar-vector-type
```
For example:
```
%four = spv.constant 4 : i32
%scalar = ... : i32
%vector = ... : vector<4xi32>
%0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %scalar : i32
%1 = spv.GroupNonUniformIAdd "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
```
}];
let availability = [
MinVersion<SPV_V_1_3>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformClustered, SPV_C_GroupNonUniformPartitionedNV]>
];
let arguments = (ins
SPV_ScopeAttr:$execution_scope,
SPV_GroupOperationAttr:$group_operation,
SPV_ScalarOrVectorOf<SPV_Integer>:$value,
SPV_Optional<SPV_Integer>:$cluster_size
);
let results = (outs
SPV_ScalarOrVectorOf<SPV_Integer>:$result
);
}
// -----
#endif // SPIRV_NON_UNIFORM_OPS

View File

@ -32,10 +32,12 @@ using namespace mlir;
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kCallee[] = "callee";
static constexpr const char kClusterSize[] = "cluster_size";
static constexpr const char kDefaultValueAttrName[] = "default_value";
static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kGroupOperationAttrName[] = "group_operation";
static constexpr const char kIndicesAttrName[] = "indices";
static constexpr const char kInitializerAttrName[] = "initializer";
static constexpr const char kInterfaceAttrName[] = "interface";
@ -53,9 +55,8 @@ static constexpr const char kVariableAttrName[] = "variable";
// Common utility functions
//===----------------------------------------------------------------------===//
static LogicalResult extractValueFromConstOp(Operation *op,
int32_t &indexValue) {
auto constOp = dyn_cast<spirv::ConstantOp>(op);
static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
if (!constOp) {
return failure();
}
@ -64,7 +65,7 @@ static LogicalResult extractValueFromConstOp(Operation *op,
if (!integerValueAttr) {
return failure();
}
indexValue = integerValueAttr.getInt();
value = integerValueAttr.getInt();
return success();
}
@ -1888,6 +1889,122 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformElectOp
//===----------------------------------------------------------------------===//
void spirv::GroupNonUniformElectOp::build(Builder *builder,
OperationState &state,
spirv::Scope scope) {
build(builder, state, builder->getI1Type(), scope);
}
static ParseResult parseGroupNonUniformElectOp(OpAsmParser &parser,
OperationState &state) {
spirv::Scope executionScope;
Type resultType;
if (parseEnumAttribute(executionScope, parser, state,
kExecutionScopeAttrName) ||
parser.parseColonType(resultType))
return failure();
return parser.addTypeToList(resultType, state.types);
}
static void print(spirv::GroupNonUniformElectOp groupOp,
OpAsmPrinter &printer) {
printer << spirv::GroupNonUniformElectOp::getOperationName() << " \""
<< stringifyScope(groupOp.execution_scope())
<< "\" : " << groupOp.getType();
}
static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
spirv::Scope scope = groupOp.execution_scope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp.emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
return success();
}
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformIAddOp
//===----------------------------------------------------------------------===//
static ParseResult parseGroupNonUniformIAddOp(OpAsmParser &parser,
OperationState &state) {
spirv::Scope executionScope;
spirv::GroupOperation groupOperation;
OpAsmParser::OperandType valueInfo;
if (parseEnumAttribute(executionScope, parser, state,
kExecutionScopeAttrName) ||
parseEnumAttribute(groupOperation, parser, state,
kGroupOperationAttrName) ||
parser.parseOperand(valueInfo))
return failure();
Optional<OpAsmParser::OperandType> clusterSizeInfo;
if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
clusterSizeInfo = OpAsmParser::OperandType();
if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
parser.parseRParen())
return failure();
}
Type resultType;
if (parser.parseColonType(resultType))
return failure();
if (parser.resolveOperand(valueInfo, resultType, state.operands))
return failure();
if (clusterSizeInfo.hasValue()) {
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
return failure();
}
return parser.addTypeToList(resultType, state.types);
}
static void print(spirv::GroupNonUniformIAddOp groupOp, OpAsmPrinter &printer) {
printer << spirv::GroupNonUniformIAddOp::getOperationName() << " \""
<< stringifyScope(groupOp.execution_scope()) << "\" \""
<< stringifyGroupOperation(groupOp.group_operation()) << "\" "
<< groupOp.value();
if (!groupOp.cluster_size().empty())
printer << " " << kClusterSize << '(' << groupOp.cluster_size() << ')';
printer << " : " << groupOp.getType();
}
static LogicalResult verify(spirv::GroupNonUniformIAddOp groupOp) {
spirv::Scope scope = groupOp.execution_scope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp.emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
spirv::GroupOperation operation = groupOp.group_operation();
if (operation == spirv::GroupOperation::ClusteredReduce &&
groupOp.cluster_size().empty())
return groupOp.emitOpError("cluster size operand must be provided for "
"'ClusteredReduce' group operation");
if (!groupOp.cluster_size().empty()) {
Operation *sizeOp = (*groupOp.cluster_size().begin()).getDefiningOp();
int32_t clusterSize = 0;
// TODO(antiagainst): support specialization constant here.
if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
return groupOp.emitOpError(
"cluster size operand must come from a constant op");
if (!llvm::isPowerOf2_32(clusterSize))
return groupOp.emitOpError("cluster size operand must be a power of two");
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.IAdd
//===----------------------------------------------------------------------===//

View File

@ -7,4 +7,26 @@ spv.module "Logical" "GLSL450" {
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
// CHECK-LABEL: @group_non_uniform_elect
func @group_non_uniform_elect() -> i1 {
// CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1
%0 = spv.GroupNonUniformElect "Workgroup" : i1
spv.ReturnValue %0: i1
}
// CHECK-LABEL: @group_non_uniform_iadd_reduce
func @group_non_uniform_iadd_reduce(%val: i32) -> i32 {
// CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %{{.+}} : i32
%0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %val : i32
spv.ReturnValue %0: i32
}
// CHECK-LABEL: @group_non_uniform_iadd_clustered_reduce
func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> {
%four = spv.constant 4 : i32
// CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xi32>
%0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32>
spv.ReturnValue %0: vector<2xi32>
}
}

View File

@ -4,7 +4,7 @@
// spv.GroupNonUniformBallot
//===----------------------------------------------------------------------===//
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
// CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
return %0: vector<4xi32>
@ -12,8 +12,83 @@ func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
// -----
func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
%0 = spv.GroupNonUniformBallot "Device" %predicate : vector<4xi32>
return %0: vector<4xi32>
}
// -----
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformElect
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @group_non_uniform_elect
func @group_non_uniform_elect() -> i1 {
// CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1
%0 = spv.GroupNonUniformElect "Workgroup" : i1
return %0: i1
}
// -----
func @group_non_uniform_elect() -> i1 {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
%0 = spv.GroupNonUniformElect "CrossDevice" : i1
return %0: i1
}
// -----
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformIAdd
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @group_non_uniform_iadd_reduce
func @group_non_uniform_iadd_reduce(%val: i32) -> i32 {
// CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %{{.+}} : i32
%0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %val : i32
return %0: i32
}
// CHECK-LABEL: @group_non_uniform_iadd_clustered_reduce
func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> {
%four = spv.constant 4 : i32
// CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xi32>
%0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32>
return %0: vector<2xi32>
}
// -----
func @group_non_uniform_iadd_reduce(%val: i32) -> i32 {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
%0 = spv.GroupNonUniformIAdd "Device" "Reduce" %val : i32
return %0: i32
}
// -----
func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> {
// expected-error @+1 {{cluster size operand must be provided for 'ClusteredReduce' group operation}}
%0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val : vector<2xi32>
return %0: vector<2xi32>
}
// -----
func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>, %size: i32) -> vector<2xi32> {
// expected-error @+1 {{cluster size operand must come from a constant op}}
%0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%size) : vector<2xi32>
return %0: vector<2xi32>
}
// -----
func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> {
%five = spv.constant 5 : i32
// expected-error @+1 {{cluster size operand must be a power of two}}
%0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%five) : vector<2xi32>
return %0: vector<2xi32>
}