mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Fix symbol collision when merging MLIR modules.
PiperOrigin-RevId: 542039479
This commit is contained in:
parent
eca3b97253
commit
e99ca460e1
@ -1613,35 +1613,69 @@ def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation
|
||||
def merge_mlir_modules(dst_module: ir.Module,
|
||||
sym_name: str,
|
||||
src_module: ir.Module) -> str:
|
||||
"""Returns the name of src_module's main() function, after renaming."""
|
||||
callee_name = None
|
||||
assert dst_module.context == src_module.context
|
||||
dst_symtab = ir.SymbolTable(dst_module.operation)
|
||||
"""
|
||||
Args:
|
||||
dst_module: the module into which the contents of src_module should be
|
||||
moved. Nothing in dst_module will be renamed.
|
||||
sym_name: the desired name for the "main" function of src_module after
|
||||
merging. This is a hint: the true name may be different because of symbol
|
||||
uniquification, and the true name is returned by this function.
|
||||
src_module: the module whose contents are to be alpha-renamed, set to
|
||||
private visibility, and merged into dst_module. src_module must contain
|
||||
exactly one symbol named "main".
|
||||
|
||||
Functions in src_module will be renamed such that they do not collide with
|
||||
functions in dst_module.
|
||||
|
||||
This function mutates `src_module`. On return, `src_module` is left in an
|
||||
undefined state.
|
||||
|
||||
Returns:
|
||||
the name of src_module's main() function, after renaming.
|
||||
"""
|
||||
assert dst_module.context == src_module.context
|
||||
|
||||
src_symtab = ir.SymbolTable(src_module.operation)
|
||||
dst_symtab = ir.SymbolTable(dst_module.operation)
|
||||
used_names = set()
|
||||
|
||||
# Rename all symbols in src_module that clash with names in dst_module, or
|
||||
# are the "main" symbol.
|
||||
renamings = {}
|
||||
for op in src_module.body.operations:
|
||||
name = op.name.value
|
||||
should_rename = name in dst_symtab or name == "main"
|
||||
if should_rename:
|
||||
base_name = sym_name if name == "main" else name
|
||||
new_name = base_name
|
||||
i = 0
|
||||
# Replacements are chosen such that the new names are present in neither
|
||||
# src_module, dst_module, or the set of fresh names we've already used.
|
||||
# Since we rename names one at a time, if new names were in src_module,
|
||||
# they might themselves collide with a later renaming.
|
||||
while (new_name in src_symtab or new_name in dst_symtab or
|
||||
new_name in used_names):
|
||||
new_name = f"{base_name}_{i}"
|
||||
i += 1
|
||||
renamings[name] = new_name
|
||||
used_names.add(new_name)
|
||||
|
||||
# Apply the symbol renamings to symbol definitions.
|
||||
private = ir.StringAttr.get("private")
|
||||
for op in src_module.body.operations:
|
||||
if op.name.value in renamings:
|
||||
src_symtab.set_symbol_name(op, renamings[op.name.value])
|
||||
op.attributes["sym_visibility"] = private
|
||||
|
||||
# Apply the symbol renamings to symbol uses.
|
||||
for old_name, new_name in renamings.items():
|
||||
for op in src_module.body.operations:
|
||||
src_symtab.replace_all_symbol_uses(old_name, new_name, op)
|
||||
|
||||
n = len(dst_module.body.operations)
|
||||
for op in src_module.body.operations:
|
||||
dst_module.body.append(op)
|
||||
ops = list(dst_module.body.operations)[n:]
|
||||
|
||||
for op in ops:
|
||||
op = typing.cast(func_dialect.FuncOp, op)
|
||||
old_name = op.name.value
|
||||
if op.name.value == "main":
|
||||
dst_symtab.set_symbol_name(op, sym_name)
|
||||
op.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
callee_name = ir.StringAttr(dst_symtab.insert(op)).value
|
||||
new_name = callee_name
|
||||
else:
|
||||
new_name = ir.StringAttr(dst_symtab.insert(op)).value
|
||||
|
||||
# Replace references to the symbol with the new name
|
||||
for other_op in ops:
|
||||
dst_symtab.replace_all_symbol_uses(
|
||||
old_name, new_name, other_op.operation)
|
||||
|
||||
|
||||
assert callee_name is not None
|
||||
return callee_name
|
||||
return renamings["main"]
|
||||
|
||||
|
||||
def xla_fallback_lowering(prim: core.Primitive):
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Tests for lowering of array origami ops into MLIR.
|
||||
# Tests for MLIR helpers.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
@ -44,10 +44,10 @@ def main(_):
|
||||
# Test merging modules
|
||||
# CHECK-LABEL: TEST: merge_modules
|
||||
# CHECK: module @jit_g
|
||||
# CHECK: func public @main
|
||||
# CHECK: func private @f
|
||||
# CHECK: func private @m2_main_renamed
|
||||
# CHECK: func private @f_0
|
||||
# CHECK: func public @main(
|
||||
# CHECK: func private @f(
|
||||
# CHECK: func private @m2_main_renamed(
|
||||
# CHECK: func private @f_0(
|
||||
def make_module(c):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@ -69,5 +69,34 @@ def main(_):
|
||||
print("\nTEST: merge_modules")
|
||||
print(str(m1))
|
||||
|
||||
|
||||
# Test symbol renaming when merging modules
|
||||
# CHECK-LABEL: TEST: merge_modules_2
|
||||
# CHECK: module @jit_f
|
||||
# CHECK: func public @main(
|
||||
# CHECK: call @f(
|
||||
# CHECK: func private @f(
|
||||
# CHECK: func private @f_0(
|
||||
# CHECK: call @f_1(
|
||||
# CHECK: func private @f_1(
|
||||
|
||||
with mlir.make_ir_context():
|
||||
m_str = """
|
||||
module @jit_f {
|
||||
func.func public @main(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
%0 = call @f(%arg0) : (tensor<i64>) -> tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
func.func private @f(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
return %arg0 : tensor<i64>
|
||||
}
|
||||
}"""
|
||||
m1 = ir.Module.parse(m_str)
|
||||
m2 = ir.Module.parse(m_str)
|
||||
mlir.merge_mlir_modules(m1, "f", m2)
|
||||
print("\nTEST: merge_modules_2")
|
||||
print(str(m1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
Loading…
x
Reference in New Issue
Block a user