Revert "[MLIR][python bindings] implement replace_all_uses_with on PyValue"

This reverts commit 3bab7cb089d92cc7025ebc57ef3a74d3ce94ecd8 because it breaks sanitizers.

Differential Revision: https://reviews.llvm.org/D149188
This commit is contained in:
max 2023-04-25 15:32:14 -05:00
parent 20d0f80dd3
commit fd527ceff1
5 changed files with 74 additions and 127 deletions

View File

@ -755,12 +755,6 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
/// operand if there are no uses.
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
/// Replace all uses of 'of' value with the 'with' value, updating anything in
/// the IR that uses 'of' to use the other value instead. When this returns
/// there are zero uses of 'of'.
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
MlirValue with);
//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//

View File

@ -13,9 +13,11 @@
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
//#include "mlir-c/Registration.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@ -152,11 +154,6 @@ position in the argument list. If the value is an operation result, this is
equivalent to printing the operation that produced it.
)";
static const char kValueReplaceAllUsesWithDocstring[] =
R"(Replace all uses of value with the new value, updating anything in
the IR that uses 'self' to use the other value instead.
)";
//------------------------------------------------------------------------------
// Utilities.
//------------------------------------------------------------------------------
@ -3319,18 +3316,10 @@ void mlir::python::populateIRCore(py::module &m) {
return printAccum.join();
},
kValueDunderStrDocstring)
.def_property_readonly("type",
[](PyValue &self) {
return PyType(
self.getParentOperation()->getContext(),
mlirValueGetType(self.get()));
})
.def(
"replace_all_uses_with",
[](PyValue &self, PyValue &with) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
kValueReplaceAllUsesWithDocstring);
.def_property_readonly("type", [](PyValue &self) {
return PyType(self.getParentOperation()->getContext(),
mlirValueGetType(self.get()));
});
PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);

View File

@ -751,10 +751,6 @@ MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
return wrap(opOperand);
}
void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
}
//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//

View File

