mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-13 08:46:22 +00:00
[mlir][SMT] restore custom builder for forall/exists (#135470)
This reverts commit 54e70ac7650f1c22f687937d1a082e4152f97b22 which
itself fixed an [asan
leak](https://lab.llvm.org/buildbot/#/builders/55/builds/9761) from the
original upstreaming commit. The leak was due to op allocations not
being `free`ed.
~~The necessary change was to explicitly `->destroy()` the ops at the
end of the tests. I believe this is because the rewriter used in the
tests doesn't actually insert them into a module and so without an
explicit `->destroy()` no bookkeeping process is able to take care of
them.~~
The necessary change was to use `OwningOpRef` which calls `op->erase()`
in its [own
destructor](89cfae41ec/mlir/include/mlir/IR/OwningOpRef.h (L39)
).
This commit is contained in:
parent
33e5305c59
commit
c6a892e0ed
@ -448,6 +448,18 @@ class QuantifierOp<string mnemonic> : SMTOp<mnemonic, [
|
||||
VariadicRegion<SizedRegion<1>>:$patterns);
|
||||
let results = (outs BoolType:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins
|
||||
"TypeRange":$boundVarTypes,
|
||||
"function_ref<Value(OpBuilder &, Location, ValueRange)>":$bodyBuilder,
|
||||
CArg<"std::optional<ArrayRef<StringRef>>", "std::nullopt">:$boundVarNames,
|
||||
CArg<"function_ref<ValueRange(OpBuilder &, Location, ValueRange)>",
|
||||
"{}">:$patternBuilder,
|
||||
CArg<"uint32_t", "0">:$weight,
|
||||
CArg<"bool", "false">:$noPattern)>
|
||||
];
|
||||
let skipDefaultBuilders = true;
|
||||
|
||||
let assemblyFormat = [{
|
||||
($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)?
|
||||
attr-dict-with-keyword $body (`patterns` $patterns^)?
|
||||
|
@ -432,6 +432,16 @@ LogicalResult ForallOp::verifyRegions() {
|
||||
return verifyQuantifierRegions(*this);
|
||||
}
|
||||
|
||||
void ForallOp::build(
|
||||
OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
|
||||
function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
|
||||
std::optional<ArrayRef<StringRef>> boundVarNames,
|
||||
function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
|
||||
uint32_t weight, bool noPattern) {
|
||||
buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
|
||||
boundVarNames, patternBuilder, weight, noPattern);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExistsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -448,5 +458,15 @@ LogicalResult ExistsOp::verifyRegions() {
|
||||
return verifyQuantifierRegions(*this);
|
||||
}
|
||||
|
||||
void ExistsOp::build(
|
||||
OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
|
||||
function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
|
||||
std::optional<ArrayRef<StringRef>> boundVarNames,
|
||||
function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
|
||||
uint32_t weight, bool noPattern) {
|
||||
buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
|
||||
boundVarNames, patternBuilder, weight, noPattern);
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/SMT/IR/SMT.cpp.inc"
|
||||
|
@ -1,5 +1,6 @@
|
||||
add_mlir_unittest(MLIRSMTTests
|
||||
AttributeTest.cpp
|
||||
QuantifierTest.cpp
|
||||
TypeTest.cpp
|
||||
)
|
||||
|
||||
|
187
mlir/unittests/Dialect/SMT/QuantifierTest.cpp
Normal file
187
mlir/unittests/Dialect/SMT/QuantifierTest.cpp
Normal file
@ -0,0 +1,187 @@
|
||||
//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SMT/IR/SMTOps.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace smt;
|
||||
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test custom builders of ExistsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TEST(QuantifierTest, ExistsBuilderWithPattern) {
|
||||
MLIRContext context;
|
||||
context.loadDialect<SMTDialect>();
|
||||
Location loc(UnknownLoc::get(&context));
|
||||
|
||||
OpBuilder builder(&context);
|
||||
auto boolTy = BoolType::get(&context);
|
||||
|
||||
OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
|
||||
loc, TypeRange{boolTy, boolTy},
|
||||
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
|
||||
return builder.create<AndOp>(loc, boundVars);
|
||||
},
|
||||
std::nullopt,
|
||||
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
|
||||
return boundVars;
|
||||
},
|
||||
/*weight=*/2);
|
||||
|
||||
SmallVector<char, 1024> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
existsOp->print(stream);
|
||||
|
||||
ASSERT_STREQ(
|
||||
stream.str().str().c_str(),
|
||||
"%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, "
|
||||
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
|
||||
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
|
||||
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
|
||||
}
|
||||
|
||||
TEST(QuantifierTest, ExistsBuilderNoPattern) {
|
||||
MLIRContext context;
|
||||
context.loadDialect<SMTDialect>();
|
||||
Location loc(UnknownLoc::get(&context));
|
||||
|
||||
OpBuilder builder(&context);
|
||||
auto boolTy = BoolType::get(&context);
|
||||
|
||||
OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
|
||||
loc, TypeRange{boolTy, boolTy},
|
||||
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
|
||||
return builder.create<AndOp>(loc, boundVars);
|
||||
},
|
||||
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
|
||||
|
||||
SmallVector<char, 1024> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
existsOp->print(stream);
|
||||
|
||||
ASSERT_STREQ(stream.str().str().c_str(),
|
||||
"%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
|
||||
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
|
||||
"smt.yield %0 : !smt.bool\n}\n");
|
||||
}
|
||||
|
||||
TEST(QuantifierTest, ExistsBuilderDefault) {
|
||||
MLIRContext context;
|
||||
context.loadDialect<SMTDialect>();
|
||||
Location loc(UnknownLoc::get(&context));
|
||||
|
||||
OpBuilder builder(&context);
|
||||
auto boolTy = BoolType::get(&context);
|
||||
|
||||
OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
|
||||
loc, TypeRange{boolTy, boolTy},
|
||||
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
|
||||
return builder.create<AndOp>(loc, boundVars);
|
||||
},
|
||||
ArrayRef<StringRef>{"a", "b"});
|
||||
|
||||
SmallVector<char, 1024> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
existsOp->print(stream);
|
||||
|
||||
ASSERT_STREQ(stream.str().str().c_str(),
|
||||
"%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, "
|
||||
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
|
||||
"%0 : !smt.bool\n}\n");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test custom builders of ForallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TEST(QuantifierTest, ForallBuilderWithPattern) {
|
||||
MLIRContext context;
|
||||
context.loadDialect<SMTDialect>();
|
||||
Location loc(UnknownLoc::get(&context));
|
||||
|
||||
OpBuilder builder(&context);
|
||||
auto boolTy = BoolType::get(&context);
|
||||
|
||||
OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
|
||||
loc, TypeRange{boolTy, boolTy},
|
||||
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
|
||||
return builder.create<AndOp>(loc, boundVars);
|
||||
},
|
||||
ArrayRef<StringRef>{"a", "b"},
|
||||
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
|
||||
return boundVars;
|
||||
},
|
||||
/*weight=*/2);
|
||||
|
||||
SmallVector<char, 1024> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
forallOp->print(stream);
|
||||
|
||||
ASSERT_STREQ(
|
||||
stream.str().str().c_str(),
|
||||
"%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, "
|
||||
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
|
||||
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
|
||||
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
|
||||
}
|
||||
|
||||
TEST(QuantifierTest, ForallBuilderNoPattern) {
|
||||
MLIRContext context;
|
||||
context.loadDialect<SMTDialect>();
|
||||
Location loc(UnknownLoc::get(&context));
|
||||
|
||||
OpBuilder builder(&context);
|
||||
auto boolTy = BoolType::get(&context);
|
||||
|
||||
OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
|
||||
loc, TypeRange{boolTy, boolTy},
|
||||
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
|
||||
return builder.create<AndOp>(loc, boundVars);
|
||||
},
|
||||
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
|
||||
|
||||
SmallVector<char, 1024> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
forallOp->print(stream);
|
||||
|
||||
ASSERT_STREQ(stream.str().str().c_str(),
|
||||
"%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
|
||||
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
|
||||
"smt.yield %0 : !smt.bool\n}\n");
|
||||
}
|
||||
|
||||
TEST(QuantifierTest, ForallBuilderDefault) {
|
||||
MLIRContext context;
|
||||
context.loadDialect<SMTDialect>();
|
||||
Location loc(UnknownLoc::get(&context));
|
||||
|
||||
OpBuilder builder(&context);
|
||||
auto boolTy = BoolType::get(&context);
|
||||
|
||||
OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
|
||||
loc, TypeRange{boolTy, boolTy},
|
||||
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
|
||||
return builder.create<AndOp>(loc, boundVars);
|
||||
},
|
||||
std::nullopt);
|
||||
|
||||
SmallVector<char, 1024> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
forallOp->print(stream);
|
||||
|
||||
ASSERT_STREQ(stream.str().str().c_str(),
|
||||
"%0 = smt.forall {\n^bb0(%arg0: !smt.bool, "
|
||||
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
|
||||
"%0 : !smt.bool\n}\n");
|
||||
}
|
||||
|
||||
} // namespace
|
Loading…
x
Reference in New Issue
Block a user