River Riddle 52b34df9d6 [mlir][PDLL] Add an initial language server for PDLL
This commits adds a basic language server for PDLL to enable providing
language features in IDEs such as VSCode. This initial commit only
adds support for tracking definitions, references, and diagnostics, but
followup commits will build upon this to provide more significant behavior.

In addition to the server, this commit also updates mlir-vscode to support
the PDLL language and invoke the server.

Differential Revision: https://reviews.llvm.org/D121541
2022-03-19 13:28:23 -07:00

540 lines
21 KiB
C++

//===- Nodes.cpp ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::pdll::ast;
/// Copy a string reference into the context with a null terminator.
static StringRef copyStringWithNull(Context &ctx, StringRef str) {
if (str.empty())
return str;
char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
std::copy(str.begin(), str.end(), data);
data[str.size()] = 0;
return StringRef(data, str.size());
}
//===----------------------------------------------------------------------===//
// Name
//===----------------------------------------------------------------------===//
const Name &Name::create(Context &ctx, StringRef name, SMRange location) {
return *new (ctx.getAllocator().Allocate<Name>())
Name(copyStringWithNull(ctx, name), location);
}
//===----------------------------------------------------------------------===//
// Node
//===----------------------------------------------------------------------===//
namespace {
class NodeVisitor {
public:
explicit NodeVisitor(function_ref<void(const Node *)> visitFn)
: visitFn(visitFn) {}
void visit(const Node *node) {
if (!node || !alreadyVisited.insert(node).second)
return;
visitFn(node);
TypeSwitch<const Node *>(node)
.Case<
// Statements.
const CompoundStmt, const EraseStmt, const LetStmt,
const ReplaceStmt, const ReturnStmt, const RewriteStmt,
// Expressions.
const AttributeExpr, const CallExpr, const DeclRefExpr,
const MemberAccessExpr, const OperationExpr, const TupleExpr,
const TypeExpr,
// Core Constraint Decls.
const AttrConstraintDecl, const OpConstraintDecl,
const TypeConstraintDecl, const TypeRangeConstraintDecl,
const ValueConstraintDecl, const ValueRangeConstraintDecl,
// Decls.
const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
const UserConstraintDecl, const UserRewriteDecl, const VariableDecl,
const Module>(
[&](auto derivedNode) { this->visitImpl(derivedNode); })
.Default([](const Node *) { llvm_unreachable("unknown AST node"); });
}
private:
void visitImpl(const CompoundStmt *stmt) {
for (const Node *child : stmt->getChildren())
visit(child);
}
void visitImpl(const EraseStmt *stmt) { visit(stmt->getRootOpExpr()); }
void visitImpl(const LetStmt *stmt) { visit(stmt->getVarDecl()); }
void visitImpl(const ReplaceStmt *stmt) {
visit(stmt->getRootOpExpr());
for (const Node *child : stmt->getReplExprs())
visit(child);
}
void visitImpl(const ReturnStmt *stmt) { visit(stmt->getResultExpr()); }
void visitImpl(const RewriteStmt *stmt) {
visit(stmt->getRootOpExpr());
visit(stmt->getRewriteBody());
}
void visitImpl(const AttributeExpr *expr) {}
void visitImpl(const CallExpr *expr) {
visit(expr->getCallableExpr());
for (const Node *child : expr->getArguments())
visit(child);
}
void visitImpl(const DeclRefExpr *expr) { visit(expr->getDecl()); }
void visitImpl(const MemberAccessExpr *expr) { visit(expr->getParentExpr()); }
void visitImpl(const OperationExpr *expr) {
visit(expr->getNameDecl());
for (const Node *child : expr->getOperands())
visit(child);
for (const Node *child : expr->getResultTypes())
visit(child);
for (const Node *child : expr->getAttributes())
visit(child);
}
void visitImpl(const TupleExpr *expr) {
for (const Node *child : expr->getElements())
visit(child);
}
void visitImpl(const TypeExpr *expr) {}
void visitImpl(const AttrConstraintDecl *decl) { visit(decl->getTypeExpr()); }
void visitImpl(const OpConstraintDecl *decl) { visit(decl->getNameDecl()); }
void visitImpl(const TypeConstraintDecl *decl) {}
void visitImpl(const TypeRangeConstraintDecl *decl) {}
void visitImpl(const ValueConstraintDecl *decl) {
visit(decl->getTypeExpr());
}
void visitImpl(const ValueRangeConstraintDecl *decl) {
visit(decl->getTypeExpr());
}
void visitImpl(const NamedAttributeDecl *decl) { visit(decl->getValue()); }
void visitImpl(const OpNameDecl *decl) {}
void visitImpl(const PatternDecl *decl) { visit(decl->getBody()); }
void visitImpl(const UserConstraintDecl *decl) {
for (const Node *child : decl->getInputs())
visit(child);
for (const Node *child : decl->getResults())
visit(child);
visit(decl->getBody());
}
void visitImpl(const UserRewriteDecl *decl) {
for (const Node *child : decl->getInputs())
visit(child);
for (const Node *child : decl->getResults())
visit(child);
visit(decl->getBody());
}
void visitImpl(const VariableDecl *decl) {
visit(decl->getInitExpr());
for (const ConstraintRef &child : decl->getConstraints())
visit(child.constraint);
}
void visitImpl(const Module *module) {
for (const Node *child : module->getChildren())
visit(child);
}
function_ref<void(const Node *)> visitFn;
SmallPtrSet<const Node *, 16> alreadyVisited;
};
} // namespace
void Node::walk(function_ref<void(const Node *)> walkFn) const {
return NodeVisitor(walkFn).visit(this);
}
//===----------------------------------------------------------------------===//
// DeclScope
//===----------------------------------------------------------------------===//
void DeclScope::add(Decl *decl) {
const Name *name = decl->getName();
assert(name && "expected a named decl");
assert(!decls.count(name->getName()) && "decl with this name already exists");
decls.try_emplace(name->getName(), decl);
}
Decl *DeclScope::lookup(StringRef name) {
if (Decl *decl = decls.lookup(name))
return decl;
return parent ? parent->lookup(name) : nullptr;
}
//===----------------------------------------------------------------------===//
// CompoundStmt
//===----------------------------------------------------------------------===//
CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc,
ArrayRef<Stmt *> children) {
unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
std::uninitialized_copy(children.begin(), children.end(),
stmt->getChildren().begin());
return stmt;
}
//===----------------------------------------------------------------------===//
// LetStmt
//===----------------------------------------------------------------------===//
LetStmt *LetStmt::create(Context &ctx, SMRange loc,
VariableDecl *varDecl) {
return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
}
//===----------------------------------------------------------------------===//
// OpRewriteStmt
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// EraseStmt
EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) {
return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
}
//===----------------------------------------------------------------------===//
// ReplaceStmt
ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
ArrayRef<Expr *> replExprs) {
unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt));
ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size());
std::uninitialized_copy(replExprs.begin(), replExprs.end(),
stmt->getReplExprs().begin());
return stmt;
}
//===----------------------------------------------------------------------===//
// RewriteStmt
RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
CompoundStmt *rewriteBody) {
return new (ctx.getAllocator().Allocate<RewriteStmt>())
RewriteStmt(loc, rootOp, rewriteBody);
}
//===----------------------------------------------------------------------===//
// ReturnStmt
//===----------------------------------------------------------------------===//
ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
return new (ctx.getAllocator().Allocate<ReturnStmt>())
ReturnStmt(loc, resultExpr);
}
//===----------------------------------------------------------------------===//
// AttributeExpr
//===----------------------------------------------------------------------===//
AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
StringRef value) {
return new (ctx.getAllocator().Allocate<AttributeExpr>())
AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
}
//===----------------------------------------------------------------------===//
// CallExpr
//===----------------------------------------------------------------------===//
CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
ArrayRef<Expr *> arguments, Type resultType) {
unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
CallExpr *expr =
new (rawData) CallExpr(loc, resultType, callable, arguments.size());
std::uninitialized_copy(arguments.begin(), arguments.end(),
expr->getArguments().begin());
return expr;
}
//===----------------------------------------------------------------------===//
// DeclRefExpr
//===----------------------------------------------------------------------===//
DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl,
Type type) {
return new (ctx.getAllocator().Allocate<DeclRefExpr>())
DeclRefExpr(loc, decl, type);
}
//===----------------------------------------------------------------------===//
// MemberAccessExpr
//===----------------------------------------------------------------------===//
MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
const Expr *parentExpr,
StringRef memberName, Type type) {
return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
}
//===----------------------------------------------------------------------===//
// OperationExpr
//===----------------------------------------------------------------------===//
OperationExpr *OperationExpr::create(
Context &ctx, SMRange loc, const OpNameDecl *name,
ArrayRef<Expr *> operands, ArrayRef<Expr *> resultTypes,
ArrayRef<NamedAttributeDecl *> attributes) {
unsigned allocSize =
OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
operands.size() + resultTypes.size(), attributes.size());
void *rawData =
ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
Type resultType = OperationType::get(ctx, name->getName());
OperationExpr *opExpr = new (rawData)
OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
attributes.size(), name->getLoc());
std::uninitialized_copy(operands.begin(), operands.end(),
opExpr->getOperands().begin());
std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
opExpr->getResultTypes().begin());
std::uninitialized_copy(attributes.begin(), attributes.end(),
opExpr->getAttributes().begin());
return opExpr;
}
Optional<StringRef> OperationExpr::getName() const {
return getNameDecl()->getName();
}
//===----------------------------------------------------------------------===//
// TupleExpr
//===----------------------------------------------------------------------===//
TupleExpr *TupleExpr::create(Context &ctx, SMRange loc,
ArrayRef<Expr *> elements,
ArrayRef<StringRef> names) {
unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
auto elementTypes = llvm::map_range(
elements, [](const Expr *expr) { return expr->getType(); });
TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names);
TupleExpr *expr = new (rawData) TupleExpr(loc, type);
std::uninitialized_copy(elements.begin(), elements.end(),
expr->getElements().begin());
return expr;
}
//===----------------------------------------------------------------------===//
// TypeExpr
//===----------------------------------------------------------------------===//
TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) {
return new (ctx.getAllocator().Allocate<TypeExpr>())
TypeExpr(ctx, loc, copyStringWithNull(ctx, value));
}
//===----------------------------------------------------------------------===//
// AttrConstraintDecl
//===----------------------------------------------------------------------===//
AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc,
Expr *typeExpr) {
return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
AttrConstraintDecl(loc, typeExpr);
}
//===----------------------------------------------------------------------===//
// OpConstraintDecl
//===----------------------------------------------------------------------===//
OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc,
const OpNameDecl *nameDecl) {
if (!nameDecl)
nameDecl = OpNameDecl::create(ctx, SMRange());
return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
OpConstraintDecl(loc, nameDecl);
}
Optional<StringRef> OpConstraintDecl::getName() const {
return getNameDecl()->getName();
}
//===----------------------------------------------------------------------===//
// TypeConstraintDecl
//===----------------------------------------------------------------------===//
TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx,
SMRange loc) {
return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
TypeConstraintDecl(loc);
}
//===----------------------------------------------------------------------===//
// TypeRangeConstraintDecl
//===----------------------------------------------------------------------===//
TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
SMRange loc) {
return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
TypeRangeConstraintDecl(loc);
}
//===----------------------------------------------------------------------===//
// ValueConstraintDecl
//===----------------------------------------------------------------------===//
ValueConstraintDecl *
ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
ValueConstraintDecl(loc, typeExpr);
}
//===----------------------------------------------------------------------===//
// ValueRangeConstraintDecl
//===----------------------------------------------------------------------===//
ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
SMRange loc,
Expr *typeExpr) {
return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
ValueRangeConstraintDecl(loc, typeExpr);
}
//===----------------------------------------------------------------------===//
// UserConstraintDecl
//===----------------------------------------------------------------------===//
UserConstraintDecl *UserConstraintDecl::createImpl(
Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
const CompoundStmt *body, Type resultType) {
unsigned allocSize = UserConstraintDecl::totalSizeToAlloc<VariableDecl *>(
inputs.size() + results.size());
void *rawData =
ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
if (codeBlock)
codeBlock = codeBlock->copy(ctx.getAllocator());
UserConstraintDecl *decl = new (rawData) UserConstraintDecl(
name, inputs.size(), results.size(), codeBlock, body, resultType);
std::uninitialized_copy(inputs.begin(), inputs.end(),
decl->getInputs().begin());
std::uninitialized_copy(results.begin(), results.end(),
decl->getResults().begin());
return decl;
}
//===----------------------------------------------------------------------===//
// NamedAttributeDecl
//===----------------------------------------------------------------------===//
NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name,
Expr *value) {
return new (ctx.getAllocator().Allocate<NamedAttributeDecl>())
NamedAttributeDecl(name, value);
}
//===----------------------------------------------------------------------===//
// OpNameDecl
//===----------------------------------------------------------------------===//
OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
}
OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) {
return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
}
//===----------------------------------------------------------------------===//
// PatternDecl
//===----------------------------------------------------------------------===//
PatternDecl *PatternDecl::create(Context &ctx, SMRange loc,
const Name *name, Optional<uint16_t> benefit,
bool hasBoundedRecursion,
const CompoundStmt *body) {
return new (ctx.getAllocator().Allocate<PatternDecl>())
PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
}
//===----------------------------------------------------------------------===//
// UserRewriteDecl
//===----------------------------------------------------------------------===//
UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
ArrayRef<VariableDecl *> inputs,
ArrayRef<VariableDecl *> results,
Optional<StringRef> codeBlock,
const CompoundStmt *body,
Type resultType) {
unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
inputs.size() + results.size());
void *rawData =
ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
if (codeBlock)
codeBlock = codeBlock->copy(ctx.getAllocator());
UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
name, inputs.size(), results.size(), codeBlock, body, resultType);
std::uninitialized_copy(inputs.begin(), inputs.end(),
decl->getInputs().begin());
std::uninitialized_copy(results.begin(), results.end(),
decl->getResults().begin());
return decl;
}
//===----------------------------------------------------------------------===//
// VariableDecl
//===----------------------------------------------------------------------===//
VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
Expr *initExpr,
ArrayRef<ConstraintRef> constraints) {
unsigned allocSize =
VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
VariableDecl *varDecl =
new (rawData) VariableDecl(name, type, initExpr, constraints.size());
std::uninitialized_copy(constraints.begin(), constraints.end(),
varDecl->getConstraints().begin());
return varDecl;
}
//===----------------------------------------------------------------------===//
// Module
//===----------------------------------------------------------------------===//
Module *Module::create(Context &ctx, SMLoc loc,
ArrayRef<Decl *> children) {
unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
Module *module = new (rawData) Module(loc, children.size());
std::uninitialized_copy(children.begin(), children.end(),
module->getChildren().begin());
return module;
}