1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Merge pull request from gnecula:export_fix_symtab

PiperOrigin-RevId: 629125592
This commit is contained in:
jax authors 2024-04-29 11:14:04 -07:00
commit 0e62c4cfcc
4 changed files with 29 additions and 5 deletions
jax
_src/interpreters
experimental
tests

@ -2224,7 +2224,8 @@ def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation
def merge_mlir_modules(dst_module: ir.Module,
sym_name: str,
src_module: ir.Module) -> str:
src_module: ir.Module,
dst_symtab: ir.SymbolTable | None = None) -> str:
"""
Args:
dst_module: the module into which the contents of src_module should be
@ -2235,6 +2236,7 @@ def merge_mlir_modules(dst_module: ir.Module,
src_module: the module whose contents are to be alpha-renamed, set to
private visibility, and merged into dst_module. src_module must contain
exactly one symbol named "main".
dst_symtab: the symbol table of `dst_module`
Functions in src_module will be renamed such that they do not collide with
functions in dst_module.
@ -2248,7 +2250,7 @@ def merge_mlir_modules(dst_module: ir.Module,
assert dst_module.context == src_module.context
src_symtab = ir.SymbolTable(src_module.operation)
dst_symtab = ir.SymbolTable(dst_module.operation)
dst_symtab = dst_symtab or ir.SymbolTable(dst_module.operation)
used_names = set()
# Rename all symbols in src_module that clash with names in dst_module, or
@ -2286,6 +2288,7 @@ def merge_mlir_modules(dst_module: ir.Module,
for op in src_module.body.operations:
dst_module.body.append(op)
dst_symtab.insert(op)
return renamings["main"]
@ -2313,7 +2316,8 @@ def xla_fallback_lowering(prim: core.Primitive):
ctx.avals_out, **params)
xla_module = xla_computation_to_mlir_module(xla_computation)
callee_name = merge_mlir_modules(
module_ctx.module, f"xla_fallback_{prim.name}", xla_module)
module_ctx.module, f"xla_fallback_{prim.name}", xla_module,
dst_symtab=module_ctx.symbol_table)
output_types = map(aval_to_ir_types, ctx.avals_out)
flat_output_types = util.flatten(output_types)
output_type = (ir.TupleType.get_tuple(flat_output_types)

@ -1179,7 +1179,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
# TODO: maybe cache multiple calls
fn = mlir.merge_mlir_modules(ctx.module_context.module,
f"call_exported_{exported.fun_name}",
submodule)
submodule,
dst_symtab=ctx.module_context.symbol_table)
submodule_args = []
# All the platforms for the current lowering must be among the platforms

@ -571,7 +571,8 @@ def _call_tf_lowering(
callee_result_types = symtab["main"].type.results
fn = mlir.merge_mlir_modules(ctx.module_context.module,
f"call_tf_{function_flat_tf.name}",
submodule)
submodule,
dst_symtab=ctx.module_context.symbol_table)
call = func_dialect.CallOp(callee_result_types,
ir.FlatSymbolRefAttr.get(fn),
tuple(args_op) + captured_ops)

@ -248,6 +248,24 @@ class JaxExportTest(jtu.JaxTestCase):
f1 = export.call_exported(exp_f)
self.assertAllClose(f(x), f1(x))
def test_call_name_conflict(self):
@jax.jit
def inner(x):
# The lowering will contain a _where private function
return jnp.where(x > 0, jnp.ones_like(x), jnp.zeros_like(x))
x = jnp.arange(-20, 20, dtype=np.int32)
exp_inner = export.export(inner)(x)
self.assertIn("@_where(", str(exp_inner.mlir_module()))
@jax.jit
def outer(x):
# There should be no conflict on _where
x = export.call(exp_inner)(x)
return inner(x)
export.export(outer)(x)
def test_call_twice_exported(self):
def f(x): return jnp.sin(x)
x = np.arange(4, dtype=np.float32)