mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[jax2tf] Remove non-native serialization test from jax_to_ir_test
PiperOrigin-RevId: 683124315
This commit is contained in:
parent
95631a7d92
commit
5fabd34e7e
@ -160,7 +160,7 @@ def jax_to_ir(fn, input_shapes, *, constants=None, format):
|
||||
raise ValueError(
|
||||
'Conversion to TF graph requires TensorFlow to be installed.')
|
||||
|
||||
f = jax2tf.convert(ordered_wrapper, native_serialization=False)
|
||||
f = jax2tf.convert(ordered_wrapper)
|
||||
f = tf_wrap_with_input_names(f, input_shapes)
|
||||
f = tf.function(f, autograph=False)
|
||||
g = f.get_concrete_function(*args).graph.as_graph_def()
|
||||
|
@ -81,10 +81,6 @@ class JaxToIRTest(absltest.TestCase):
|
||||
jax_to_ir.parse_shape_str('foo[]')
|
||||
|
||||
@unittest.skipIf(tf is None, 'TensorFlow not installed.')
|
||||
@jtu.ignore_warning(
|
||||
category=UserWarning,
|
||||
message='jax2tf.convert with native_serialization=False is deprecated.'
|
||||
)
|
||||
def test_jax_to_tf_axpy(self):
|
||||
tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [
|
||||
('y', jax_to_ir.parse_shape_str('f32[128]')),
|
||||
@ -92,11 +88,6 @@ class JaxToIRTest(absltest.TestCase):
|
||||
('x', jax_to_ir.parse_shape_str('f32[128,2]')),
|
||||
])
|
||||
|
||||
# Check that tf debug txt contains a broadcast, add, and multiply.
|
||||
self.assertIn('BroadcastTo', tf_text)
|
||||
self.assertIn('AddV2', tf_text)
|
||||
self.assertIn('Mul', tf_text)
|
||||
|
||||
# Check that we can re-import our graphdef.
|
||||
gdef = tf.compat.v1.GraphDef()
|
||||
gdef.ParseFromString(tf_proto)
|
||||
|
Loading…
x
Reference in New Issue
Block a user