[jax2tf] Expanded the SavedModel library to allow the compilation of models.

This is important for use with TensorFlow serving.
Also removed the servo_main.py (only applies to OSS TF Serving, which
does not yet support XLA)

Co-authored-by: Benjamin Chetioui <3920784+bchetioui@users.noreply.github.com>
This commit is contained in:
George Necula 2020-11-03 09:31:10 +02:00
parent cc8fe15e46
commit 4ee7a296c3
4 changed files with 85 additions and 261 deletions

View File

@ -1,14 +1,18 @@
Getting started with jax2tf
===========================
This directory contains a number of examples of using the jax2tf converter to:
* save SavedModel from trained MNIST models, using both pure JAX and Flax.
* load the SavedModel into the TensorFlow model server and use it for
inference. We show also how to use a batch-polymorphic saved model.
* reuse the feature-extractor part of the trained MNIST model in
TensorFlow Hub, and in a larger TensorFlow Keras model.
Preparing the model for jax2tf
==============================
It is also possible to use jax2tf-generated SavedModel with TensorFlow serving.
At the moment, the open-source TensorFlow model server is missing XLA support,
but the Google version can be used, as described in README_serving.md.
# Preparing the model for jax2tf
The most important detail for using jax2tf with SavedModel is to express
the trained model as a pair:
@ -16,17 +20,18 @@ the trained model as a pair:
* `func`: a two-argument function with signature:
`(params: Parameters, inputs: Inputs) -> Outputs`.
Both arguments must be a `numpy.ndarray` or
nested tuple/list/dictionaries thereof.
nested tuple/list/dictionaries thereof. In particular, it is important
for models that take multiple inputs to be adapted to take their inputs
as a single tuple/list/dictionary of `numpy.ndarray`.
* `params`: of type `Parameters`, the model parameters, to be used as the
first input for `func`.
You can see in [mnist_lib.py](mnist_lib.py) how this can be done for two
You can see in [mnist_lib.py](mnist_lib.py) how this can be done for two
implementations of MNIST, one using pure JAX (`PureJaxMNIST`) and a CNN
one using Flax (`FlaxMNIST`). Other Flax models can be arranged similarly,
and the same strategy should work for other neural-network libraries for JAX.
Generating TensorFlow SavedModel
=====================
# Generating TensorFlow SavedModel
Once you have the model in this form, you can use the
`saved_model_lib.save_model` function [saved_model_lib.py](saved_model_lib.py)
@ -36,19 +41,24 @@ into functions that behave as if they had been written with TensorFlow.
Therefore, if you are familiar with how to generate SavedModel, you can most
likely just use your own code for this.
The file `saved_model_main.py` is an executable that shows who to perform the following
sequence of steps:
The file `saved_model_main.py` is an executable that shows how to perform the
following sequence of steps:
* train an MNIST model, and obtain a pair of an inference function and the parameters.
* train an MNIST model, and obtain a pair of an inference function and the
parameters.
* convert the inference function with jax2tf, for one of more batch sizes.
* save a SavedModel and dump its contents.
* reload the SavedModel and run it with TensorFlow to test that the inference
function produces the same results as the JAX inference function.
* optionally plot images with the training digits and the inference results.
There are a number of flags to select the Flax model (`--model=mnist_flag`),
to skip the training and just test a previously loaded
SavedModel (`--nogenerate_model`), to choose the saving path, etc.
The default saving location is `/tmp/jax2tf/saved_models/1`.
By default, this example will convert the inference function for three separate
batch sizes: 1, 16, 128. You can see this in the dumped SavedModel. If you
@ -58,42 +68,8 @@ be done using jax2tf's
As a result, the inference function will be traced only once and the SavedModel
will contain a single batch-polymorphic TensorFlow graph.
Using the TensorFlow model server
=================================
The executable `servo_main.py` extends the `saved_model_main.py` with code to
show how to use the SavedModel with TensorFlow model server. All the flags
of `saved_model_main.py` also apply to `servo_main.py`. In particular, you
can select the Flax MNIST model: `--model=mnist_flax`,
batch-polymorphic conversion to TensorFlow `--serving_batch_size=-1`,
skip the training, reuse a previously trained model: `--nogenerate_model`.
If you want to start your own model server, you should pass the
`--nostart_model_server` flag and also `--serving_url` to point to the
HTTP REST API end point of your model server. You can see the path of the
trained and saved model in the output.
Open-source model server
------------------------
We have tried this example with the OSS model server, using a
[TensorFlow Serving with Docker](https://www.tensorflow.org/tfx/serving/docker).
Specifically, you need to install Docker and run
```
docker pull tensorflow/serving
```
The actual starting of the model server is done by the `servo_main.py` script.
The script also does a sanity check that the model server is serving a model
with the right batch size and image shapes, and makes a few requests to the
server to perform inference.
Instructions for using the Google internal version of the model server: TBA.
Reusing models with TensorFlow Hub and TensorFlow Keras
=======================================================
# Reusing models with TensorFlow Hub and TensorFlow Keras
The SavedModel produced by the example in `saved_model_main.py` already
implements the [reusable saved models interface](https://www.tensorflow.org/hub/reusable_saved_models).

View File

@ -31,40 +31,64 @@ def save_model(jax_fn: Callable,
*,
input_signatures: Sequence[tf.TensorSpec],
shape_polymorphic_input_spec: Optional[str] = None,
with_gradient: bool = False):
with_gradient: bool = False,
compile_model: bool = True):
"""Saves the SavedModel for a function.
In order to use this wrapper you must first convert your model to a function
with two arguments: the parameters and the input on which you want to do
inference. Both arguments may be tuples/lists/dictionaries of np.ndarray.
inference. Both arguments may be np.ndarray or
(nested) tuples/lists/dictionaries thereof.
If you want to save the model for a function with multiple parameters and
multiple inputs, you have to collect the parameters and the inputs into
one argument, e.g., adding a tuple or dictionary at top-level.
```
def jax_fn_multi(param1, param2, input1, input2):
# JAX model with multiple parameters and multiple inputs. They all can
# be (nested) tuple/list/dictionaries of np.ndarray.
...
def jax_fn_for_save_model(params, inputs):
# JAX model with parameters and inputs collected in a tuple each. We can
# use dictionaries also (in which case the keys would appear as the names
# of the inputs)
param1, param2 = params
input1, input2 = inputs
return jax_fn_multi(param1, param2, input1, input2)
save_model(jax_fn_for_save_model, (param1, param2), ...)
```
See examples in mnist_lib.py and saved_model.py.
Args:
jax_fn: a JAX function taking two arguments, the parameters and the inputs.
Both arguments may be tuples/lists/dictionaries of np.ndarray.
Both arguments may be (nested) tuples/lists/dictionaries of np.ndarray.
params: the parameters, to be used as first argument for `jax_fn`. These
must be tuples/lists/dictionaries of np.ndarray, and will be saved as the
variables of the SavedModel.
must be (nested) tuples/lists/dictionaries of np.ndarray, and will be
saved as the variables of the SavedModel.
model_dir: the directory where the model should be saved.
input_signatures: the input signatures for the second argument of `jax_fn`
(the input). A signature must be a `tensorflow.TensorSpec` instance, or a
tuple/list/dictionary thereof with a structure matching the second
argument of `jax_fn`. The first input_signature will be saved as the
default serving signature. The additional signatures will be used only to
ensure that the `jax_fn` is traced and converted to TF for the
(nested) tuple/list/dictionary thereof with a structure matching the
second argument of `jax_fn`. The first input_signature will be saved as
the default serving signature. The additional signatures will be used
only to ensure that the `jax_fn` is traced and converted to TF for the
corresponding input shapes.
shape_polymorphic_input_spec: if given then it will be used as the
`in_shapes` argument to jax2tf.convert for the second parameter of
`jax_fn`. In this case, a single `input_signatures` is supported, and
should have `None` in the polymorphic dimensions. Should be a string, or a
tuple/list/dictionary thereof with a structure matching the second
argument of `jax_fn`.
(nesteD) tuple/list/dictionary thereof with a structure matching the
second argument of `jax_fn`.
with_gradient: whether the SavedModel should support gradients. If True,
then a custom gradient is saved. If False, then a
tf.raw_ops.PreventGradient is saved to error if a gradient is attempted.
(At the moment due to a bug in SavedModel, custom gradients are not
supported.)
compile_model: use TensorFlow experimental_compiler on the SavedModel. This
is needed if the SavedModel will be used for TensorFlow serving.
"""
if not input_signatures:
raise ValueError("At least one input_signature must be given")
@ -77,49 +101,46 @@ def save_model(jax_fn: Callable,
with_gradient=with_gradient,
in_shapes=[None, shape_polymorphic_input_spec])
wrapper = _ExportWrapper(tf_fn, params, params_trainable=with_gradient)
# Create tf.Variables for the parameters.
param_vars = tf.nest.map_structure(
# If with_gradient=False, we mark the variables behind as non-trainable,
# to ensure that users of the SavedModel will not try to fine tune them.
lambda param: tf.Variable(param, trainable=with_gradient),
params)
tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
autograph=False,
experimental_compile=compile_model)
signatures = {}
signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
wrapper.__call__.get_concrete_function(input_signatures[0])
tf_graph.get_concrete_function(input_signatures[0])
for input_signature in input_signatures[1:]:
# If there are more signatures, trace and cache a TF function for each one
wrapper.__call__.get_concrete_function(input_signature)
tf_graph.get_concrete_function(input_signature)
wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
tf.saved_model.save(wrapper, model_dir, signatures=signatures)
class _ExportWrapper(tf.train.Checkpoint):
class _ReusableSavedModelWrapper(tf.train.Checkpoint):
"""Wraps a function and its parameters for saving to a SavedModel.
Implements the interface described at
https://www.tensorflow.org/hub/reusable_saved_models.
"""
def __init__(self, tf_fn: Callable, params, params_trainable: bool):
def __init__(self, tf_graph, param_vars):
"""Args:
tf_fn: a TF function taking two arguments, the parameters and the inputs.
Both arguments may be tuples/lists/dictionaries of np.ndarray of tensors.
params: the parameters, to be used as first argument for `tf_fn`. These
must be tuples/lists/dictionaries of np.ndarray, and will be saved as the
variables of the SavedModel.
tf_fn: a tf.function taking one argument (the inputs), which can be
be tuples/lists/dictionaries of np.ndarray or tensors.
params: the parameters, as tuples/lists/dictionaries of tf.Variable,
and will be saved as the variables of the SavedModel.
"""
super().__init__()
self._fn = tf_fn
# Create tf.Variables for the parameters.
self._params = tf.nest.map_structure(
# If with_gradient=False, we mark the variables behind as non-trainable,
# to ensure that users of the SavedModel will not try to fine tune them.
lambda param: tf.Variable(param, trainable=params_trainable),
params)
# Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models
self.variables = tf.nest.flatten(self._params)
self.variables = tf.nest.flatten(param_vars)
self.trainable_variables = [v for v in self.variables if v.trainable]
# If you intend to prescribe regularization terms for users of the model,
# add them as @tf.functions with no inputs to this list. Else drop this.
self.regularization_losses = []
@tf.function(autograph=False)
def __call__(self, inputs):
outputs = self._fn(self._params, inputs)
return outputs
self.__call__ = tf_graph

