2022-07-14 13:32:13 -07:00

411 lines
12 KiB
C++

//===- Lexer.cpp ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Lexer.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace mlir::pdll;
//===----------------------------------------------------------------------===//
// Token
//===----------------------------------------------------------------------===//
std::string Token::getStringValue() const {
assert(getKind() == string || getKind() == string_block ||
getKind() == code_complete_string);
// Start by dropping the quotes.
StringRef bytes = getSpelling();
if (is(string))
bytes = bytes.drop_front().drop_back();
else if (is(string_block))
bytes = bytes.drop_front(2).drop_back(2);
std::string result;
result.reserve(bytes.size());
for (unsigned i = 0, e = bytes.size(); i != e;) {
auto c = bytes[i++];
if (c != '\\') {
result.push_back(c);
continue;
}
assert(i + 1 <= e && "invalid string should be caught by lexer");
auto c1 = bytes[i++];
switch (c1) {
case '"':
case '\\':
result.push_back(c1);
continue;
case 'n':
result.push_back('\n');
continue;
case 't':
result.push_back('\t');
continue;
default:
break;
}
assert(i + 1 <= e && "invalid string should be caught by lexer");
auto c2 = bytes[i++];
assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
}
return result;
}
//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//
Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
CodeCompleteContext *codeCompleteContext)
: srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
codeCompletionLocation(nullptr) {
curBufferID = mgr.getMainFileID();
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = curBuffer.begin();
// Set the code completion location if necessary.
if (codeCompleteContext) {
codeCompletionLocation =
codeCompleteContext->getCodeCompleteLoc().getPointer();
}
// If the diag engine has no handler, add a default that emits to the
// SourceMgr.
if (!diagEngine.getHandlerFn()) {
diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
diag.getMessage());
for (const ast::Diagnostic &note : diag.getNotes())
srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
note.getMessage());
});
addedHandlerToDiagEngine = true;
}
}
Lexer::~Lexer() {
if (addedHandlerToDiagEngine)
diagEngine.setHandlerFn(nullptr);
}
LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
std::string includedFile;
int bufferID =
srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile);
if (!bufferID)
return failure();
curBufferID = bufferID;
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = curBuffer.begin();
return success();
}
Token Lexer::emitError(SMRange loc, const Twine &msg) {
diagEngine.emitError(loc, msg);
return formToken(Token::error, loc.Start.getPointer());
}
Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
const Twine &note) {
diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
return formToken(Token::error, loc.Start.getPointer());
}
Token Lexer::emitError(const char *loc, const Twine &msg) {
return emitError(
SMRange(SMLoc::getFromPointer(loc), SMLoc::getFromPointer(loc + 1)), msg);
}
int Lexer::getNextChar() {
char curChar = *curPtr++;
switch (curChar) {
default:
return static_cast<unsigned char>(curChar);
case 0: {
// A nul character in the stream is either the end of the current buffer
// or a random nul in the file. Disambiguate that here.
if (curPtr - 1 != curBuffer.end())
return 0;
// Otherwise, return end of file.
--curPtr;
return EOF;
}
case '\n':
case '\r':
// Handle the newline character by ignoring it and incrementing the line
// count. However, be careful about 'dos style' files with \n\r in them.
// Only treat a \n\r or \r\n as a single line.
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
++curPtr;
return '\n';
}
}
Token Lexer::lexToken() {
while (true) {
const char *tokStart = curPtr;
// Check to see if this token is at the code completion location.
if (tokStart == codeCompletionLocation)
return formToken(Token::code_complete, tokStart);
// This always consumes at least one character.
int curChar = getNextChar();
switch (curChar) {
default:
// Handle identifiers: [a-zA-Z_]
if (isalpha(curChar) || curChar == '_')
return lexIdentifier(tokStart);
// Unknown character, emit an error.
return emitError(tokStart, "unexpected character");
case EOF: {
// Return EOF denoting the end of lexing.
Token eof = formToken(Token::eof, tokStart);
// Check to see if we are in an included file.
SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
if (parentIncludeLoc.isValid()) {
curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = parentIncludeLoc.getPointer();
}
return eof;
}
// Lex punctuation.
case '-':
if (*curPtr == '>') {
++curPtr;
return formToken(Token::arrow, tokStart);
}
return emitError(tokStart, "unexpected character");
case ':':
return formToken(Token::colon, tokStart);
case ',':
return formToken(Token::comma, tokStart);
case '.':
return formToken(Token::dot, tokStart);
case '=':
if (*curPtr == '>') {
++curPtr;
return formToken(Token::equal_arrow, tokStart);
}
return formToken(Token::equal, tokStart);
case ';':
return formToken(Token::semicolon, tokStart);
case '[':
if (*curPtr == '{') {
++curPtr;
return lexString(tokStart, /*isStringBlock=*/true);
}
return formToken(Token::l_square, tokStart);
case ']':
return formToken(Token::r_square, tokStart);
case '<':
return formToken(Token::less, tokStart);
case '>':
return formToken(Token::greater, tokStart);
case '{':
return formToken(Token::l_brace, tokStart);
case '}':
return formToken(Token::r_brace, tokStart);
case '(':
return formToken(Token::l_paren, tokStart);
case ')':
return formToken(Token::r_paren, tokStart);
case '/':
if (*curPtr == '/') {
lexComment();
continue;
}
return emitError(tokStart, "unexpected character");
// Ignore whitespace characters.
case 0:
case ' ':
case '\t':
case '\n':
return lexToken();
case '#':
return lexDirective(tokStart);
case '"':
return lexString(tokStart, /*isStringBlock=*/false);
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
return lexNumber(tokStart);
}
}
}
/// Skip a comment line, starting with a '//'.
void Lexer::lexComment() {
// Advance over the second '/' in a '//' comment.
assert(*curPtr == '/');
++curPtr;
while (true) {
switch (*curPtr++) {
case '\n':
case '\r':
// Newline is end of comment.
return;
case 0:
// If this is the end of the buffer, end the comment.
if (curPtr - 1 == curBuffer.end()) {
--curPtr;
return;
}
LLVM_FALLTHROUGH;
default:
// Skip over other characters.
break;
}
}
}
Token Lexer::lexDirective(const char *tokStart) {
// Match the rest with an identifier regex: [0-9a-zA-Z_]*
while (isalnum(*curPtr) || *curPtr == '_')
++curPtr;
StringRef str(tokStart, curPtr - tokStart);
return Token(Token::directive, str);
}
Token Lexer::lexIdentifier(const char *tokStart) {
// Match the rest of the identifier regex: [0-9a-zA-Z_]*
while (isalnum(*curPtr) || *curPtr == '_')
++curPtr;
// Check to see if this identifier is a keyword.
StringRef str(tokStart, curPtr - tokStart);
Token::Kind kind = StringSwitch<Token::Kind>(str)
.Case("attr", Token::kw_attr)
.Case("Attr", Token::kw_Attr)
.Case("erase", Token::kw_erase)
.Case("let", Token::kw_let)
.Case("Constraint", Token::kw_Constraint)
.Case("op", Token::kw_op)
.Case("Op", Token::kw_Op)
.Case("OpName", Token::kw_OpName)
.Case("Pattern", Token::kw_Pattern)
.Case("replace", Token::kw_replace)
.Case("return", Token::kw_return)
.Case("rewrite", Token::kw_rewrite)
.Case("Rewrite", Token::kw_Rewrite)
.Case("type", Token::kw_type)
.Case("Type", Token::kw_Type)
.Case("TypeRange", Token::kw_TypeRange)
.Case("Value", Token::kw_Value)
.Case("ValueRange", Token::kw_ValueRange)
.Case("with", Token::kw_with)
.Case("_", Token::underscore)
.Default(Token::identifier);
return Token(kind, str);
}
Token Lexer::lexNumber(const char *tokStart) {
assert(isdigit(curPtr[-1]));
// Handle the normal decimal case.
while (isdigit(*curPtr))
++curPtr;
return formToken(Token::integer, tokStart);
}
Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
while (true) {
// Check to see if there is a code completion location within the string. In
// these cases we generate a completion location and place the currently
// lexed string within the token (without the quotes). This allows for the
// parser to use the partially lexed string when computing the completion
// results.
if (curPtr == codeCompletionLocation) {
return formToken(Token::code_complete_string,
tokStart + (isStringBlock ? 2 : 1));
}
switch (*curPtr++) {
case '"':
// If this is a string block, we only end the string when we encounter a
// `}]`.
if (!isStringBlock)
return formToken(Token::string, tokStart);
continue;
case '}':
// If this is a string block, we only end the string when we encounter a
// `}]`.
if (!isStringBlock || *curPtr != ']')
continue;
++curPtr;
return formToken(Token::string_block, tokStart);
case 0: {
// If this is a random nul character in the middle of a string, just
// include it. If it is the end of file, then it is an error.
if (curPtr - 1 != curBuffer.end())
continue;
--curPtr;
StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
return emitError(curPtr - 1,
"expected '" + expectedEndStr + "' in string literal");
}
case '\n':
case '\v':
case '\f':
// String blocks allow multiple lines.
if (!isStringBlock)
return emitError(curPtr - 1, "expected '\"' in string literal");
continue;
case '\\':
// Handle explicitly a few escapes.
if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
*curPtr == 't') {
++curPtr;
} else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
// Support \xx for two hex digits.
curPtr += 2;
} else {
return emitError(curPtr - 1, "unknown escape in string literal");
}
continue;
default:
continue;
}
}
}