mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Flip native serialization strict_check to True.
PiperOrigin-RevId: 537399539
This commit is contained in:
parent
9a76bfb02e
commit
277e461046
@ -186,7 +186,7 @@ def poly_specs(
|
||||
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||||
for more details.
|
||||
|
||||
Returns: a pytree of jax.ShapeDTypeStruct mathcing `args`.
|
||||
Returns: a pytree of jax.ShapeDTypeStruct matching `args`.
|
||||
"""
|
||||
args_flat, args_tree = tree_util.tree_flatten(args)
|
||||
|
||||
|
@ -1408,7 +1408,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
tf_f_rt = jax2tf.convert(
|
||||
jax_f,
|
||||
native_serialization=True,
|
||||
native_serialization_strict_checks=False,
|
||||
with_gradient=False,
|
||||
)
|
||||
_, restored_model = tf_test_util.SaveAndLoadFunction(
|
||||
@ -1472,7 +1471,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
tf_f_rt = jax2tf.convert(
|
||||
jax_f,
|
||||
native_serialization=True,
|
||||
native_serialization_strict_checks=False,
|
||||
with_gradient=False,
|
||||
)
|
||||
_, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt, input_args=[inputs])
|
||||
@ -1487,7 +1485,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
tf_f_rt_2 = jax2tf.convert(
|
||||
jax_f_2,
|
||||
native_serialization=True,
|
||||
native_serialization_strict_checks=False,
|
||||
with_gradient=False,
|
||||
)
|
||||
_, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt_2, input_args=[])
|
||||
@ -1552,7 +1549,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf = jax2tf.convert(
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
native_serialization_strict_checks=False,
|
||||
with_gradient=False,
|
||||
)
|
||||
_, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
|
||||
|
Loading…
x
Reference in New Issue
Block a user