//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "TestDialect.h" #include "TestTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSwitch.h" using namespace mlir; void mlir::registerTestDialect(DialectRegistry ®istry) { registry.insert(); } //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// namespace { // Test support for interacting with the AsmPrinter. struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { StringAttr strAttr = attr.dyn_cast(); if (!strAttr) return failure(); // Check the contents of the string attribute to see what the test alias // should be named. Optional aliasName = StringSwitch>(strAttr.getValue()) .Case("alias_test:dot_in_name", StringRef("test.alias")) .Case("alias_test:trailing_digit", StringRef("test_alias0")) .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) .Case("alias_test:sanitize_conflict_a", StringRef("test_alias_conflict0")) .Case("alias_test:sanitize_conflict_b", StringRef("test_alias_conflict0_")) .Default(llvm::None); if (!aliasName) return failure(); os << *aliasName; return success(); } void getAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn) const final { if (auto asmOp = dyn_cast(op)) setNameFn(asmOp, "result"); } void getAsmBlockArgumentNames(Block *block, OpAsmSetValueNameFn setNameFn) const final { auto op = block->getParentOp(); auto arrayAttr = op->getAttrOfType("arg_names"); if (!arrayAttr) return; auto args = block->getArguments(); auto e = std::min(arrayAttr.size(), args.size()); for (unsigned i = 0; i < e; ++i) { if (auto strAttr = arrayAttr[i].dyn_cast()) setNameFn(args[i], strAttr.getValue()); } } }; struct TestDialectFoldInterface : public DialectFoldInterface { using DialectFoldInterface::DialectFoldInterface; /// Registered hook to check if the given region, which is attached to an /// operation that is *not* isolated from above, should be used when /// materializing constants. bool shouldMaterializeInto(Region *region) const final { // If this is a one region operation, then insert into it. return isa(region->getParentOp()); } }; /// This class defines the interface for handling inlining with standard /// operations. struct TestInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks //===--------------------------------------------------------------------===// bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final { // Don't allow inlining calls that are marked `noinline`. return !call->hasAttr("noinline"); } bool isLegalToInline(Region *, Region *, bool, BlockAndValueMapping &) const final { // Inlining into test dialect regions is legal. return true; } bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } bool shouldAnalyzeRecursively(Operation *op) const final { // Analyze recursively if this is not a functional region operation, it // froms a separate functional scope. return !isa(op); } //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, ArrayRef valuesToRepl) const final { // Only handle "test.return" here. auto returnOp = dyn_cast(op); if (!returnOp) return; // Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); for (const auto &it : llvm::enumerate(returnOp.getOperands())) valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } /// Attempt to materialize a conversion for a type mismatch between a call /// from this dialect, and a callable region. This method should generate an /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { // Only allow conversion for i16/i32 types. if (!(resultType.isSignlessInteger(16) || resultType.isSignlessInteger(32)) || !(input.getType().isSignlessInteger(16) || input.getType().isSignlessInteger(32))) return nullptr; return builder.create(conversionLoc, resultType, input); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// void TestDialect::initialize() { addOperations< #define GET_OP_LIST #include "TestOps.cpp.inc" >(); addInterfaces(); addTypes(); allowUnknownOperations(); } static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, llvm::SetVector &stack) { StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); auto genType = generatedTypeParser(ctxt, parser, typeTag); if (genType != Type()) return genType; if (typeTag == "test_type") return TestType::get(parser.getBuilder().getContext()); if (typeTag != "test_rec") return Type(); StringRef name; if (parser.parseLess() || parser.parseKeyword(&name)) return Type(); auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); // If this type already has been parsed above in the stack, expect just the // name. if (stack.contains(rec)) { if (failed(parser.parseGreater())) return Type(); return rec; } // Otherwise, parse the body and update the type. if (failed(parser.parseComma())) return Type(); stack.insert(rec); Type subtype = parseTestType(ctxt, parser, stack); stack.pop_back(); if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) return Type(); return rec; } Type TestDialect::parseType(DialectAsmParser &parser) const { llvm::SetVector stack; return parseTestType(getContext(), parser, stack); } static void printTestType(Type type, DialectAsmPrinter &printer, llvm::SetVector &stack) { if (succeeded(generatedTypePrinter(type, printer))) return; if (type.isa()) { printer << "test_type"; return; } auto rec = type.cast(); printer << "test_rec<" << rec.getName(); if (!stack.contains(rec)) { printer << ", "; stack.insert(rec); printTestType(rec.getBody(), printer, stack); stack.pop_back(); } printer << ">"; } void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { llvm::SetVector stack; printTestType(type, printer, stack); } LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, unsigned regionIndex, unsigned argIndex, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } LogicalResult TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, unsigned resultIndex, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } //===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// Optional TestBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); return targetOperandsMutable(); } //===----------------------------------------------------------------------===// // TestFoldToCallOp //===----------------------------------------------------------------------===// namespace { struct FoldToCallOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FoldToCallOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, TypeRange(), op.calleeAttr(), ValueRange()); return success(); } }; } // end anonymous namespace void FoldToCallOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // Test Format* operations //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Parsing static ParseResult parseCustomDirectiveOperands( OpAsmParser &parser, OpAsmParser::OperandType &operand, Optional &optOperand, SmallVectorImpl &varOperands) { if (parser.parseOperand(operand)) return failure(); if (succeeded(parser.parseOptionalComma())) { optOperand.emplace(); if (parser.parseOperand(*optOperand)) return failure(); } if (parser.parseArrow() || parser.parseLParen() || parser.parseOperandList(varOperands) || parser.parseRParen()) return failure(); return success(); } static ParseResult parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, Type &optOperandType, SmallVectorImpl &varOperandTypes) { if (parser.parseColon()) return failure(); if (parser.parseType(operandType)) return failure(); if (succeeded(parser.parseOptionalComma())) { if (parser.parseType(optOperandType)) return failure(); } if (parser.parseArrow() || parser.parseLParen() || parser.parseTypeList(varOperandTypes) || parser.parseRParen()) return failure(); return success(); } static ParseResult parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, Type optOperandType, const SmallVectorImpl &varOperandTypes) { if (parser.parseKeyword("type_refs_capture")) return failure(); Type operandType2, optOperandType2; SmallVector varOperandTypes2; if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, varOperandTypes2)) return failure(); if (operandType != operandType2 || optOperandType != optOperandType2 || varOperandTypes != varOperandTypes2) return failure(); return success(); } static ParseResult parseCustomDirectiveOperandsAndTypes( OpAsmParser &parser, OpAsmParser::OperandType &operand, Optional &optOperand, SmallVectorImpl &varOperands, Type &operandType, Type &optOperandType, SmallVectorImpl &varOperandTypes) { if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || parseCustomDirectiveResults(parser, operandType, optOperandType, varOperandTypes)) return failure(); return success(); } static ParseResult parseCustomDirectiveRegions( OpAsmParser &parser, Region ®ion, SmallVectorImpl> &varRegions) { if (parser.parseRegion(region)) return failure(); if (failed(parser.parseOptionalComma())) return success(); std::unique_ptr varRegion = std::make_unique(); if (parser.parseRegion(*varRegion)) return failure(); varRegions.emplace_back(std::move(varRegion)); return success(); } static ParseResult parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, SmallVectorImpl &varSuccessors) { if (parser.parseSuccessor(successor)) return failure(); if (failed(parser.parseOptionalComma())) return success(); Block *varSuccessor; if (parser.parseSuccessor(varSuccessor)) return failure(); varSuccessors.append(2, varSuccessor); return success(); } static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, IntegerAttr &attr, IntegerAttr &optAttr) { if (parser.parseAttribute(attr)) return failure(); if (succeeded(parser.parseOptionalComma())) { if (parser.parseAttribute(optAttr)) return failure(); } return success(); } static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, NamedAttrList &attrs) { return parser.parseOptionalAttrDict(attrs); } //===----------------------------------------------------------------------===// // Printing static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, Value operand, Value optOperand, OperandRange varOperands) { printer << operand; if (optOperand) printer << ", " << optOperand; printer << " -> (" << varOperands << ")"; } static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " : " << operandType; if (optOperandType) printer << ", " << optOperandType; printer << " -> (" << varOperandTypes << ")"; } static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, Operation *op, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " type_refs_capture "; printCustomDirectiveResults(printer, op, operandType, optOperandType, varOperandTypes); } static void printCustomDirectiveOperandsAndTypes( OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, OperandRange varOperands, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); printCustomDirectiveResults(printer, op, operandType, optOperandType, varOperandTypes); } static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, Region ®ion, MutableArrayRef varRegions) { printer.printRegion(region); if (!varRegions.empty()) { printer << ", "; for (Region ®ion : varRegions) printer.printRegion(region); } } static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, Block *successor, SuccessorRange varSuccessors) { printer << successor; if (!varSuccessors.empty()) printer << ", " << varSuccessors.front(); } static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, Attribute attribute, Attribute optAttribute) { printer << attribute; if (optAttribute) printer << ", " << optAttribute; } static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, MutableDictionaryAttr attrs) { printer.printOptionalAttrDict(attrs.getAttrs()); } //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType argInfo; Type argType = parser.getBuilder().getIndexType(); // Parse the input operand. if (parser.parseOperand(argInfo) || parser.resolveOperand(argInfo, argType, result.operands)) return failure(); // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, argInfo, argType, /*enableNameShadowing=*/true); } static void print(OpAsmPrinter &p, IsolatedRegionOp op) { p << "test.isolated_region "; p.printOperand(op.getOperand()); p.shadowRegionArgs(op.region(), op.getOperand()); p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// // Test SSACFGRegionOp //===----------------------------------------------------------------------===// RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { return RegionKind::SSACFG; } //===----------------------------------------------------------------------===// // Test GraphRegionOp //===----------------------------------------------------------------------===// static ParseResult parseGraphRegionOp(OpAsmParser &parser, OperationState &result) { // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); } static void print(OpAsmPrinter &p, GraphRegionOp op) { p << "test.graph_region "; p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } RegionKind GraphRegionOp::getRegionKind(unsigned index) { return RegionKind::Graph; } //===----------------------------------------------------------------------===// // Test AffineScopeOp //===----------------------------------------------------------------------===// static ParseResult parseAffineScopeOp(OpAsmParser &parser, OperationState &result) { // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); } static void print(OpAsmPrinter &p, AffineScopeOp op) { p << "test.affine_scope "; p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// // Test parser. //===----------------------------------------------------------------------===// static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, OperationState &result) { StringRef keyword; if (parser.parseKeyword(&keyword)) return failure(); result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); return success(); } static void print(OpAsmPrinter &p, WrappedKeywordOp op) { p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); } //===----------------------------------------------------------------------===// // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. static ParseResult parseWrappingRegionOp(OpAsmParser &parser, OperationState &result) { if (parser.parseKeyword("wraps")) return failure(); // Parse the wrapped op in a region Region &body = *result.addRegion(); body.push_back(new Block); Block &block = body.back(); Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); if (!wrapped_op) return failure(); // Create a return terminator in the inner region, pass as operand to the // terminator the returned values from the wrapped operation. SmallVector return_operands(wrapped_op->getResults()); OpBuilder builder(parser.getBuilder().getContext()); builder.setInsertionPointToEnd(&block); builder.create(wrapped_op->getLoc(), return_operands); // Get the results type for the wrapping op from the terminator operands. Operation &return_op = body.back().back(); result.types.append(return_op.operand_type_begin(), return_op.operand_type_end()); // Use the location of the wrapped op for the "test.wrapping_region" op. result.location = wrapped_op->getLoc(); return success(); } static void print(OpAsmPrinter &p, WrappingRegionOp op) { p << op.getOperationName() << " wraps "; p.printGenericOp(&op.region().front().front()); } //===----------------------------------------------------------------------===// // Test PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { SmallVector ivsInfo; // Parse list of region arguments without a delimiter. if (parser.parseRegionArgumentList(ivsInfo)) return failure(); // Parse the body region. Region *body = result.addRegion(); auto &builder = parser.getBuilder(); SmallVector argTypes(ivsInfo.size(), builder.getIndexType()); return parser.parseRegion(*body, ivsInfo, argTypes); } //===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// namespace { struct TestRemoveOpWithInnerOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TestOpWithRegionPattern op, PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); } }; } // end anonymous namespace void TestOpWithRegionPattern::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { return operand(); } OpFoldResult TestOpConstant::fold(ArrayRef operands) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( ArrayRef operands, SmallVectorImpl &results) { for (Value input : this->operands()) { results.push_back(input); } return success(); } OpFoldResult TestOpInPlaceFold::fold(ArrayRef operands) { assert(operands.size() == 1); if (operands.front()) { setAttr("attr", operands.front()); return getResult(); } return {}; } LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType() != operands[1].getType()) { return emitOptionalError(location, "operand type mismatch ", operands[0].getType(), " vs ", operands[1].getType()); } inferredReturnTypes.assign({operands[0].getType()}); return success(); } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. auto operandType = *operands.getTypes().begin(); auto sval = operandType.dyn_cast(); if (!sval) { return emitOptionalError(location, "only shaped type operands allowed"); } int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; auto type = IntegerType::get(17, context); inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); return success(); } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( OpBuilder &builder, llvm::SmallVectorImpl &shapes) { shapes = SmallVector{ builder.createOrFold(getLoc(), getOperand(0), 0)}; return success(); } //===----------------------------------------------------------------------===// // Test SideEffect interfaces //===----------------------------------------------------------------------===// namespace { /// A test resource for side effects. struct TestResource : public SideEffects::Resource::Base { StringRef getName() final { return ""; } }; } // end anonymous namespace void SideEffectOp::getEffects( SmallVectorImpl &effects) { // Check for an effects attribute on the op instance. ArrayAttr effectsAttr = getAttrOfType("effects"); if (!effectsAttr) return; // If there is one, it is an array of dictionary attributes that hold // information on the effects of this operation. for (Attribute element : effectsAttr) { DictionaryAttr effectElement = element.cast(); // Get the specific memory effect. MemoryEffects::Effect *effect = StringSwitch( effectElement.get("effect").cast().getValue()) .Case("allocate", MemoryEffects::Allocate::get()) .Case("free", MemoryEffects::Free::get()) .Case("read", MemoryEffects::Read::get()) .Case("write", MemoryEffects::Write::get()); // Check for a result to affect. Value value; if (effectElement.get("on_result")) value = getResult(); // Check for a non-default resource to use. SideEffects::Resource *resource = SideEffects::DefaultResource::get(); if (effectElement.get("test_resource")) resource = TestResource::get(); effects.emplace_back(effect, value, resource); } } //===----------------------------------------------------------------------===// // StringAttrPrettyNameOp //===----------------------------------------------------------------------===// // This op has fancy handling of its SSA result name. static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, OperationState &result) { // Add the result types. for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) result.addTypes(parser.getBuilder().getIntegerType(32)); if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return failure(); // If the attribute dictionary contains no 'names' attribute, infer it from // the SSA name (if specified). bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { return attr.first == "names"; }); // If there was no name specified, check to see if there was a useful name // specified in the asm file. if (hadNames || parser.getNumResults() == 0) return success(); SmallVector names; auto *context = result.getContext(); for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { auto resultName = parser.getResultName(i); StringRef nameStr; if (!resultName.first.empty() && !isdigit(resultName.first[0])) nameStr = resultName.first; names.push_back(nameStr); } auto namesAttr = parser.getBuilder().getStrArrayAttr(names); result.attributes.push_back({Identifier::get("names", context), namesAttr}); return success(); } static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { p << "test.string_attr_pretty_name"; // Note that we only need to print the "name" attribute if the asmprinter // result name disagrees with it. This can happen in strange cases, e.g. // when there are conflicts. bool namesDisagree = op.names().size() != op.getNumResults(); SmallString<32> resultNameStr; for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { resultNameStr.clear(); llvm::raw_svector_ostream tmpStream(resultNameStr); p.printOperand(op.getResult(i), tmpStream); auto expectedName = op.names()[i].dyn_cast(); if (!expectedName || tmpStream.str().drop_front() != expectedName.getValue()) { namesDisagree = true; } } if (namesDisagree) p.printOptionalAttrDictWithKeyword(op.getAttrs()); else p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); } // We set the SSA name in the asm syntax to the contents of the name // attribute. void StringAttrPrettyNameOp::getAsmResultNames( function_ref setNameFn) { auto value = names(); for (size_t i = 0, e = value.size(); i != e; ++i) if (auto str = value[i].dyn_cast()) if (!str.getValue().empty()) setNameFn(getResult(i), str.getValue()); } //===----------------------------------------------------------------------===// // RegionIfOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, RegionIfOp op) { p << RegionIfOp::getOperationName() << " "; p.printOperands(op.getOperands()); p << ": " << op.getOperandTypes(); p.printArrowTypeList(op.getResultTypes()); p << " then"; p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " else"; p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " join"; p.printRegion(op.joinRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); } static ParseResult parseRegionIfOp(OpAsmParser &parser, OperationState &result) { SmallVector operandInfos; SmallVector operandTypes; result.regions.reserve(3); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); Region *joinRegion = result.addRegion(); // Parse operand, type and arrow type lists. if (parser.parseOperandList(operandInfos) || parser.parseColonTypeList(operandTypes) || parser.parseArrowTypeList(result.types)) return failure(); // Parse all attached regions. if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) return failure(); return parser.resolveOperands(operandInfos, operandTypes, parser.getCurrentLocation(), result.operands); } OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { assert(index < 2 && "invalid region index"); return getOperands(); } void RegionIfOp::getSuccessorRegions( Optional index, ArrayRef operands, SmallVectorImpl ®ions) { // We always branch to the join region. if (index.hasValue()) { if (index.getValue() < 2) regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); else regions.push_back(RegionSuccessor(getResults())); return; } // The then and else regions are the entry regions of this op. regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); } #include "TestOpEnums.cpp.inc" #include "TestOpStructs.cpp.inc" #include "TestTypeInterfaces.cpp.inc" #define GET_OP_CLASSES #include "TestOps.cpp.inc"