mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary. For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API. Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`. This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR. PiperOrigin-RevId: 551604974
142 lines
4.8 KiB
Python
142 lines
4.8 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
from absl import app
|
|
from absl import flags
|
|
|
|
from jax.experimental import jax2tf
|
|
from jax.experimental.jax2tf.examples import mnist_lib
|
|
|
|
import numpy as np
|
|
|
|
import tensorflow as tf # type: ignore[import]
|
|
import tensorflow_datasets as tfds # type: ignore[import]
|
|
|
|
_TFLITE_FILE_PATH = flags.DEFINE_string(
|
|
'tflite_file_path',
|
|
'/tmp/mnist.tflite',
|
|
'Path where to save the TensorFlow Lite file.',
|
|
)
|
|
_SERVING_BATCH_SIZE = flags.DEFINE_integer(
|
|
'serving_batch_size',
|
|
4,
|
|
'For what batch size to prepare the serving signature. ',
|
|
)
|
|
_NUM_EPOCHS = flags.DEFINE_integer(
|
|
'num_epochs', 10, 'For how many epochs to train.'
|
|
)
|
|
|
|
|
|
# A helper function to evaluate the TF Lite model using "test" dataset.
|
|
def evaluate_tflite_model(tflite_model, test_ds):
|
|
# Initialize TFLite interpreter using the model.
|
|
interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
|
interpreter.allocate_tensors()
|
|
input_tensor_index = interpreter.get_input_details()[0]['index']
|
|
output = interpreter.tensor(interpreter.get_output_details()[0]['index'])
|
|
|
|
# Run predictions on every image in the "test" dataset.
|
|
prediction_digits = []
|
|
labels = []
|
|
for image, one_hot_label in test_ds:
|
|
interpreter.set_tensor(input_tensor_index, image)
|
|
|
|
# Run inference.
|
|
interpreter.invoke()
|
|
|
|
# Post-processing: for each batch dimension and find the digit with highest
|
|
# probability.
|
|
digits = np.argmax(output(), axis=1)
|
|
prediction_digits.extend(digits)
|
|
labels.extend(np.argmax(one_hot_label, axis=1))
|
|
|
|
# Compare prediction results with ground truth labels to calculate accuracy.
|
|
accurate_count = 0
|
|
for index in range(len(prediction_digits)):
|
|
if prediction_digits[index] == labels[index]:
|
|
accurate_count += 1
|
|
accuracy = accurate_count * 1.0 / len(prediction_digits)
|
|
return accuracy
|
|
|
|
|
|
def main(_):
|
|
logging.info('Loading the MNIST TensorFlow dataset')
|
|
train_ds = mnist_lib.load_mnist(
|
|
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
|
|
test_ds = mnist_lib.load_mnist(
|
|
tfds.Split.TEST, batch_size=_SERVING_BATCH_SIZE)
|
|
|
|
(flax_predict, flax_params) = mnist_lib.FlaxMNIST.train(
|
|
train_ds, test_ds, _NUM_EPOCHS.value
|
|
)
|
|
|
|
def predict(image):
|
|
return flax_predict(flax_params, image)
|
|
|
|
# Convert Flax model to TF function.
|
|
tf_predict = tf.function(
|
|
jax2tf.convert(predict, enable_xla=False),
|
|
input_signature=[
|
|
tf.TensorSpec(
|
|
shape=[_SERVING_BATCH_SIZE, 28, 28, 1],
|
|
dtype=tf.float32,
|
|
name='input')
|
|
],
|
|
autograph=False)
|
|
|
|
# Convert TF function to TF Lite format.
|
|
converter = tf.lite.TFLiteConverter.from_concrete_functions(
|
|
[tf_predict.get_concrete_function()], tf_predict)
|
|
converter.target_spec.supported_ops = [
|
|
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
|
|
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
|
|
]
|
|
tflite_float_model = converter.convert()
|
|
|
|
# Show model size in KBs.
|
|
float_model_size = len(tflite_float_model) / 1024
|
|
print('Float model size = %dKBs.' % float_model_size)
|
|
|
|
# Re-convert the model to TF Lite using quantization.
|
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
tflite_quantized_model = converter.convert()
|
|
|
|
# Show model size in KBs.
|
|
quantized_model_size = len(tflite_quantized_model) / 1024
|
|
print('Quantized model size = %dKBs,' % quantized_model_size)
|
|
print('which is about %d%% of the float model size.' %
|
|
(quantized_model_size * 100 / float_model_size))
|
|
|
|
# Evaluate the TF Lite float model. You'll find that its accuracy is identical
|
|
# to the original Flax model because they are essentially the same model
|
|
# stored in different format.
|
|
float_accuracy = evaluate_tflite_model(tflite_float_model, test_ds)
|
|
print('Float model accuracy = %.4f' % float_accuracy)
|
|
|
|
# Evalualte the TF Lite quantized model.
|
|
# Don't be surprised if you see quantized model accuracy is higher than
|
|
# the original float model. It happens sometimes :)
|
|
quantized_accuracy = evaluate_tflite_model(tflite_quantized_model, test_ds)
|
|
print('Quantized model accuracy = %.4f' % quantized_accuracy)
|
|
print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))
|
|
|
|
f = open(_TFLITE_FILE_PATH.value, 'wb')
|
|
f.write(tflite_quantized_model)
|
|
f.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run(main)
|