View File

@ -55,6 +55,10 @@ flags.DEFINE_integer("num_epochs", 3, "For how many epochs to train.")
flags.DEFINE_boolean(
"generate_model", True,
"Train and save a new model. Otherwise, use an existing SavedModel.")
flags.DEFINE_boolean(
"compile_model", True,
"Enable TensorFlow experimental_compiler for the SavedModel. This is "
"necessary if you want to use the model for TensorFlow serving.")
flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.")
flags.DEFINE_boolean(
"show_images", False,
@ -111,7 +115,8 @@ def train_and_save():
predict_params,
model_dir,
input_signatures=input_signatures,
shape_polymorphic_input_spec=shape_polymorphic_input_spec)
shape_polymorphic_input_spec=shape_polymorphic_input_spec,
compile_model=FLAGS.compile_model)
if FLAGS.test_savedmodel:
tf_accelerator, tolerances = tf_accelerator_and_tolerances()
@ -134,9 +139,7 @@ def train_and_save():
pure_restored_model(tf.convert_to_tensor(test_input)),
predict_fn(predict_params, test_input), **tolerances)
assert os.path.isdir(model_dir)
if FLAGS.show_model:
def print_model(model_dir: str):
cmd = f"saved_model_cli show --all --dir {model_dir}"
print(cmd)
@ -165,14 +168,9 @@ def model_description() -> str:
def savedmodel_dir(with_version: bool = True) -> str:
"""The directory where we save the SavedModel."""
if FLAGS.model == "mnist_pure_jax":
model_class = mnist_lib.PureJaxMNIST
elif FLAGS.model == "mnist_flax":
model_class = mnist_lib.FlaxMNIST
model_dir = os.path.join(
FLAGS.model_path,
f"{model_class.name}{'' if FLAGS.model_classifier_layer else '_features'}"
f"{'mnist' if FLAGS.model_classifier_layer else 'mnist_features'}"
)
if with_version:
model_dir = os.path.join(model_dir, str(FLAGS.model_version))