@ -28,25 +28,6 @@
#include <stdlib.h>
#include <string.h>
MlirValue makeConstantLiteral(MlirContext ctx, const char *literalStr,
const char *typeStr) {
MlirLocation loc = mlirLocationUnknownGet(ctx);
char attrStr[50];
sprintf(attrStr, "%s : %s", literalStr, typeStr);
MlirAttribute literal =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(attrStr));
MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), literal);
MlirOperationState constState = mlirOperationStateGet(
mlirStringRefCreateFromCString("arith.constant"), loc);
MlirType type =
mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(typeStr));
mlirOperationStateAddResults(&constState, 1, &type);
mlirOperationStateAddAttributes(&constState, 1, &valueAttr);
MlirOperation constOp = mlirOperationCreate(&constState);
return mlirOperationGetResult(constOp, 0);
}
static void registerAllUpstreamDialects(MlirContext ctx) {
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterAllDialects(registry);
@ -134,17 +115,26 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
MlirOperation func = mlirOperationCreate(&funcState);
mlirBlockInsertOwnedOperation(moduleBody, 0, func);
MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
MlirOperation constZero = mlirOpResultGetOwner(constZeroValue);
MlirType indexType =
mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
MlirAttribute indexZeroLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
indexZeroLiteral);
MlirOperationState constZeroState = mlirOperationStateGet(
mlirStringRefCreateFromCString("arith.constant"), location);
mlirOperationStateAddResults(&constZeroState, 1, &indexType);
mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
MlirOperation constZero = mlirOperationCreate(&constZeroState);
mlirBlockAppendOwnedOperation(funcBody, constZero);
MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
MlirValue dimOperands[] = {funcArg0, constZeroValue};
MlirOperationState dimState = mlirOperationStateGet(
mlirStringRefCreateFromCString("memref.dim"), location);
mlirOperationStateAddOperands(&dimState, 2, dimOperands);
MlirType indexType =
mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
mlirOperationStateAddResults(&dimState, 1, &indexType);
MlirOperation dim = mlirOperationCreate(&dimState);
mlirBlockAppendOwnedOperation(funcBody, dim);
@ -163,11 +153,11 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
mlirStringRefCreateFromCString("arith.constant"), location);
mlirOperationStateAddResults(&constOneState, 1, &indexType);
mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
MlirValue constOneValue = makeConstantLiteral(ctx, "1", "index");
MlirOperation constOne = mlirOpResultGetOwner(constOneValue);
MlirOperation constOne = mlirOperationCreate(&constOneState);
mlirBlockAppendOwnedOperation(funcBody, constOne);
MlirValue dimValue = mlirOperationGetResult(dim, 0);
MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
MlirOperationState loopState = mlirOperationStateGet(
mlirStringRefCreateFromCString("scf.for"), location);
@ -830,6 +820,11 @@ static int printBuiltinTypes(MlirContext ctx) {
return 0;
}
void callbackSetFixedLengthString(const char *data, intptr_t len,
void *userData) {
strncpy(userData, data, len);
}
bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
if (strlen(lhs) != rhs.length) {
return false;
@ -1799,10 +1794,32 @@ int testOperands(void) {
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith"));
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
MlirLocation loc = mlirLocationUnknownGet(ctx);
MlirType indexType = mlirIndexTypeGet(ctx);
// Create some constants to use as operands.
MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
MlirValue constOneValue = makeConstantLiteral(ctx, "1", "index");
MlirAttribute indexZeroLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
indexZeroLiteral);
MlirOperationState constZeroState = mlirOperationStateGet(
mlirStringRefCreateFromCString("arith.constant"), loc);
mlirOperationStateAddResults(&constZeroState, 1, &indexType);
mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
MlirOperation constZero = mlirOperationCreate(&constZeroState);
MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
MlirAttribute indexOneLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
indexOneLiteral);
MlirOperationState constOneState = mlirOperationStateGet(
mlirStringRefCreateFromCString("arith.constant"), loc);
mlirOperationStateAddResults(&constOneState, 1, &indexType);
mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
MlirOperation constOne = mlirOperationCreate(&constOneState);
MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
// Create the operation under test.
mlirContextSetAllowUnregisteredDialects(ctx, true);
@ -1856,50 +1873,9 @@ int testOperands(void) {
return 3;
}
MlirOperationState op2State =
mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
MlirValue initialOperands2[] = {constOneValue};
mlirOperationStateAddOperands(&op2State, 1, initialOperands2);
MlirOperation op2 = mlirOperationCreate(&op2State);
MlirOpOperand use3 = mlirValueGetFirstUse(constOneValue);
fprintf(stderr, "First use owner: ");
mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: First use owner: "dummy.op2"
use3 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constOneValue));
fprintf(stderr, "Second use owner: ");
mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Second use owner: "dummy.op"
MlirValue constTwoValue = makeConstantLiteral(ctx, "2", "index");
mlirValueReplaceAllUsesOfWith(constOneValue, constTwoValue);
use3 = mlirValueGetFirstUse(constOneValue);
if (!mlirOpOperandIsNull(use3)) {
fprintf(stderr, "ERROR: Use should be null\n");
return 4;
}
MlirOpOperand use4 = mlirValueGetFirstUse(constTwoValue);
fprintf(stderr, "First replacement use owner: ");
mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: First replacement use owner: "dummy.op"
use4 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constTwoValue));
fprintf(stderr, "Second replacement use owner: ");
mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Second replacement use owner: "dummy.op2"
mlirOperationDestroy(op);
mlirOperationDestroy(op2);
mlirOperationDestroy(mlirOpResultGetOwner(constZeroValue));
mlirOperationDestroy(mlirOpResultGetOwner(constOneValue));
mlirOperationDestroy(mlirOpResultGetOwner(constTwoValue));
mlirOperationDestroy(constZero);
mlirOperationDestroy(constOne);
mlirContextDestroy(ctx);
return 0;
@ -1914,10 +1890,19 @@ int testClone(void) {
registerAllUpstreamDialects(ctx);
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
MlirLocation loc = mlirLocationUnknownGet(ctx);
MlirType indexType = mlirIndexTypeGet(ctx);
MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
MlirOperation constZero = mlirOpResultGetOwner(constZeroValue);
MlirAttribute indexZeroLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
MlirOperationState constZeroState = mlirOperationStateGet(
mlirStringRefCreateFromCString("arith.constant"), loc);
mlirOperationStateAddResults(&constZeroState, 1, &indexType);
mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
MlirOperation constZero = mlirOperationCreate(&constZeroState);
MlirAttribute indexOneLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
@ -1995,10 +1980,19 @@ int testTypeID(MlirContext ctx) {
}
MlirLocation loc = mlirLocationUnknownGet(ctx);
MlirType indexType = mlirIndexTypeGet(ctx);
MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
// Create a registered operation, which should have a type id.
MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
MlirOperation constZero = mlirOpResultGetOwner(constZeroValue);
MlirAttribute indexZeroLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
MlirOperationState constZeroState = mlirOperationStateGet(
mlirStringRefCreateFromCString("arith.constant"), loc);
mlirOperationStateAddResults(&constZeroState, 1, &indexType);
mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
MlirOperation constZero = mlirOperationCreate(&constZeroState);
if (!mlirOperationVerify(constZero)) {
fprintf(stderr, "ERROR: Expected operation to verify correctly\n");

View File

@ -111,29 +111,3 @@ def testValueUses():
assert use.owner in [op1, op2]
print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}")
# CHECK-LABEL: TEST: testValueReplaceAllUsesWith
@run
def testValueReplaceAllUsesWith():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
op1 = Operation.create("custom.op2", operands=[value])
op2 = Operation.create("custom.op2", operands=[value])
value2 = Operation.create("custom.op3", results=[i32]).results[0]
value.replace_all_uses_with(value2)
assert len(list(value.uses)) == 0
# CHECK: Use owner: "custom.op2"
# CHECK: Use operand_number: 0
# CHECK: Use owner: "custom.op2"
# CHECK: Use operand_number: 0
for use in value2.uses:
assert use.owner in [op1, op2]
print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}")