Flip native serialization strict_check to True.

PiperOrigin-RevId: 537399539
This commit is contained in:
John QiangZhang 2023-06-02 13:44:29 -07:00 committed by jax authors
parent 9a76bfb02e
commit 277e461046
2 changed files with 1 additions and 5 deletions

View File

@ -186,7 +186,7 @@ def poly_specs(
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
Returns: a pytree of jax.ShapeDTypeStruct mathcing `args`.
Returns: a pytree of jax.ShapeDTypeStruct matching `args`.
"""
args_flat, args_tree = tree_util.tree_flatten(args)

View File

@ -1408,7 +1408,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
tf_f_rt = jax2tf.convert(
jax_f,
native_serialization=True,
native_serialization_strict_checks=False,
with_gradient=False,
)
_, restored_model = tf_test_util.SaveAndLoadFunction(
@ -1472,7 +1471,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
tf_f_rt = jax2tf.convert(
jax_f,
native_serialization=True,
native_serialization_strict_checks=False,
with_gradient=False,
)
_, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt, input_args=[inputs])
@ -1487,7 +1485,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
tf_f_rt_2 = jax2tf.convert(
jax_f_2,
native_serialization=True,
native_serialization_strict_checks=False,
with_gradient=False,
)
_, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt_2, input_args=[])
@ -1552,7 +1549,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
f_tf = jax2tf.convert(
f_jax,
native_serialization=True,
native_serialization_strict_checks=False,
with_gradient=False,
)
_, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])