View File

@ -1,171 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Demonstrates use of a jax2tf model in TensorFlow model server.
Includes the flags from saved_model_main.py.
If you want to start your own model server, you should pass the
`--nostart_model_server` flag and also `--serving_url` to point to the
HTTP REST API end point of your model server. You can see the path of the
trained and saved model in the output.
See README.md.
"""
import atexit
import logging
import subprocess
import threading
import time
from absl import app
from absl import flags
from jax.experimental.jax2tf.examples import mnist_lib
from jax.experimental.jax2tf.examples import saved_model_main
import numpy as np
import requests
import tensorflow as tf # type: ignore
import tensorflow_datasets as tfds # type: ignore
flags.DEFINE_integer("count_images", 10, "How many images to test")
flags.DEFINE_bool("start_model_server", True,
"Whether to start/stop the model server.")
flags.DEFINE_string("serving_url", "http://localhost:8501/v1/models/jax_model",
"The HTTP endpoint for the model server")
FLAGS = flags.FLAGS
def mnist_predict_request(serving_url: str, images):
"""Predicts using the model server.
Args:
serving_url: The URL for the model server.
images: A batch of images of shape F32[B, 28, 28, 1]
Returns:
a batch of one-hot predictions, of shape F32[B, 10]
"""
request = {"inputs": images.tolist()}
response = requests.post(f"{serving_url}:predict", json=request)
response_json = response.json()
if response.status_code != 200:
raise RuntimeError("Model server error: " + response_json["error"])
predictions = np.array(response_json["outputs"])
return predictions
def main(_):
if FLAGS.count_images % FLAGS.serving_batch_size != 0:
raise ValueError("count_images must be a multiple of serving_batch_size")
saved_model_main.train_and_save()
# Strip the version number from the model directory
servo_model_dir = saved_model_main.savedmodel_dir(with_version=False)
if FLAGS.start_model_server:
model_server_proc = _start_localhost_model_server(servo_model_dir)
try:
_mnist_sanity_check(FLAGS.serving_url, FLAGS.serving_batch_size)
test_ds = mnist_lib.load_mnist(
tfds.Split.TEST, batch_size=FLAGS.serving_batch_size)
images_and_labels = tfds.as_numpy(
test_ds.take(FLAGS.count_images // FLAGS.serving_batch_size))
for (images, labels) in images_and_labels:
predictions_one_hot = mnist_predict_request(FLAGS.serving_url, images)
predictions_digit = np.argmax(predictions_one_hot, axis=1)
label_digit = np.argmax(labels, axis=1)
logging.info(
f" predicted = {predictions_digit} labelled digit {label_digit}")
finally:
if FLAGS.start_model_server:
model_server_proc.kill()
model_server_proc.communicate()
def _mnist_sanity_check(serving_url: str, serving_batch_size: int):
"""Checks that we can reach a model server with a model that matches MNIST."""
logging.info("Checking that model server serves a compatible model.")
response = requests.get(f"{serving_url}/metadata")
response_json = response.json()
if response.status_code != 200:
raise IOError("Model server error: " + response_json["error"])
try:
signature_def = response_json["metadata"]["signature_def"]["signature_def"]
serving_default = signature_def[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
assert serving_default["method_name"] == "tensorflow/serving/predict"
inputs, = serving_default["inputs"].values()
b, w, h, c = [int(d["size"]) for d in inputs["tensor_shape"]["dim"]]
assert b == -1 or b == serving_batch_size, (
f"Found input batch size {b}. Expecting {serving_batch_size}")
assert (w, h, c) == mnist_lib.input_shape
except Exception as e:
raise IOError(
f"Unexpected response from model server: {response_json}") from e
def _start_localhost_model_server(model_dir):
"""Starts the model server on localhost, using docker.
Ignore this if you have a different way to start the model server.
"""
cmd = ("docker run -p 8501:8501 --mount "
f"type=bind,source={model_dir}/,target=/models/jax_model "
"-e MODEL_NAME=jax_model -t --rm --name=serving tensorflow/serving")
cmd_args = cmd.split(" ")
logging.info("Starting model server")
logging.info(f"Running {cmd}")
proc = subprocess.Popen(
cmd_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
model_server_ready = False
def model_server_output_reader():
for line in iter(proc.stdout.readline, b""):
line_str = line.decode("utf-8").strip()
if "Exporting HTTP/REST API at:localhost:8501" in line_str:
nonlocal model_server_ready
model_server_ready = True
logging.info(f"Model server: {line_str}")
output_thread = threading.Thread(target=model_server_output_reader, args=())
output_thread.start()
def _stop_model_server():
logging.info("Stopping the model server")
subprocess.run("docker container stop serving".split(" "),
check=True)
atexit.register(_stop_model_server)
wait_iteration_sec = 2
wait_remaining_sec = 10
while not model_server_ready and wait_remaining_sec > 0:
logging.info("Waiting for the model server to be ready...")
time.sleep(wait_iteration_sec)
wait_remaining_sec -= wait_iteration_sec
if wait_remaining_sec <= 0:
raise IOError("Model server failed to start properly")
return proc
if __name__ == "__main__":
app.run(main)