[MLIR][python bindings] Add PyValue.print_as_operand (Value::printAsOperand)

Useful for easier debugging (no need to regex out all of the stuff around the id).

Differential Revision: https://reviews.llvm.org/D149902
This commit is contained in:
max 2023-05-07 18:19:46 -05:00
parent 8dacfdf033
commit 81233c70cb
6 changed files with 150 additions and 5 deletions

View File

@ -776,6 +776,12 @@ MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value);
MLIR_CAPI_EXPORTED void MLIR_CAPI_EXPORTED void
mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
/// Prints a value as an operand (i.e., the ValueID).
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value,
MlirOpPrintingFlags flags,
MlirStringCallback callback,
void *userData);
/// Returns an op operand representing the first use of the value, or a null op /// Returns an op operand representing the first use of the value, or a null op
/// operand if there are no uses. /// operand if there are no uses.
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);

View File

@ -226,6 +226,7 @@ public:
/// Print this value as if it were an operand. /// Print this value as if it were an operand.
void printAsOperand(raw_ostream &os, AsmState &state); void printAsOperand(raw_ostream &os, AsmState &state);
void printAsOperand(raw_ostream &os, const OpPrintingFlags &flags);
/// Methods for supporting PointerLikeTypeTraits. /// Methods for supporting PointerLikeTypeTraits.
void *getAsOpaquePointer() const { return impl; } void *getAsOpaquePointer() const { return impl; }

View File

@ -156,6 +156,10 @@ position in the argument list. If the value is an operation result, this is
equivalent to printing the operation that produced it. equivalent to printing the operation that produced it.
)"; )";
static const char kGetNameAsOperand[] =
R"(Returns the string form of value as an operand (i.e., the ValueID).
)";
static const char kValueReplaceAllUsesWithDocstring[] = static const char kValueReplaceAllUsesWithDocstring[] =
R"(Replace all uses of value with the new value, updating anything in R"(Replace all uses of value with the new value, updating anything in
the IR that uses 'self' to use the other value instead. the IR that uses 'self' to use the other value instead.
@ -3336,6 +3340,19 @@ void mlir::python::populateIRCore(py::module &m) {
return printAccum.join(); return printAccum.join();
}, },
kValueDunderStrDocstring) kValueDunderStrDocstring)
.def(
"get_name",
[](PyValue &self, bool useLocalScope) {
PyPrintAccumulator printAccum;
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (useLocalScope)
mlirOpPrintingFlagsUseLocalScope(flags);
mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(),
printAccum.getUserData());
mlirOpPrintingFlagsDestroy(flags);
return printAccum.join();
},
py::arg("use_local_scope") = false, kGetNameAsOperand)
.def_property_readonly("type", .def_property_readonly("type",
[](PyValue &self) { [](PyValue &self) {
return PyType( return PyType(

View File

@ -20,6 +20,7 @@
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h" #include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h" #include "mlir/Parser/Parser.h"
@ -767,6 +768,13 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback,
unwrap(value).print(stream); unwrap(value).print(stream);
} }
void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags,
MlirStringCallback callback, void *userData) {
detail::CallbackOstream stream(callback, userData);
Value cppValue = unwrap(value);
cppValue.printAsOperand(stream, *unwrap(flags));
}
MlirOpOperand mlirValueGetFirstUse(MlirValue value) { MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
Value cppValue = unwrap(value); Value cppValue = unwrap(value);
if (cppValue.use_empty()) if (cppValue.use_empty())

View File

@ -44,8 +44,8 @@
#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/Threading.h" #include "llvm/Support/Threading.h"
#include <tuple>
#include <optional> #include <optional>
#include <tuple>
using namespace mlir; using namespace mlir;
using namespace mlir::detail; using namespace mlir::detail;
@ -3673,10 +3673,7 @@ void Value::printAsOperand(raw_ostream &os, AsmState &state) {
os); os);
} }
void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
// Find the operation to number from based upon the provided flags.
Operation *op = this;
bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
do { do {
// If we are printing local scope, stop at the first operation that is // If we are printing local scope, stop at the first operation that is
// isolated from above. // isolated from above.
@ -3689,7 +3686,28 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
break; break;
op = parentOp; op = parentOp;
} while (true); } while (true);
return op;
}
void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
Operation *op;
if (auto result = dyn_cast<OpResult>()) {
op = result.getOwner();
} else {
op = cast<BlockArgument>().getOwner()->getParentOp();
if (!op) {
os << "<<UNKNOWN SSA VALUE>>";
return;
}
}
op = findParent(op, flags.shouldUseLocalScope());
AsmState state(op, flags);
printAsOperand(os, state);
}
void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
// Find the operation to number from based upon the provided flags.
Operation *op = findParent(this, printerFlags.shouldUseLocalScope());
AsmState state(op, printerFlags); AsmState state(op, printerFlags);
print(os, state); print(os, state);
} }

