George Necula 235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00

136 lines
4.8 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from absl import app
from absl import flags
from jax.experimental import jax2tf
from jax.experimental.jax2tf.examples import mnist_lib
import numpy as np
import tensorflow as tf # type: ignore[import]
import tensorflow_datasets as tfds # type: ignore[import]
flags.DEFINE_string('tflite_file_path',
'/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite',
'Path where to save the TensorFlow Lite file.')
flags.DEFINE_integer('serving_batch_size', 4,
('For what batch size to prepare the serving signature. '))
flags.DEFINE_integer('num_epochs', 10, 'For how many epochs to train.')
FLAGS = flags.FLAGS
# A helper function to evaluate the TF Lite model using "test" dataset.
def evaluate_tflite_model(tflite_model, test_ds):
# Initialize TFLite interpreter using the model.
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_tensor_index = interpreter.get_input_details()[0]['index']
output = interpreter.tensor(interpreter.get_output_details()[0]['index'])
# Run predictions on every image in the "test" dataset.
prediction_digits = []
labels = []
for image, one_hot_label in test_ds:
interpreter.set_tensor(input_tensor_index, image)
# Run inference.
interpreter.invoke()
# Post-processing: for each batch dimension and find the digit with highest
# probability.
digits = np.argmax(output(), axis=1)
prediction_digits.extend(digits)
labels.extend(np.argmax(one_hot_label, axis=1))
# Compare prediction results with ground truth labels to calculate accuracy.
accurate_count = 0
for index in range(len(prediction_digits)):
if prediction_digits[index] == labels[index]:
accurate_count += 1
accuracy = accurate_count * 1.0 / len(prediction_digits)
return accuracy
def main(_):
logging.info('Loading the MNIST TensorFlow dataset')
train_ds = mnist_lib.load_mnist(
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
test_ds = mnist_lib.load_mnist(
tfds.Split.TEST, batch_size=FLAGS.serving_batch_size)
(flax_predict,
flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds, FLAGS.num_epochs)
def predict(image):
return flax_predict(flax_params, image)
# Convert Flax model to TF function.
tf_predict = tf.function(
jax2tf.convert(predict, enable_xla=False),
input_signature=[
tf.TensorSpec(
shape=[FLAGS.serving_batch_size, 28, 28, 1],
dtype=tf.float32,
name='input')
],
autograph=False)
# Convert TF function to TF Lite format.
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_predict.get_concrete_function()])
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_float_model = converter.convert()
# Show model size in KBs.
float_model_size = len(tflite_float_model) / 1024
print('Float model size = %dKBs.' % float_model_size)
# Re-convert the model to TF Lite using quantization.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()
# Show model size in KBs.
quantized_model_size = len(tflite_quantized_model) / 1024
print('Quantized model size = %dKBs,' % quantized_model_size)
print('which is about %d%% of the float model size.' %
(quantized_model_size * 100 / float_model_size))
# Evaluate the TF Lite float model. You'll find that its accurary is identical
# to the original Flax model because they are essentially the same model
# stored in different format.
float_accuracy = evaluate_tflite_model(tflite_float_model, test_ds)
print('Float model accuracy = %.4f' % float_accuracy)
# Evalualte the TF Lite quantized model.
# Don't be surprised if you see quantized model accuracy is higher than
# the original float model. It happens sometimes :)
quantized_accuracy = evaluate_tflite_model(tflite_quantized_model, test_ds)
print('Quantized model accuracy = %.4f' % quantized_accuracy)
print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))
f = open(FLAGS.tflite_file_path, 'wb')
f.write(tflite_quantized_model)
f.close()
if __name__ == '__main__':
app.run(main)