Allow SymbolUserOpInterface operators to be used in RemoveDeadValues Pass (#117405)

This change removes the restriction on `SymbolUserOpInterface` operators
so they can be used with operators that implement `SymbolOpInterface`,
example:

`memref.global` implements `SymbolOpInterface` so it can be used with
`memref.get_global` which implements `SymbolUserOpInterface`

```
// Define a global constant array
memref.global "private" constant @global_array : memref<10xi32> = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<10xi32>

// Access this global constant within a function
func @use_global() {
  %0 = memref.get_global @global_array : memref<10xi32>
}
```

Reference: https://github.com/llvm/llvm-project/pull/116519 and
https://discourse.llvm.org/t/question-on-criteria-for-acceptable-ir-in-removedeadvaluespass/83131

---------

Co-authored-by: Zeeshan Siddiqui <mzs@ntdev.microsoft.com>
This commit is contained in:
M. Zeeshan Siddiqui 2024-11-23 10:37:29 -08:00 committed by GitHub
parent 14b9ca3f38
commit 5f9db0876a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 6 deletions

View File

@ -577,10 +577,8 @@ void RemoveDeadValues::runOnOperation() {
WalkResult acceptableIR = module->walk([&](Operation *op) {
if (op == module)
return WalkResult::advance();
if (isa<BranchOpInterface>(op) ||
(isa<SymbolUserOpInterface>(op) && !isa<CallOpInterface>(op))) {
op->emitError() << "cannot optimize an IR with "
"non-call symbol user ops or branch ops\n";
if (isa<BranchOpInterface>(op)) {
op->emitError() << "cannot optimize an IR with branch ops\n";
return WalkResult::interrupt();
}
return WalkResult::advance();

View File

@ -3,9 +3,12 @@
// The IR is updated regardless of memref.global private constant
//
module {
memref.global "private" constant @__something_global : memref<i32> = dense<0>
// CHECK: memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> {alignment = 16 : i64}
memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> {alignment = 16 : i64}
func.func @main(%arg0: i32) -> i32 {
%0 = tensor.empty() : tensor<10xbf16>
// CHECK-NOT: memref.get_global
%1 = memref.get_global @__constant_4xi32 : memref<4xi32>
// CHECK-NOT: tensor.empty
return %arg0 : i32
}
@ -29,7 +32,7 @@ module @named_module_acceptable {
//
func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
%non_live = arith.constant 0 : i32
// expected-error @+1 {{cannot optimize an IR with non-call symbol user ops or branch ops}}
// expected-error @+1 {{cannot optimize an IR with branch ops}}
cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)
^bb1(%non_live_0 : i32):
cf.br ^bb3