[jax2tf] Remove non-native serialization test from jax_to_ir_test

PiperOrigin-RevId: 683124315
This commit is contained in:
George Necula 2024-10-07 04:20:55 -07:00 committed by jax authors
parent 95631a7d92
commit 5fabd34e7e
2 changed files with 1 additions and 10 deletions

View File

@ -160,7 +160,7 @@ def jax_to_ir(fn, input_shapes, *, constants=None, format):
raise ValueError(
'Conversion to TF graph requires TensorFlow to be installed.')
f = jax2tf.convert(ordered_wrapper, native_serialization=False)
f = jax2tf.convert(ordered_wrapper)
f = tf_wrap_with_input_names(f, input_shapes)
f = tf.function(f, autograph=False)
g = f.get_concrete_function(*args).graph.as_graph_def()

View File

@ -81,10 +81,6 @@ class JaxToIRTest(absltest.TestCase):
jax_to_ir.parse_shape_str('foo[]')
@unittest.skipIf(tf is None, 'TensorFlow not installed.')
@jtu.ignore_warning(
category=UserWarning,
message='jax2tf.convert with native_serialization=False is deprecated.'
)
def test_jax_to_tf_axpy(self):
tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [
('y', jax_to_ir.parse_shape_str('f32[128]')),
@ -92,11 +88,6 @@ class JaxToIRTest(absltest.TestCase):
('x', jax_to_ir.parse_shape_str('f32[128,2]')),
])
# Check that tf debug txt contains a broadcast, add, and multiply.
self.assertIn('BroadcastTo', tf_text)
self.assertIn('AddV2', tf_text)
self.assertIn('Mul', tf_text)
# Check that we can re-import our graphdef.
gdef = tf.compat.v1.GraphDef()
gdef.ParseFromString(tf_proto)