mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add a new test to cover multiple calls to same tf function when call_tf_graph = True.
PiperOrigin-RevId: 531578811
This commit is contained in:
parent
559b837ba5
commit
2c05fe996e
@ -1358,6 +1358,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_call_tf_graph_non_compilable(self, tf_f, output_shape_dtype):
|
||||
inputs = jnp.ones([10], dtype=jnp.float32)
|
||||
caller_name_list = []
|
||||
xla_call_module_list = []
|
||||
|
||||
def _extract_info(op):
|
||||
if op.operation.name != "stablehlo.custom_call":
|
||||
@ -1398,7 +1399,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
)
|
||||
func_def = restored_model.f.concrete_functions[0]
|
||||
|
||||
xla_call_module_list = []
|
||||
for node_def in func_def.graph.as_graph_def().node:
|
||||
if node_def.op == "XlaCallModule":
|
||||
xla_call_module_list.append(node_def)
|
||||
@ -1409,8 +1409,34 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
xla_call_module = xla_call_module_list[0]
|
||||
self.assertEqual(xla_call_module.attr["version"].i, 5)
|
||||
self.assertIn("function_list", str(xla_call_module.attr))
|
||||
xla_call_module_list.clear()
|
||||
caller_name_list.clear()
|
||||
|
||||
# If JAX calls same tensorflow function by `jax2tf.call_tf` twice,
|
||||
# it should return two different tf concrete functions.
|
||||
def jax_f_2(x):
|
||||
res1 = jax2tf.call_tf(
|
||||
tf_f,
|
||||
call_tf_graph=True,
|
||||
output_shape_dtype=output_shape_dtype,
|
||||
)(x)
|
||||
res2 = jax2tf.call_tf(
|
||||
tf_f,
|
||||
call_tf_graph=True,
|
||||
output_shape_dtype=output_shape_dtype,
|
||||
)(x)
|
||||
return res1, res2
|
||||
|
||||
stablehlo_module = jax.jit(jax_f_2).lower(inputs).compiler_ir("stablehlo")
|
||||
self._walk_stablehlo_operations(stablehlo_module, _extract_info)
|
||||
self.assertLen(caller_name_list, 2)
|
||||
self.assertNotEqual(caller_name_list[0], caller_name_list[1])
|
||||
logging.info("caller_name_list = %s", caller_name_list)
|
||||
xla_call_module_list.clear()
|
||||
caller_name_list.clear()
|
||||
|
||||
def test_b279454591(self):
|
||||
"""Test case when tensorflow function returns `StatefulPartitionedCall` op."""
|
||||
inputs = jnp.ones([10], dtype=jnp.float32)
|
||||
|
||||
# With one or more outputs, it is okay.
|
||||
|
Loading…
x
Reference in New Issue
Block a user