mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Merge pull request #20976 from gnecula:export_fix_symtab
PiperOrigin-RevId: 629125592
This commit is contained in:
commit
0e62c4cfcc
@ -2224,7 +2224,8 @@ def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation
|
||||
|
||||
def merge_mlir_modules(dst_module: ir.Module,
|
||||
sym_name: str,
|
||||
src_module: ir.Module) -> str:
|
||||
src_module: ir.Module,
|
||||
dst_symtab: ir.SymbolTable | None = None) -> str:
|
||||
"""
|
||||
Args:
|
||||
dst_module: the module into which the contents of src_module should be
|
||||
@ -2235,6 +2236,7 @@ def merge_mlir_modules(dst_module: ir.Module,
|
||||
src_module: the module whose contents are to be alpha-renamed, set to
|
||||
private visibility, and merged into dst_module. src_module must contain
|
||||
exactly one symbol named "main".
|
||||
dst_symtab: the symbol table of `dst_module`
|
||||
|
||||
Functions in src_module will be renamed such that they do not collide with
|
||||
functions in dst_module.
|
||||
@ -2248,7 +2250,7 @@ def merge_mlir_modules(dst_module: ir.Module,
|
||||
assert dst_module.context == src_module.context
|
||||
|
||||
src_symtab = ir.SymbolTable(src_module.operation)
|
||||
dst_symtab = ir.SymbolTable(dst_module.operation)
|
||||
dst_symtab = dst_symtab or ir.SymbolTable(dst_module.operation)
|
||||
used_names = set()
|
||||
|
||||
# Rename all symbols in src_module that clash with names in dst_module, or
|
||||
@ -2286,6 +2288,7 @@ def merge_mlir_modules(dst_module: ir.Module,
|
||||
|
||||
for op in src_module.body.operations:
|
||||
dst_module.body.append(op)
|
||||
dst_symtab.insert(op)
|
||||
|
||||
return renamings["main"]
|
||||
|
||||
@ -2313,7 +2316,8 @@ def xla_fallback_lowering(prim: core.Primitive):
|
||||
ctx.avals_out, **params)
|
||||
xla_module = xla_computation_to_mlir_module(xla_computation)
|
||||
callee_name = merge_mlir_modules(
|
||||
module_ctx.module, f"xla_fallback_{prim.name}", xla_module)
|
||||
module_ctx.module, f"xla_fallback_{prim.name}", xla_module,
|
||||
dst_symtab=module_ctx.symbol_table)
|
||||
output_types = map(aval_to_ir_types, ctx.avals_out)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
output_type = (ir.TupleType.get_tuple(flat_output_types)
|
||||
|
@ -1179,7 +1179,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
# TODO: maybe cache multiple calls
|
||||
fn = mlir.merge_mlir_modules(ctx.module_context.module,
|
||||
f"call_exported_{exported.fun_name}",
|
||||
submodule)
|
||||
submodule,
|
||||
dst_symtab=ctx.module_context.symbol_table)
|
||||
|
||||
submodule_args = []
|
||||
# All the platforms for the current lowering must be among the platforms
|
||||
|
@ -571,7 +571,8 @@ def _call_tf_lowering(
|
||||
callee_result_types = symtab["main"].type.results
|
||||
fn = mlir.merge_mlir_modules(ctx.module_context.module,
|
||||
f"call_tf_{function_flat_tf.name}",
|
||||
submodule)
|
||||
submodule,
|
||||
dst_symtab=ctx.module_context.symbol_table)
|
||||
call = func_dialect.CallOp(callee_result_types,
|
||||
ir.FlatSymbolRefAttr.get(fn),
|
||||
tuple(args_op) + captured_ops)
|
||||
|
@ -248,6 +248,24 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
f1 = export.call_exported(exp_f)
|
||||
self.assertAllClose(f(x), f1(x))
|
||||
|
||||
def test_call_name_conflict(self):
|
||||
@jax.jit
|
||||
def inner(x):
|
||||
# The lowering will contain a _where private function
|
||||
return jnp.where(x > 0, jnp.ones_like(x), jnp.zeros_like(x))
|
||||
|
||||
x = jnp.arange(-20, 20, dtype=np.int32)
|
||||
exp_inner = export.export(inner)(x)
|
||||
self.assertIn("@_where(", str(exp_inner.mlir_module()))
|
||||
|
||||
@jax.jit
|
||||
def outer(x):
|
||||
# There should be no conflict on _where
|
||||
x = export.call(exp_inner)(x)
|
||||
return inner(x)
|
||||
|
||||
export.export(outer)(x)
|
||||
|
||||
def test_call_twice_exported(self):
|
||||
def f(x): return jnp.sin(x)
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user