diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td index af73955caee5..1872c00b74f1 100644 --- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td @@ -448,6 +448,18 @@ class QuantifierOp : SMTOp>:$patterns); let results = (outs BoolType:$result); + let builders = [ + OpBuilder<(ins + "TypeRange":$boundVarTypes, + "function_ref":$bodyBuilder, + CArg<"std::optional>", "std::nullopt">:$boundVarNames, + CArg<"function_ref", + "{}">:$patternBuilder, + CArg<"uint32_t", "0">:$weight, + CArg<"bool", "false">:$noPattern)> + ]; + let skipDefaultBuilders = true; + let assemblyFormat = [{ ($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)? attr-dict-with-keyword $body (`patterns` $patterns^)? diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp index 604dd26da198..8977a3abc125 100644 --- a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp +++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp @@ -432,6 +432,16 @@ LogicalResult ForallOp::verifyRegions() { return verifyQuantifierRegions(*this); } +void ForallOp::build( + OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, + function_ref bodyBuilder, + std::optional> boundVarNames, + function_ref patternBuilder, + uint32_t weight, bool noPattern) { + buildQuantifier(odsBuilder, odsState, boundVarTypes, bodyBuilder, + boundVarNames, patternBuilder, weight, noPattern); +} + //===----------------------------------------------------------------------===// // ExistsOp //===----------------------------------------------------------------------===// @@ -448,5 +458,15 @@ LogicalResult ExistsOp::verifyRegions() { return verifyQuantifierRegions(*this); } +void ExistsOp::build( + OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, + function_ref bodyBuilder, + std::optional> boundVarNames, + function_ref patternBuilder, + uint32_t weight, bool noPattern) { + buildQuantifier(odsBuilder, odsState, boundVarTypes, bodyBuilder, + boundVarNames, patternBuilder, weight, noPattern); +} + #define GET_OP_CLASSES #include "mlir/Dialect/SMT/IR/SMT.cpp.inc" diff --git a/mlir/unittests/Dialect/SMT/CMakeLists.txt b/mlir/unittests/Dialect/SMT/CMakeLists.txt index 86e16d6194ea..a1331467feba 100644 --- a/mlir/unittests/Dialect/SMT/CMakeLists.txt +++ b/mlir/unittests/Dialect/SMT/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRSMTTests AttributeTest.cpp + QuantifierTest.cpp TypeTest.cpp ) diff --git a/mlir/unittests/Dialect/SMT/QuantifierTest.cpp b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp new file mode 100644 index 000000000000..d7c57f0acbbe --- /dev/null +++ b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp @@ -0,0 +1,187 @@ +//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SMT/IR/SMTOps.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace smt; + +namespace { + +//===----------------------------------------------------------------------===// +// Test custom builders of ExistsOp +//===----------------------------------------------------------------------===// + +TEST(QuantifierTest, ExistsBuilderWithPattern) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + OwningOpRef existsOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + std::nullopt, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return boundVars; + }, + /*weight=*/2); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + existsOp->print(stream); + + ASSERT_STREQ( + stream.str().str().c_str(), + "%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, " + "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : " + "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n " + "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n"); +} + +TEST(QuantifierTest, ExistsBuilderNoPattern) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + OwningOpRef existsOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + ArrayRef{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + existsOp->print(stream); + + ASSERT_STREQ(stream.str().str().c_str(), + "%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: " + "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n " + "smt.yield %0 : !smt.bool\n}\n"); +} + +TEST(QuantifierTest, ExistsBuilderDefault) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + OwningOpRef existsOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + ArrayRef{"a", "b"}); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + existsOp->print(stream); + + ASSERT_STREQ(stream.str().str().c_str(), + "%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, " + "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield " + "%0 : !smt.bool\n}\n"); +} + +//===----------------------------------------------------------------------===// +// Test custom builders of ForallOp +//===----------------------------------------------------------------------===// + +TEST(QuantifierTest, ForallBuilderWithPattern) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + OwningOpRef forallOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + ArrayRef{"a", "b"}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return boundVars; + }, + /*weight=*/2); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + forallOp->print(stream); + + ASSERT_STREQ( + stream.str().str().c_str(), + "%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, " + "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : " + "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n " + "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n"); +} + +TEST(QuantifierTest, ForallBuilderNoPattern) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + OwningOpRef forallOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + ArrayRef{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + forallOp->print(stream); + + ASSERT_STREQ(stream.str().str().c_str(), + "%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: " + "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n " + "smt.yield %0 : !smt.bool\n}\n"); +} + +TEST(QuantifierTest, ForallBuilderDefault) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + OwningOpRef forallOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + std::nullopt); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + forallOp->print(stream); + + ASSERT_STREQ(stream.str().str().c_str(), + "%0 = smt.forall {\n^bb0(%arg0: !smt.bool, " + "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield " + "%0 : !smt.bool\n}\n"); +} + +} // namespace