Add a new test to cover multiple calls to same tf function when call_tf_graph = True.

PiperOrigin-RevId: 531578811
This commit is contained in:
John QiangZhang 2023-05-12 12:41:59 -07:00 committed by jax authors
parent 559b837ba5
commit 2c05fe996e

View File

@ -1358,6 +1358,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
def test_call_tf_graph_non_compilable(self, tf_f, output_shape_dtype):
inputs = jnp.ones([10], dtype=jnp.float32)
caller_name_list = []
xla_call_module_list = []
def _extract_info(op):
if op.operation.name != "stablehlo.custom_call":
@ -1398,7 +1399,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
)
func_def = restored_model.f.concrete_functions[0]
xla_call_module_list = []
for node_def in func_def.graph.as_graph_def().node:
if node_def.op == "XlaCallModule":
xla_call_module_list.append(node_def)
@ -1409,8 +1409,34 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
xla_call_module = xla_call_module_list[0]
self.assertEqual(xla_call_module.attr["version"].i, 5)
self.assertIn("function_list", str(xla_call_module.attr))
xla_call_module_list.clear()
caller_name_list.clear()
# If JAX calls same tensorflow function by `jax2tf.call_tf` twice,
# it should return two different tf concrete functions.
def jax_f_2(x):
res1 = jax2tf.call_tf(
tf_f,
call_tf_graph=True,
output_shape_dtype=output_shape_dtype,
)(x)
res2 = jax2tf.call_tf(
tf_f,
call_tf_graph=True,
output_shape_dtype=output_shape_dtype,
)(x)
return res1, res2
stablehlo_module = jax.jit(jax_f_2).lower(inputs).compiler_ir("stablehlo")
self._walk_stablehlo_operations(stablehlo_module, _extract_info)
self.assertLen(caller_name_list, 2)
self.assertNotEqual(caller_name_list[0], caller_name_list[1])
logging.info("caller_name_list = %s", caller_name_list)
xla_call_module_list.clear()
caller_name_list.clear()
def test_b279454591(self):
"""Test case when tensorflow function returns `StatefulPartitionedCall` op."""
inputs = jnp.ones([10], dtype=jnp.float32)
# With one or more outputs, it is okay.