[jax2tf] Start using jit_compile instead of the deprecated experimental_compile

This commit is contained in:
George Necula 2021-01-18 14:41:42 +02:00
parent 449c2bc635
commit 6d2b976fab
6 changed files with 8 additions and 8 deletions

View File

@ -89,7 +89,7 @@ def convert_and_save_model(
enable_xla: whether the jax2tf converter is allowed to use TFXLA ops. If
False, the conversion tries harder to use purely TF ops and raises an
exception if it is not possible. (default: True)
compile_model: use TensorFlow experimental_compiler on the SavedModel. This
compile_model: use TensorFlow jit_compiler on the SavedModel. This
is needed if the SavedModel will be used for TensorFlow serving.
save_model_options: options to pass to savedmodel.save.
"""
@ -116,7 +116,7 @@ def convert_and_save_model(
params)
tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
autograph=False,
experimental_compile=compile_model)
jit_compile=compile_model)
signatures = {}
# This signature is needed for TensorFlow Serving use.

View File

@ -56,7 +56,7 @@ flags.DEFINE_boolean(
"Train and save a new model. Otherwise, use an existing SavedModel.")
flags.DEFINE_boolean(
"compile_model", True,
"Enable TensorFlow experimental_compiler for the SavedModel. This is "
"Enable TensorFlow jit_compiler for the SavedModel. This is "
"necessary if you want to use the model for TensorFlow serving.")
flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.")
flags.DEFINE_boolean(

View File

@ -1826,7 +1826,7 @@ def _dynamic_slice(operand, *start_indices, slice_sizes):
# Here we could use tf.slice. Similarly, for lax.gather we can sometimes use
# tf.gather. But those have different semantics for index-out-of-bounds than
# JAX (and XLA). We have tried to force compilation, by wrapping into
# tf.xla.experimental.compile, or tf.function(experimental_compile=True), but
# tf.xla.experimental.compile, or tf.function(jit_compile=True), but
# those solutions are brittle because they do not work when nested into an
# outer compilation (see b/162814494 and b/163006262). They also do not
# survive well being put in a SavedModel. Hence, we now use TFXLA slicing

View File

@ -977,7 +977,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def qr(cls, harness: primitive_harness.Harness):
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
# # experimental_compile=True breaks for complex types.
# # jit_compile=True breaks for complex types.
# TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
# - for now, the performance of the HLO QR implementation called when
# compiling with TF is expected to have worse performance than the

View File

@ -476,7 +476,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# If we get_concrete_function we trace once
f_tf = tf.function(jax2tf.convert(f_jax, in_shapes=["(2 * batch, d)"]),
autograph=False,
experimental_compile=True).get_concrete_function(tf.TensorSpec([None, None], tf.float32))
jit_compile=True).get_concrete_function(tf.TensorSpec([None, None], tf.float32))
self.assertTrue(traced)
traced = False
self.assertAllClose(res_jax, f_tf(x))

View File

@ -61,7 +61,7 @@ def _run_tf_function(func_tf: Callable, *tf_args, mode: str):
return tf.function(
func_tf,
autograph=False,
experimental_compile=True,
jit_compile=True,
input_signature=_make_tf_input_signature(*tf_args))(
*tf_args) # COMPILED
else:
@ -114,7 +114,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
It compares the result of JAX, TF ("eager" mode),
TF with tf.function ("graph" mode), and TF with
tf.function(experimental_compile=True) ("compiled" mode). In each mode,
tf.function(jit_compile=True) ("compiled" mode). In each mode,
either we expect to encounter a known limitation, or the value should
match the value from the JAX execution.