mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[jax2tf] Add more documentation about saving models with custom gradients
This commit is contained in:
parent
c169ee3934
commit
2888e7ca81
@ -95,7 +95,8 @@ is trivial:
|
||||
my_model = tf.Module()
|
||||
# Save a function that can take scalar inputs.
|
||||
my_model.f = tf.function(jax2tf.convert(f_jax), input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
tf.saved_model.save(my_model, '/some/directory')
|
||||
tf.saved_model.save(my_model, '/some/directory',
|
||||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
|
||||
|
||||
# Restoring (note: the restored model does *not* require JAX to run, just XLA).
|
||||
restored_model = tf.saved_model.load('/some/directory')
|
||||
@ -113,7 +114,8 @@ SavedModel multiple versions of a function for different input shapes, by
|
||||
my_model.f = tf.function(jax2tf.convert(f_jax), autograph=False)
|
||||
my_model.f(tf.ones([1, 28, 28])) # a batch size of 1
|
||||
my_model.f(tf.ones([16, 28, 28])) # a batch size of 16
|
||||
tf.saved_model.save(my_model, '/some/directory')
|
||||
tf.saved_model.save(my_model, '/some/directory',
|
||||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
|
||||
```
|
||||
|
||||
For examples of how to save a Flax model as a SavedModel see the
|
||||
@ -144,6 +146,32 @@ options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
|
||||
tf.saved_model.save(model, path, options=options)
|
||||
```
|
||||
|
||||
If you use `with_gradient=True` and forget to use the `experimental_custom_gradients=True` parameter
|
||||
to `tf.saved_model.save` when you later load the saved model you will see a warning:
|
||||
|
||||
```
|
||||
WARNING:absl:Importing a function (__inference_converted_fun_25) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
|
||||
```
|
||||
|
||||
and if you do attempt to take a gradient of the loaded model you may get an error:
|
||||
|
||||
```
|
||||
TypeError: An op outside of the function building code is being passed
|
||||
a "Graph" tensor. It is possible to have Graph tensors
|
||||
leak out of the function building context by including a
|
||||
tf.init_scope in your function building code.
|
||||
For example, the following function will fail:
|
||||
@tf.function
|
||||
def has_init_scope():
|
||||
my_constant = tf.constant(1.)
|
||||
with tf.init_scope():
|
||||
added = my_constant * 2
|
||||
The graph tensor has name: args_0:0
|
||||
```
|
||||
|
||||
(We are working with the TF team to give a more explicit error in this case.)
|
||||
|
||||
|
||||
## Shape-polymorphic conversion
|
||||
|
||||
**The shape polymorphism support is work in progress. It is meant to be sound,
|
||||
|
@ -41,7 +41,7 @@ def convert_and_save_model(
|
||||
with_gradient: bool = False,
|
||||
enable_xla: bool = True,
|
||||
compile_model: bool = True,
|
||||
save_model_options: Optional[tf.saved_model.SaveOptions] = None):
|
||||
saved_model_options: Optional[tf.saved_model.SaveOptions] = None):
|
||||
"""Convert a JAX function and saves a SavedModel.
|
||||
|
||||
This is an example, for serious uses you will likely want to copy and
|
||||
@ -89,7 +89,7 @@ def convert_and_save_model(
|
||||
`polymorphic_shapes` argument to jax2tf.convert for the second parameter of
|
||||
`jax_fn`. In this case, a single `input_signatures` is supported, and
|
||||
should have `None` in the polymorphic dimensions.
|
||||
save_model_options: options to pass to savedmodel.save.
|
||||
saved_model_options: options to pass to savedmodel.save.
|
||||
"""
|
||||
if not input_signatures:
|
||||
raise ValueError("At least one input_signature must be given")
|
||||
@ -124,8 +124,13 @@ def convert_and_save_model(
|
||||
# If there are more signatures, trace and cache a TF function for each one
|
||||
tf_graph.get_concrete_function(input_signature)
|
||||
wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
|
||||
if with_gradient:
|
||||
if not saved_model_options:
|
||||
saved_model_options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
|
||||
else:
|
||||
saved_model_options.experimental_custom_gradients = True
|
||||
tf.saved_model.save(wrapper, model_dir, signatures=signatures,
|
||||
options=save_model_options)
|
||||
options=saved_model_options)
|
||||
|
||||
|
||||
class _ReusableSavedModelWrapper(tf.train.Checkpoint):
|
||||
|
@ -415,6 +415,48 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(g(x), g_rt(x))
|
||||
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
|
||||
|
||||
def test_round_trip_without_gradient_saved_model(self):
|
||||
# Explicitly with_gradient=False
|
||||
f_jax = jnp.sum
|
||||
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
f_tf = tf_test_util.SaveAndLoadFunction(
|
||||
jax2tf.convert(f_jax, with_gradient=False),
|
||||
[tf.TensorSpec(x.shape, dtype=x.dtype)])
|
||||
f_rt = jax2tf.call_tf(f_tf)
|
||||
|
||||
self.assertAllClose(f_jax(x), f_rt(x))
|
||||
with self.assertRaisesRegex(Exception,
|
||||
"Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
|
||||
jax.grad(f_rt)(x)
|
||||
|
||||
def test_round_trip_saved_model_no_gradients(self):
|
||||
# Save without gradients
|
||||
f_jax = jnp.sum
|
||||
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
f_tf = tf_test_util.SaveAndLoadFunction(
|
||||
jax2tf.convert(f_jax, with_gradient=True),
|
||||
[tf.TensorSpec(x.shape, dtype=x.dtype)],
|
||||
save_gradients=False)
|
||||
f_rt = jax2tf.call_tf(f_tf)
|
||||
|
||||
self.assertAllClose(f_jax(x), f_rt(x))
|
||||
# TODO: clean this up b/191117111: it should fail with a clear error
|
||||
# The following results in a confusing error:
|
||||
# TypeError: An op outside of the function building code is being passed
|
||||
# a "Graph" tensor. It is possible to have Graph tensors
|
||||
# leak out of the function building context by including a
|
||||
# tf.init_scope in your function building code.
|
||||
# For example, the following function will fail:
|
||||
# @tf.function
|
||||
# def has_init_scope():
|
||||
# my_constant = tf.constant(1.)
|
||||
# with tf.init_scope():
|
||||
# added = my_constant * 2
|
||||
# The graph tensor has name: args_0:0
|
||||
# g = jax.grad(f_rt)(x)
|
||||
|
||||
def test_module_documentation(self):
|
||||
def cos_tf(x):
|
||||
return tf.math.cos(x)
|
||||
|
@ -42,24 +42,6 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||
|
||||
def test_gradient_disabled(self):
|
||||
f_jax = lambda x: x * x
|
||||
|
||||
model = tf.Module()
|
||||
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=False),
|
||||
autograph=False,
|
||||
input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
x = np.array(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(model.f(x), f_jax(x))
|
||||
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||
xv = tf.Variable(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||
|
||||
with self.assertRaisesRegex(LookupError,
|
||||
"Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"):
|
||||
with tf.GradientTape():
|
||||
_ = restored_model.f(xv)
|
||||
|
||||
def test_gradient(self):
|
||||
"""Save and restore the custom gradient."""
|
||||
@jax.custom_jvp
|
||||
@ -82,12 +64,95 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
x = np.array(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(model.f(x), f_jax(x))
|
||||
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||
xv = tf.Variable(0.7, dtype=jnp.float32)
|
||||
xv = tf.Variable(x)
|
||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||
with tf.GradientTape() as tape:
|
||||
y = restored_model.f(xv)
|
||||
self.assertAllClose(tape.gradient(y, xv).numpy(),
|
||||
jax.grad(f_jax)(x).astype(np.float32))
|
||||
jax.grad(f_jax)(x))
|
||||
|
||||
def test_gradient_nested(self):
|
||||
"""Save and restore the custom gradient, when combined with other TF code."""
|
||||
@jax.custom_jvp
|
||||
def f_jax(x):
|
||||
return x * x
|
||||
|
||||
@f_jax.defjvp
|
||||
def f_jax_jvp(primals, tangents):
|
||||
# 3 * x * x_t
|
||||
x, = primals
|
||||
x_dot, = tangents
|
||||
primal_out = f_jax(x)
|
||||
tangent_out = x * x_dot * 3.
|
||||
return primal_out, tangent_out
|
||||
|
||||
model = tf.Module()
|
||||
# After conversion, we wrap with some pure TF code
|
||||
model.f = tf.function(lambda x: tf.math.sin(jax2tf.convert(f_jax, with_gradient=True)(x)),
|
||||
autograph=False,
|
||||
input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
f_jax_equiv = lambda x: jnp.sin(f_jax(x))
|
||||
x = np.array(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(model.f(x), f_jax_equiv(x))
|
||||
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||
xv = tf.Variable(x)
|
||||
self.assertAllClose(restored_model.f(x), f_jax_equiv(x))
|
||||
with tf.GradientTape() as tape:
|
||||
y = restored_model.f(xv)
|
||||
self.assertAllClose(tape.gradient(y, xv).numpy(),
|
||||
jax.grad(f_jax_equiv)(x))
|
||||
|
||||
def test_gradient_disabled(self):
|
||||
f_jax = lambda x: x * x
|
||||
|
||||
model = tf.Module()
|
||||
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=False),
|
||||
autograph=False,
|
||||
input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
x = np.array(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(model.f(x), f_jax(x))
|
||||
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||
xv = tf.Variable(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||
|
||||
with self.assertRaisesRegex(LookupError,
|
||||
"Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"):
|
||||
with tf.GradientTape():
|
||||
_ = restored_model.f(xv)
|
||||
|
||||
def test_save_without_gradients(self):
|
||||
f_jax = lambda x: x * x
|
||||
|
||||
x = np.array(0.7, dtype=jnp.float32)
|
||||
model = tf.Module()
|
||||
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True),
|
||||
autograph=False,
|
||||
input_signature=[tf.TensorSpec(x.shape, x.dtype)])
|
||||
|
||||
self.assertAllClose(model.f(x), f_jax(x))
|
||||
restored_model = tf_test_util.SaveAndLoadModel(model,
|
||||
save_gradients=False)
|
||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||
|
||||
xv = tf.Variable(x)
|
||||
with tf.GradientTape():
|
||||
_ = restored_model.f(xv)
|
||||
# TODO: clean this up b/191117111: it should fail with a clear error
|
||||
# The following results in a confusing error:
|
||||
# TypeError: An op outside of the function building code is being passed
|
||||
# a "Graph" tensor. It is possible to have Graph tensors
|
||||
# leak out of the function building context by including a
|
||||
# tf.init_scope in your function building code.
|
||||
# For example, the following function will fail:
|
||||
# @tf.function
|
||||
# def has_init_scope():
|
||||
# my_constant = tf.constant(1.)
|
||||
# with tf.init_scope():
|
||||
# added = my_constant * 2
|
||||
# The graph tensor has name: args_0:0
|
||||
# g = tape.gradient(res, xv)
|
||||
#self.assertAllClose(g.numpy(), jax.grad(f_jax)(x))
|
||||
|
||||
|
||||
def _compare_with_saved_model(self, f_jax, *args):
|
||||
# Certain ops are converted to ensure an XLA context, e.g.,
|
||||
|
@ -75,23 +75,25 @@ class OpMetadataGraph:
|
||||
source_line: str
|
||||
|
||||
|
||||
def SaveAndLoadModel(model: tf.Module) -> tf.Module:
|
||||
def SaveAndLoadModel(model: tf.Module,
|
||||
save_gradients=True) -> tf.Module:
|
||||
# Roundtrip through saved model on disk.
|
||||
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
|
||||
tf.saved_model.save(
|
||||
model, model_dir,
|
||||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
|
||||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=save_gradients))
|
||||
restored_model = tf.saved_model.load(model_dir)
|
||||
return restored_model
|
||||
|
||||
def SaveAndLoadFunction(f_tf: Callable,
|
||||
input_signature: Sequence[tf.TensorSpec]) -> Callable:
|
||||
input_signature: Sequence[tf.TensorSpec],
|
||||
save_gradients=True) -> Callable:
|
||||
# Roundtrip through saved model on disk
|
||||
model = tf.Module()
|
||||
model.f = tf.function(f_tf,
|
||||
autograph=False,
|
||||
input_signature=input_signature)
|
||||
restored = SaveAndLoadModel(model)
|
||||
restored = SaveAndLoadModel(model, save_gradients=save_gradients)
|
||||
return restored.f
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user