mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Start using jit_compile instead of the deprecated experimental_compile
This commit is contained in:
parent
449c2bc635
commit
6d2b976fab
@ -89,7 +89,7 @@ def convert_and_save_model(
|
||||
enable_xla: whether the jax2tf converter is allowed to use TFXLA ops. If
|
||||
False, the conversion tries harder to use purely TF ops and raises an
|
||||
exception if it is not possible. (default: True)
|
||||
compile_model: use TensorFlow experimental_compiler on the SavedModel. This
|
||||
compile_model: use TensorFlow jit_compiler on the SavedModel. This
|
||||
is needed if the SavedModel will be used for TensorFlow serving.
|
||||
save_model_options: options to pass to savedmodel.save.
|
||||
"""
|
||||
@ -116,7 +116,7 @@ def convert_and_save_model(
|
||||
params)
|
||||
tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
|
||||
autograph=False,
|
||||
experimental_compile=compile_model)
|
||||
jit_compile=compile_model)
|
||||
|
||||
signatures = {}
|
||||
# This signature is needed for TensorFlow Serving use.
|
||||
|
@ -56,7 +56,7 @@ flags.DEFINE_boolean(
|
||||
"Train and save a new model. Otherwise, use an existing SavedModel.")
|
||||
flags.DEFINE_boolean(
|
||||
"compile_model", True,
|
||||
"Enable TensorFlow experimental_compiler for the SavedModel. This is "
|
||||
"Enable TensorFlow jit_compiler for the SavedModel. This is "
|
||||
"necessary if you want to use the model for TensorFlow serving.")
|
||||
flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.")
|
||||
flags.DEFINE_boolean(
|
||||
|
@ -1826,7 +1826,7 @@ def _dynamic_slice(operand, *start_indices, slice_sizes):
|
||||
# Here we could use tf.slice. Similarly, for lax.gather we can sometimes use
|
||||
# tf.gather. But those have different semantics for index-out-of-bounds than
|
||||
# JAX (and XLA). We have tried to force compilation, by wrapping into
|
||||
# tf.xla.experimental.compile, or tf.function(experimental_compile=True), but
|
||||
# tf.xla.experimental.compile, or tf.function(jit_compile=True), but
|
||||
# those solutions are brittle because they do not work when nested into an
|
||||
# outer compilation (see b/162814494 and b/163006262). They also do not
|
||||
# survive well being put in a SavedModel. Hence, we now use TFXLA slicing
|
||||
|
@ -977,7 +977,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
@classmethod
|
||||
def qr(cls, harness: primitive_harness.Harness):
|
||||
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
|
||||
# # experimental_compile=True breaks for complex types.
|
||||
# # jit_compile=True breaks for complex types.
|
||||
# TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
|
||||
# - for now, the performance of the HLO QR implementation called when
|
||||
# compiling with TF is expected to have worse performance than the
|
||||
|
@ -476,7 +476,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# If we get_concrete_function we trace once
|
||||
f_tf = tf.function(jax2tf.convert(f_jax, in_shapes=["(2 * batch, d)"]),
|
||||
autograph=False,
|
||||
experimental_compile=True).get_concrete_function(tf.TensorSpec([None, None], tf.float32))
|
||||
jit_compile=True).get_concrete_function(tf.TensorSpec([None, None], tf.float32))
|
||||
self.assertTrue(traced)
|
||||
traced = False
|
||||
self.assertAllClose(res_jax, f_tf(x))
|
||||
|
@ -61,7 +61,7 @@ def _run_tf_function(func_tf: Callable, *tf_args, mode: str):
|
||||
return tf.function(
|
||||
func_tf,
|
||||
autograph=False,
|
||||
experimental_compile=True,
|
||||
jit_compile=True,
|
||||
input_signature=_make_tf_input_signature(*tf_args))(
|
||||
*tf_args) # COMPILED
|
||||
else:
|
||||
@ -114,7 +114,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
|
||||
It compares the result of JAX, TF ("eager" mode),
|
||||
TF with tf.function ("graph" mode), and TF with
|
||||
tf.function(experimental_compile=True) ("compiled" mode). In each mode,
|
||||
tf.function(jit_compile=True) ("compiled" mode). In each mode,
|
||||
either we expect to encounter a known limitation, or the value should
|
||||
match the value from the JAX execution.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user