Instead of tf.Graph protobuf, we switch to tf.Saved_model for back_compat_tf_test.

PiperOrigin-RevId: 555500398
This commit is contained in:
John QiangZhang 2023-08-10 08:21:50 -07:00 committed by jax authors
parent 1ddc340a1a
commit cf026ce745
3 changed files with 98 additions and 53 deletions

View File

@ -109,7 +109,7 @@ class CompatTest(bctu.CompatTestBase):
cpu_schur_lapack_gees.data_2023_07_16,
cpu_svd_lapack_gesdd.data_2023_06_19,
cpu_triangular_solve_blas_trsm.data_2023_07_16,
tf_call_tf_function.data_2023_06_02, # This is tested in back_compat_tf_test.py
tf_call_tf_function.data_2023_07_29, # This is tested in back_compat_tf_test.py
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
tpu_ApproxTopK.data_2023_05_16,

File diff suppressed because one or more lines are too long

View File

@ -17,28 +17,48 @@ See the back_compat_test_util module docstring for how to setup and update
these tests.
"""
import base64
from collections.abc import Sequence
import io
import os
import tarfile
from typing import Callable, Optional
from absl.testing import absltest
from jax import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import back_compat_test_util as bctu
from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function
import jax.numpy as jnp
from jax._src.lib import xla_extension
from jax._src import test_util as jtu
import tensorflow as tf
from tensorflow.core.framework import graph_pb2 # type: ignore[import]
config.parse_flags_with_absl()
def serialize_directory(directory_path):
"""Seriliaze the directory as a string."""
tar_buffer = io.BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
tar.add(directory_path, arcname=os.path.basename(directory_path))
# Convert the binary data to a base64-encoded string
serialized_string = base64.b64encode(tar_buffer.getvalue())
return serialized_string
def deserialize_directory(serialized_string, output_directory):
"""Deserialize the string to the diretory."""
# Convert the base64-encoded string back to binary data
tar_data = base64.b64decode(serialized_string)
# Extract the tar archive to the output directory
with tarfile.open(fileobj=io.BytesIO(tar_data), mode="r") as tar:
tar.extractall(output_directory)
class CompatTensoflowTest(bctu.CompatTestBase):
"""Compatibility tests that use TF.
@ -48,9 +68,8 @@ class CompatTensoflowTest(bctu.CompatTestBase):
"""
def run_current(self, func: Callable, data: bctu.CompatTestData):
# Is there a better way to serialize/deserialize TF functions? I thought
# about using tf.saved_model, but then we have to zip/unzip a whole
# directory.
# Here we use tf.saved_model and provide string serialize/deserialize methods
# for the whole directory.
@tf.function(autograph=False, jit_compile=True)
def tf_func(the_input): # Use recognizeable names for input and result
res = jax2tf.convert(func, native_serialization=True)(the_input)
@ -59,9 +78,13 @@ class CompatTensoflowTest(bctu.CompatTestBase):
self.tf_func = tf_func
return tf_func(*data.inputs) # type: ignore
def serialize(self, func: Callable, data: bctu.CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None,
allow_additional_custom_call_targets: Sequence[str] = ()):
def serialize(
self,
func: Callable,
data: bctu.CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None,
allow_additional_custom_call_targets: Sequence[str] = (),
):
# We serialize as a tf.Graph
assert len(data.inputs) == 1 # We only support a single input now
tf_graph = self.tf_func.get_concrete_function(*data.inputs).graph
@ -69,7 +92,8 @@ class CompatTensoflowTest(bctu.CompatTestBase):
if op.type == "XlaCallModule":
serialized_module = op.get_attr("module")
module_str = xla_extension.mlir.deserialize_portable_artifact(
serialized_module)
serialized_module
)
module_version = op.get_attr("version")
break
else:
@ -77,36 +101,49 @@ class CompatTensoflowTest(bctu.CompatTestBase):
tf_graph_def = tf_graph.as_graph_def()
# module_str is just for human readability, add both the MLIR module
# and the tf.Graph
module_str = ("# First the MLIR module:\n" + module_str +
"\n# Then the tf.Graph:\n" + str(tf_graph_def))
serialized = tf_graph_def.SerializeToString()
module_str = (
"# First the MLIR module:\n"
+ module_str
+ "\n# Then the tf.Graph:\n"
+ str(tf_graph_def)
)
# serialized = tf_graph_def.SerializeToString()
module = tf.Module()
module.call = self.tf_func.get_concrete_function(*data.inputs)
root_dir = self.create_tempdir()
saved_model_dir = os.path.join(root_dir, "saved_model")
os.mkdir(saved_model_dir)
tf.saved_model.save(
module,
saved_model_dir,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
)
serialized = serialize_directory(saved_model_dir)
return serialized, module_str, module_version
def run_serialized(self, data: bctu.CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None):
loaded_f_tf_graph = graph_pb2.GraphDef()
loaded_f_tf_graph.ParseFromString(data.mlir_module_serialized)
@tf.function(autograph=False)
def loaded_fun(x):
result = tf.import_graph_def(loaded_f_tf_graph,
input_map={"the_input": x},
return_elements=["the_result:0"])
return result[0]
return (loaded_fun(*data.inputs).numpy(),)
def run_serialized(
self,
data: bctu.CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None,
):
root_dir = self.create_tempdir()
deserialize_directory(data.mlir_module_serialized, root_dir)
saved_model_dir = os.path.join(root_dir, "saved_model")
loaded_model = tf.saved_model.load(saved_model_dir)
return (loaded_model.call(*data.inputs).numpy(),)
def test_tf_call_tf_function(self):
self.skipTest("b/286409830: brittle on function naming.")
# A custom call tf.call_tf_function is generated when we lower call_tf
# with the call_tf_graph=True option.
def func(x):
def func_tf(x):
return tf.math.sin(x)
return jnp.cos(jax2tf.call_tf(func_tf, output_shape_dtype=x,
call_tf_graph=True)(x))
data = self.load_testdata(tf_call_tf_function.data_2023_06_02)
return jnp.cos(
jax2tf.call_tf(func_tf, output_shape_dtype=x, call_tf_graph=True)(x)
)
data = self.load_testdata(tf_call_tf_function.data_2023_07_29)
self.run_one_test(func, data)