mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Update stablehlo.custom_call call_target name based on design doc discussion.
PiperOrigin-RevId: 529281826
This commit is contained in:
parent
bffddf76cb
commit
8acbe1557c
@ -577,7 +577,7 @@ def emit_tf_embedded_graph_custom_call(
|
||||
concrete_function_flat_tf_name = (
|
||||
concrete_function_flat_tf.function_def.signature.name
|
||||
)
|
||||
call_target_name = "tf_function_custom_call"
|
||||
call_target_name = "tf.call_tf_function"
|
||||
tf_backend_config = {
|
||||
"caller_name": ir.StringAttr.get(concrete_function_flat_tf_name),
|
||||
}
|
||||
|
@ -1378,7 +1378,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
# Jit mode
|
||||
stablehlo_module = jax.jit(jax_f).lower(inputs).compiler_ir("stablehlo")
|
||||
self.assertIn(
|
||||
"stablehlo.custom_call @tf_function_custom_call",
|
||||
"stablehlo.custom_call @tf.call_tf_function",
|
||||
str(stablehlo_module),
|
||||
)
|
||||
self.assertIn("tf.backend_config", str(stablehlo_module))
|
||||
|
Loading…
x
Reference in New Issue
Block a user