Update stablehlo.custom_call call_target name based on design doc discussion.

PiperOrigin-RevId: 529281826
This commit is contained in:
John QiangZhang 2023-05-03 21:22:42 -07:00 committed by jax authors
parent bffddf76cb
commit 8acbe1557c
2 changed files with 2 additions and 2 deletions

View File

@ -577,7 +577,7 @@ def emit_tf_embedded_graph_custom_call(
concrete_function_flat_tf_name = (
concrete_function_flat_tf.function_def.signature.name
)
call_target_name = "tf_function_custom_call"
call_target_name = "tf.call_tf_function"
tf_backend_config = {
"caller_name": ir.StringAttr.get(concrete_function_flat_tf_name),
}

View File

@ -1378,7 +1378,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
# Jit mode
stablehlo_module = jax.jit(jax_f).lower(inputs).compiler_ir("stablehlo")
self.assertIn(
"stablehlo.custom_call @tf_function_custom_call",
"stablehlo.custom_call @tf.call_tf_function",
str(stablehlo_module),
)
self.assertIn("tf.backend_config", str(stablehlo_module))