diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index 3076f3693..904ce509a 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -160,7 +160,7 @@ def jax_to_ir(fn, input_shapes, *, constants=None, format): raise ValueError( 'Conversion to TF graph requires TensorFlow to be installed.') - f = jax2tf.convert(ordered_wrapper, native_serialization=False) + f = jax2tf.convert(ordered_wrapper) f = tf_wrap_with_input_names(f, input_shapes) f = tf.function(f, autograph=False) g = f.get_concrete_function(*args).graph.as_graph_def() diff --git a/tests/jax_to_ir_test.py b/tests/jax_to_ir_test.py index 91780e178..f600a08f5 100644 --- a/tests/jax_to_ir_test.py +++ b/tests/jax_to_ir_test.py @@ -81,10 +81,6 @@ class JaxToIRTest(absltest.TestCase): jax_to_ir.parse_shape_str('foo[]') @unittest.skipIf(tf is None, 'TensorFlow not installed.') - @jtu.ignore_warning( - category=UserWarning, - message='jax2tf.convert with native_serialization=False is deprecated.' - ) def test_jax_to_tf_axpy(self): tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [ ('y', jax_to_ir.parse_shape_str('f32[128]')), @@ -92,11 +88,6 @@ class JaxToIRTest(absltest.TestCase): ('x', jax_to_ir.parse_shape_str('f32[128,2]')), ]) - # Check that tf debug txt contains a broadcast, add, and multiply. - self.assertIn('BroadcastTo', tf_text) - self.assertIn('AddV2', tf_text) - self.assertIn('Mul', tf_text) - # Check that we can re-import our graphdef. gdef = tf.compat.v1.GraphDef() gdef.ParseFromString(tf_proto)