//===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "MLIRServer.h" #include "lsp/Logging.h" #include "lsp/Protocol.h" #include "mlir/IR/Operation.h" #include "mlir/Parser.h" #include "mlir/Parser/AsmParserState.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; /// Returns a language server position for the given source location. static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc) { std::pair lineAndCol = mgr.getLineAndColumn(loc); lsp::Position pos; pos.line = lineAndCol.first - 1; pos.character = lineAndCol.second; return pos; } /// Returns a source location from the given language server position. static llvm::SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) { return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), pos.line + 1, pos.character); } /// Returns a language server range for the given source range. static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range) { // lsp::Range is an inclusive range, SMRange is half-open. llvm::SMLoc inclusiveEnd = llvm::SMLoc::getFromPointer(range.End.getPointer() - 1); return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, inclusiveEnd)}; } /// Returns a language server location from the given source range. static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range, const lsp::URIForFile &uri) { return lsp::Location{uri, getRangeFromLoc(mgr, range)}; } /// Returns a language server location from the given MLIR file location. static Optional getLocationFromLoc(FileLineColLoc loc) { llvm::Expected sourceURI = lsp::URIForFile::fromFile(loc.getFilename()); if (!sourceURI) { lsp::Logger::error("Failed to create URI for file `{0}`: {1}", loc.getFilename(), llvm::toString(sourceURI.takeError())); return llvm::None; } lsp::Position position; position.line = loc.getLine() - 1; position.character = loc.getColumn(); return lsp::Location{*sourceURI, lsp::Range(position)}; } /// Returns a language server location from the given MLIR location, or None if /// one couldn't be created. `uri` is an optional additional filter that, when /// present, is used to filter sub locations that do not share the same uri. static Optional getLocationFromLoc(Location loc, const lsp::URIForFile *uri = nullptr) { Optional location; loc->walk([&](Location nestedLoc) { FileLineColLoc fileLoc = nestedLoc.dyn_cast(); if (!fileLoc) return WalkResult::advance(); Optional sourceLoc = getLocationFromLoc(fileLoc); if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { location = *sourceLoc; return WalkResult::interrupt(); } return WalkResult::advance(); }); return location; } /// Collect all of the locations from the given MLIR location that are not /// contained within the given URI. static void collectLocationsFromLoc(Location loc, std::vector &locations, const lsp::URIForFile &uri) { SetVector visitedLocs; loc->walk([&](Location nestedLoc) { FileLineColLoc fileLoc = nestedLoc.dyn_cast(); if (!fileLoc || !visitedLocs.insert(nestedLoc)) return WalkResult::advance(); Optional sourceLoc = getLocationFromLoc(fileLoc); if (sourceLoc && sourceLoc->uri != uri) locations.push_back(*sourceLoc); return WalkResult::advance(); }); } /// Returns true if the given range contains the given source location. Note /// that this has slightly different behavior than SMRange because it is /// inclusive of the end location. static bool contains(llvm::SMRange range, llvm::SMLoc loc) { return range.Start.getPointer() <= loc.getPointer() && loc.getPointer() <= range.End.getPointer(); } /// Returns true if the given location is contained by the definition or one of /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to /// the range within `def` that the provided `loc` overlapped with. static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc, llvm::SMRange *overlappedRange = nullptr) { // Check the main definition. if (contains(def.loc, loc)) { if (overlappedRange) *overlappedRange = def.loc; return true; } // Check the uses. auto useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) { return contains(range, loc); }); if (useIt != def.uses.end()) { if (overlappedRange) *overlappedRange = *useIt; return true; } return false; } /// Given a location pointing to a result, return the result number it refers /// to or None if it refers to all of the results. static Optional getResultNumberFromLoc(llvm::SMLoc loc) { // Skip all of the identifier characters. auto isIdentifierChar = [](char c) { return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' || c == '-'; }; const char *curPtr = loc.getPointer(); while (isIdentifierChar(*curPtr)) ++curPtr; // Check to see if this location indexes into the result group, via `#`. If it // doesn't, we can't extract a sub result number. if (*curPtr != '#') return llvm::None; // Compute the sub result number from the remaining portion of the string. const char *numberStart = ++curPtr; while (llvm::isDigit(*curPtr)) ++curPtr; StringRef numberStr(numberStart, curPtr - numberStart); unsigned resultNumber = 0; return numberStr.consumeInteger(10, resultNumber) ? Optional() : resultNumber; } /// Given a source location range, return the text covered by the given range. /// If the range is invalid, returns None. static Optional getTextFromRange(llvm::SMRange range) { if (!range.isValid()) return None; const char *startPtr = range.Start.getPointer(); return StringRef(startPtr, range.End.getPointer() - startPtr); } /// Given a block, return its position in its parent region. static unsigned getBlockNumber(Block *block) { return std::distance(block->getParent()->begin(), block->getIterator()); } /// Given a block and source location, print the source name of the block to the /// given output stream. static void printDefBlockName(raw_ostream &os, Block *block, llvm::SMRange loc = {}) { // Try to extract a name from the source location. Optional text = getTextFromRange(loc); if (text && text->startswith("^")) { os << *text; return; } // Otherwise, we don't have a name so print the block number. os << ""; } static void printDefBlockName(raw_ostream &os, const AsmParserState::BlockDefinition &def) { printDefBlockName(os, def.block, def.definition.loc); } /// Convert the given MLIR diagnostic to the LSP form. static lsp::Diagnostic getLspDiagnoticFromDiag(Diagnostic &diag, const lsp::URIForFile &uri) { lsp::Diagnostic lspDiag; lspDiag.source = "mlir"; // Note: Right now all of the diagnostics are treated as parser issues, but // some are parser and some are verifier. lspDiag.category = "Parse Error"; // Try to grab a file location for this diagnostic. // TODO: For simplicity, we just grab the first one. It may be likely that we // will need a more interesting heuristic here.' Optional lspLocation = getLocationFromLoc(diag.getLocation(), &uri); if (lspLocation) lspDiag.range = lspLocation->range; // Convert the severity for the diagnostic. switch (diag.getSeverity()) { case DiagnosticSeverity::Note: llvm_unreachable("expected notes to be handled separately"); case DiagnosticSeverity::Warning: lspDiag.severity = lsp::DiagnosticSeverity::Warning; break; case DiagnosticSeverity::Error: lspDiag.severity = lsp::DiagnosticSeverity::Error; break; case DiagnosticSeverity::Remark: lspDiag.severity = lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.str(); // Attach any notes to the main diagnostic as related information. std::vector relatedDiags; for (Diagnostic ¬e : diag.getNotes()) { lsp::Location noteLoc; if (Optional loc = getLocationFromLoc(note.getLocation())) noteLoc = *loc; else noteLoc.uri = uri; relatedDiags.emplace_back(noteLoc, note.str()); } if (!relatedDiags.empty()) lspDiag.relatedInformation = std::move(relatedDiags); return lspDiag; } //===----------------------------------------------------------------------===// // MLIRDocument //===----------------------------------------------------------------------===// namespace { /// This class represents all of the information pertaining to a specific MLIR /// document. struct MLIRDocument { MLIRDocument(const lsp::URIForFile &uri, StringRef contents, DialectRegistry ®istry, std::vector &diagnostics); //===--------------------------------------------------------------------===// // Definitions and References //===--------------------------------------------------------------------===// void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, std::vector &locations); void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, std::vector &references); //===--------------------------------------------------------------------===// // Hover //===--------------------------------------------------------------------===// Optional findHover(const lsp::URIForFile &uri, const lsp::Position &hoverPos); Optional buildHoverForOperation(const AsmParserState::OperationDefinition &op); lsp::Hover buildHoverForOperationResult(llvm::SMRange hoverRange, Operation *op, unsigned resultStart, unsigned resultEnd, llvm::SMLoc posLoc); lsp::Hover buildHoverForBlock(llvm::SMRange hoverRange, const AsmParserState::BlockDefinition &block); lsp::Hover buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg, const AsmParserState::BlockDefinition &block); /// The context used to hold the state contained by the parsed document. MLIRContext context; /// The high level parser state used to find definitions and references within /// the source file. AsmParserState asmState; /// The container for the IR parsed from the input file. Block parsedIR; /// The source manager containing the contents of the input file. llvm::SourceMgr sourceMgr; }; } // namespace MLIRDocument::MLIRDocument(const lsp::URIForFile &uri, StringRef contents, DialectRegistry ®istry, std::vector &diagnostics) : context(registry) { context.allowUnregisteredDialects(); ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { diagnostics.push_back(getLspDiagnoticFromDiag(diag, uri)); }); // Try to parsed the given IR string. auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); if (!memBuffer) { lsp::Logger::error("Failed to create memory buffer for file", uri.file()); return; } sourceMgr.AddNewSourceBuffer(std::move(memBuffer), llvm::SMLoc()); if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr, &asmState))) { // If parsing failed, clear out any of the current state. parsedIR.clear(); asmState = AsmParserState(); return; } } //===----------------------------------------------------------------------===// // MLIRDocument: Definitions and References //===----------------------------------------------------------------------===// void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, std::vector &locations) { llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, defPos); // Functor used to check if an SM definition contains the position. auto containsPosition = [&](const AsmParserState::SMDefinition &def) { if (!isDefOrUse(def, posLoc)) return false; locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); return true; }; // Check all definitions related to operations. for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { if (contains(op.loc, posLoc)) return collectLocationsFromLoc(op.op->getLoc(), locations, uri); for (const auto &result : op.resultGroups) if (containsPosition(result.second)) return collectLocationsFromLoc(op.op->getLoc(), locations, uri); } // Check all definitions related to blocks. for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { if (containsPosition(block.definition)) return; for (const AsmParserState::SMDefinition &arg : block.arguments) if (containsPosition(arg)) return; } } void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, std::vector &references) { // Functor used to append all of the definitions/uses of the given SM // definition to the reference list. auto appendSMDef = [&](const AsmParserState::SMDefinition &def) { references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri)); for (const llvm::SMRange &use : def.uses) references.push_back(getLocationFromLoc(sourceMgr, use, uri)); }; llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, pos); // Check all definitions related to operations. for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { if (contains(op.loc, posLoc)) { for (const auto &result : op.resultGroups) appendSMDef(result.second); return; } for (const auto &result : op.resultGroups) if (isDefOrUse(result.second, posLoc)) return appendSMDef(result.second); } // Check all definitions related to blocks. for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { if (isDefOrUse(block.definition, posLoc)) return appendSMDef(block.definition); for (const AsmParserState::SMDefinition &arg : block.arguments) if (isDefOrUse(arg, posLoc)) return appendSMDef(arg); } } //===----------------------------------------------------------------------===// // MLIRDocument: Hover //===----------------------------------------------------------------------===// Optional MLIRDocument::findHover(const lsp::URIForFile &uri, const lsp::Position &hoverPos) { llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos); llvm::SMRange hoverRange; // Check for Hovers on operations and results. for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { // Check if the position points at this operation. if (contains(op.loc, posLoc)) return buildHoverForOperation(op); // Check if the position points at a result group. for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) { const auto &result = op.resultGroups[i]; if (!isDefOrUse(result.second, posLoc, &hoverRange)) continue; // Get the range of results covered by the over position. unsigned resultStart = result.first; unsigned resultEnd = (i == e - 1) ? op.op->getNumResults() : op.resultGroups[i + 1].first; return buildHoverForOperationResult(hoverRange, op.op, resultStart, resultEnd, posLoc); } } // Check to see if the hover is over a block argument. for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { if (isDefOrUse(block.definition, posLoc, &hoverRange)) return buildHoverForBlock(hoverRange, block); for (const auto &arg : llvm::enumerate(block.arguments)) { if (!isDefOrUse(arg.value(), posLoc, &hoverRange)) continue; return buildHoverForBlockArgument( hoverRange, block.block->getArgument(arg.index()), block); } } return llvm::None; } Optional MLIRDocument::buildHoverForOperation( const AsmParserState::OperationDefinition &op) { // Don't show hovers for operations with regions to avoid huge hover blocks. // TODO: Should we add support for printing an op without its regions? if (llvm::any_of(op.op->getRegions(), [](Region ®ion) { return !region.empty(); })) return llvm::None; lsp::Hover hover(getRangeFromLoc(sourceMgr, op.loc)); llvm::raw_string_ostream os(hover.contents.value); // For hovers on an operation, show the generic form. os << "```mlir\n"; op.op->print( os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); os << "\n```\n"; return hover; } lsp::Hover MLIRDocument::buildHoverForOperationResult(llvm::SMRange hoverRange, Operation *op, unsigned resultStart, unsigned resultEnd, llvm::SMLoc posLoc) { lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); llvm::raw_string_ostream os(hover.contents.value); // Add the parent operation name to the hover. os << "Operation: \"" << op->getName() << "\"\n\n"; // Check to see if the location points to a specific result within the // group. if (Optional resultNumber = getResultNumberFromLoc(posLoc)) { if ((resultStart + *resultNumber) < resultEnd) { resultStart += *resultNumber; resultEnd = resultStart + 1; } } // Add the range of results and their types to the hover info. if ((resultStart + 1) == resultEnd) { os << "Result #" << resultStart << "\n\n" << "Type: `" << op->getResult(resultStart).getType() << "`\n\n"; } else { os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n" << "Types: "; llvm::interleaveComma( op->getResults().slice(resultStart, resultEnd), os, [&](Value result) { os << "`" << result.getType() << "`"; }); } return hover; } lsp::Hover MLIRDocument::buildHoverForBlock(llvm::SMRange hoverRange, const AsmParserState::BlockDefinition &block) { lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); llvm::raw_string_ostream os(hover.contents.value); // Print the given block to the hover output stream. auto printBlockToHover = [&](Block *newBlock) { if (const auto *def = asmState.getBlockDef(newBlock)) printDefBlockName(os, *def); else printDefBlockName(os, newBlock); }; // Display the parent operation, block number, predecessors, and successors. os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" << "Block #" << getBlockNumber(block.block) << "\n\n"; if (!block.block->hasNoPredecessors()) { os << "Predecessors: "; llvm::interleaveComma(block.block->getPredecessors(), os, printBlockToHover); os << "\n\n"; } if (!block.block->hasNoSuccessors()) { os << "Successors: "; llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover); os << "\n\n"; } return hover; } lsp::Hover MLIRDocument::buildHoverForBlockArgument( llvm::SMRange hoverRange, BlockArgument arg, const AsmParserState::BlockDefinition &block) { lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange)); llvm::raw_string_ostream os(hover.contents.value); // Display the parent operation, block, the argument number, and the type. os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" << "Block: "; printDefBlockName(os, block); os << "\n\nArgument #" << arg.getArgNumber() << "\n\n" << "Type: `" << arg.getType() << "`\n\n"; return hover; } //===----------------------------------------------------------------------===// // MLIRServer::Impl //===----------------------------------------------------------------------===// struct lsp::MLIRServer::Impl { Impl(DialectRegistry ®istry) : registry(registry) {} /// The registry containing dialects that can be recognized in parsed .mlir /// files. DialectRegistry ®istry; /// The documents held by the server, mapped by their URI file name. llvm::StringMap> documents; }; //===----------------------------------------------------------------------===// // MLIRServer //===----------------------------------------------------------------------===// lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) : impl(std::make_unique(registry)) {} lsp::MLIRServer::~MLIRServer() {} void lsp::MLIRServer::addOrUpdateDocument( const URIForFile &uri, StringRef contents, std::vector &diagnostics) { impl->documents[uri.file()] = std::make_unique( uri, contents, impl->registry, diagnostics); } void lsp::MLIRServer::removeDocument(const URIForFile &uri) { impl->documents.erase(uri.file()); } void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, const Position &defPos, std::vector &locations) { auto fileIt = impl->documents.find(uri.file()); if (fileIt != impl->documents.end()) fileIt->second->getLocationsOf(uri, defPos, locations); } void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, const Position &pos, std::vector &references) { auto fileIt = impl->documents.find(uri.file()); if (fileIt != impl->documents.end()) fileIt->second->findReferencesOf(uri, pos, references); } Optional lsp::MLIRServer::findHover(const URIForFile &uri, const Position &hoverPos) { auto fileIt = impl->documents.find(uri.file()); if (fileIt != impl->documents.end()) return fileIt->second->findHover(uri, hoverPos); return llvm::None; }