View File

@ -2,6 +2,7 @@
import gc import gc
from mlir.ir import * from mlir.ir import *
from mlir.dialects import func
def run(f): def run(f):
@ -90,6 +91,7 @@ def testValueHash():
assert hash(block.arguments[0]) == hash(op.operands[0]) assert hash(block.arguments[0]) == hash(op.operands[0])
assert hash(op.result) == hash(ret.operands[0]) assert hash(op.result) == hash(ret.operands[0])
# CHECK-LABEL: TEST: testValueUses # CHECK-LABEL: TEST: testValueUses
@run @run
def testValueUses(): def testValueUses():
@ -112,6 +114,7 @@ def testValueUses():
print(f"Use owner: {use.owner}") print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}") print(f"Use operand_number: {use.operand_number}")
# CHECK-LABEL: TEST: testValueReplaceAllUsesWith # CHECK-LABEL: TEST: testValueReplaceAllUsesWith
@run @run
def testValueReplaceAllUsesWith(): def testValueReplaceAllUsesWith():
@ -137,3 +140,95 @@ def testValueReplaceAllUsesWith():
assert use.owner in [op1, op2] assert use.owner in [op1, op2]
print(f"Use owner: {use.owner}") print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}") print(f"Use operand_number: {use.operand_number}")
# CHECK-LABEL: TEST: testValuePrintAsOperand
@run
def testValuePrintAsOperand():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
# CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
print(value)
value2 = Operation.create("custom.op2", results=[i32]).results[0]
# CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
print(value2)
f = func.FuncOp("test", ([i32, i32], []))
entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32])
with InsertionPoint(entry_block1):
value3 = Operation.create("custom.op3", results=[i32]).results[0]
# CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
print(value3)
value4 = Operation.create("custom.op4", results=[i32]).results[0]
# CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
print(value4)
f = func.FuncOp("test", ([i32, i32], []))
entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
with InsertionPoint(entry_block2):
value5 = Operation.create("custom.op5", results=[i32]).results[0]
# CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
print(value5)
value6 = Operation.create("custom.op6", results=[i32]).results[0]
# CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
print(value6)
func.ReturnOp([])
func.ReturnOp([])
# CHECK: %[[VAL1]]
print(value.get_name())
# CHECK: %[[VAL2]]
print(value2.get_name())
# CHECK: %[[VAL3]]
print(value3.get_name())
# CHECK: %[[VAL4]]
print(value4.get_name())
# CHECK: %0
print(value3.get_name(use_local_scope=True))
# CHECK: %1
print(value4.get_name(use_local_scope=True))
# CHECK: %[[VAL5]]
print(value5.get_name())
# CHECK: %[[VAL6]]
print(value6.get_name())
# CHECK: %[[ARG0:.*]]
print(entry_block1.arguments[0].get_name())
# CHECK: %[[ARG1:.*]]
print(entry_block1.arguments[1].get_name())
# CHECK: %[[ARG2:.*]]
print(entry_block2.arguments[0].get_name())
# CHECK: %[[ARG3:.*]]
print(entry_block2.arguments[1].get_name())
# CHECK: module {
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i32
# CHECK: %[[VAL2]] = "custom.op2"() : () -> i32
# CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
# CHECK: %[[VAL3]] = "custom.op3"() : () -> i32
# CHECK: %[[VAL4]] = "custom.op4"() : () -> i32
# CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
# CHECK: %[[VAL5]] = "custom.op5"() : () -> i32
# CHECK: %[[VAL6]] = "custom.op6"() : () -> i32
# CHECK: return
# CHECK: }
# CHECK: return
# CHECK: }
# CHECK: }
print(module)
value2.owner.detach_from_parent()
# CHECK: %0
print(value2.get_name())