[jax2tf] test: fix jax serialization version tests

This commit is contained in:
Jake VanderPlas 2024-06-04 11:03:06 -07:00
parent 2333d5c7c3
commit ca784a09a3

View File

@ -1154,7 +1154,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
if version != version_override:
self.addCleanup(partial(jax.config.update,
"jax_serialization_version",
version_override))
version))
jax.config.update("jax_serialization_version", version_override)
logging.info(
"Using JAX serialization version %s",