mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Instead of tf.Graph protobuf, we switch to tf.Saved_model
for back_compat_tf_test.
PiperOrigin-RevId: 555500398
This commit is contained in:
parent
1ddc340a1a
commit
cf026ce745
@ -109,7 +109,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
cpu_schur_lapack_gees.data_2023_07_16,
|
||||
cpu_svd_lapack_gesdd.data_2023_06_19,
|
||||
cpu_triangular_solve_blas_trsm.data_2023_07_16,
|
||||
tf_call_tf_function.data_2023_06_02, # This is tested in back_compat_tf_test.py
|
||||
tf_call_tf_function.data_2023_07_29, # This is tested in back_compat_tf_test.py
|
||||
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
|
||||
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
|
||||
tpu_ApproxTopK.data_2023_05_16,
|
||||
|
File diff suppressed because one or more lines are too long
@ -17,28 +17,48 @@ See the back_compat_test_util module docstring for how to setup and update
|
||||
these tests.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from collections.abc import Sequence
|
||||
import io
|
||||
import os
|
||||
import tarfile
|
||||
from typing import Callable, Optional
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import back_compat_test_util as bctu
|
||||
|
||||
from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.framework import graph_pb2 # type: ignore[import]
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def serialize_directory(directory_path):
|
||||
"""Seriliaze the directory as a string."""
|
||||
tar_buffer = io.BytesIO()
|
||||
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
|
||||
tar.add(directory_path, arcname=os.path.basename(directory_path))
|
||||
|
||||
# Convert the binary data to a base64-encoded string
|
||||
serialized_string = base64.b64encode(tar_buffer.getvalue())
|
||||
return serialized_string
|
||||
|
||||
|
||||
def deserialize_directory(serialized_string, output_directory):
|
||||
"""Deserialize the string to the diretory."""
|
||||
# Convert the base64-encoded string back to binary data
|
||||
tar_data = base64.b64decode(serialized_string)
|
||||
|
||||
# Extract the tar archive to the output directory
|
||||
with tarfile.open(fileobj=io.BytesIO(tar_data), mode="r") as tar:
|
||||
tar.extractall(output_directory)
|
||||
|
||||
|
||||
class CompatTensoflowTest(bctu.CompatTestBase):
|
||||
"""Compatibility tests that use TF.
|
||||
|
||||
@ -48,9 +68,8 @@ class CompatTensoflowTest(bctu.CompatTestBase):
|
||||
"""
|
||||
|
||||
def run_current(self, func: Callable, data: bctu.CompatTestData):
|
||||
# Is there a better way to serialize/deserialize TF functions? I thought
|
||||
# about using tf.saved_model, but then we have to zip/unzip a whole
|
||||
# directory.
|
||||
# Here we use tf.saved_model and provide string serialize/deserialize methods
|
||||
# for the whole directory.
|
||||
@tf.function(autograph=False, jit_compile=True)
|
||||
def tf_func(the_input): # Use recognizeable names for input and result
|
||||
res = jax2tf.convert(func, native_serialization=True)(the_input)
|
||||
@ -59,9 +78,13 @@ class CompatTensoflowTest(bctu.CompatTestBase):
|
||||
self.tf_func = tf_func
|
||||
return tf_func(*data.inputs) # type: ignore
|
||||
|
||||
def serialize(self, func: Callable, data: bctu.CompatTestData,
|
||||
polymorphic_shapes: Optional[Sequence[str]] = None,
|
||||
allow_additional_custom_call_targets: Sequence[str] = ()):
|
||||
def serialize(
|
||||
self,
|
||||
func: Callable,
|
||||
data: bctu.CompatTestData,
|
||||
polymorphic_shapes: Optional[Sequence[str]] = None,
|
||||
allow_additional_custom_call_targets: Sequence[str] = (),
|
||||
):
|
||||
# We serialize as a tf.Graph
|
||||
assert len(data.inputs) == 1 # We only support a single input now
|
||||
tf_graph = self.tf_func.get_concrete_function(*data.inputs).graph
|
||||
@ -69,7 +92,8 @@ class CompatTensoflowTest(bctu.CompatTestBase):
|
||||
if op.type == "XlaCallModule":
|
||||
serialized_module = op.get_attr("module")
|
||||
module_str = xla_extension.mlir.deserialize_portable_artifact(
|
||||
serialized_module)
|
||||
serialized_module
|
||||
)
|
||||
module_version = op.get_attr("version")
|
||||
break
|
||||
else:
|
||||
@ -77,36 +101,49 @@ class CompatTensoflowTest(bctu.CompatTestBase):
|
||||
tf_graph_def = tf_graph.as_graph_def()
|
||||
# module_str is just for human readability, add both the MLIR module
|
||||
# and the tf.Graph
|
||||
module_str = ("# First the MLIR module:\n" + module_str +
|
||||
"\n# Then the tf.Graph:\n" + str(tf_graph_def))
|
||||
serialized = tf_graph_def.SerializeToString()
|
||||
module_str = (
|
||||
"# First the MLIR module:\n"
|
||||
+ module_str
|
||||
+ "\n# Then the tf.Graph:\n"
|
||||
+ str(tf_graph_def)
|
||||
)
|
||||
# serialized = tf_graph_def.SerializeToString()
|
||||
module = tf.Module()
|
||||
module.call = self.tf_func.get_concrete_function(*data.inputs)
|
||||
root_dir = self.create_tempdir()
|
||||
saved_model_dir = os.path.join(root_dir, "saved_model")
|
||||
os.mkdir(saved_model_dir)
|
||||
tf.saved_model.save(
|
||||
module,
|
||||
saved_model_dir,
|
||||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
|
||||
)
|
||||
serialized = serialize_directory(saved_model_dir)
|
||||
return serialized, module_str, module_version
|
||||
|
||||
def run_serialized(self, data: bctu.CompatTestData,
|
||||
polymorphic_shapes: Optional[Sequence[str]] = None):
|
||||
loaded_f_tf_graph = graph_pb2.GraphDef()
|
||||
loaded_f_tf_graph.ParseFromString(data.mlir_module_serialized)
|
||||
|
||||
@tf.function(autograph=False)
|
||||
def loaded_fun(x):
|
||||
result = tf.import_graph_def(loaded_f_tf_graph,
|
||||
input_map={"the_input": x},
|
||||
return_elements=["the_result:0"])
|
||||
return result[0]
|
||||
|
||||
return (loaded_fun(*data.inputs).numpy(),)
|
||||
def run_serialized(
|
||||
self,
|
||||
data: bctu.CompatTestData,
|
||||
polymorphic_shapes: Optional[Sequence[str]] = None,
|
||||
):
|
||||
root_dir = self.create_tempdir()
|
||||
deserialize_directory(data.mlir_module_serialized, root_dir)
|
||||
saved_model_dir = os.path.join(root_dir, "saved_model")
|
||||
loaded_model = tf.saved_model.load(saved_model_dir)
|
||||
return (loaded_model.call(*data.inputs).numpy(),)
|
||||
|
||||
def test_tf_call_tf_function(self):
|
||||
self.skipTest("b/286409830: brittle on function naming.")
|
||||
# A custom call tf.call_tf_function is generated when we lower call_tf
|
||||
# with the call_tf_graph=True option.
|
||||
def func(x):
|
||||
def func_tf(x):
|
||||
return tf.math.sin(x)
|
||||
return jnp.cos(jax2tf.call_tf(func_tf, output_shape_dtype=x,
|
||||
call_tf_graph=True)(x))
|
||||
|
||||
data = self.load_testdata(tf_call_tf_function.data_2023_06_02)
|
||||
return jnp.cos(
|
||||
jax2tf.call_tf(func_tf, output_shape_dtype=x, call_tf_graph=True)(x)
|
||||
)
|
||||
|
||||
data = self.load_testdata(tf_call_tf_function.data_2023_07_29)
|
||||
self.run_one_test(func, data)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user