2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2020 The JAX Authors.
|
2020-10-31 08:09:08 +02:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Demonstrates reuse of a jax2tf model in Keras.
|
|
|
|
|
|
|
|
Includes the flags from saved_model_main.py.
|
|
|
|
|
|
|
|
See README.md.
|
|
|
|
"""
|
|
|
|
import logging
|
|
|
|
from absl import app
|
|
|
|
from absl import flags
|
2021-04-15 15:48:00 +03:00
|
|
|
from jax.experimental.jax2tf.examples import mnist_lib # type: ignore
|
|
|
|
from jax.experimental.jax2tf.examples import saved_model_main # type: ignore
|
2020-10-31 08:09:08 +02:00
|
|
|
import tensorflow as tf # type: ignore
|
|
|
|
import tensorflow_datasets as tfds # type: ignore
|
|
|
|
import tensorflow_hub as hub # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
|
|
|
|
|
|
def main(_):
|
|
|
|
FLAGS.model_classifier_layer = False # We only need the features
|
2021-03-29 14:20:15 +03:00
|
|
|
# Train the model and save the feature extractor
|
|
|
|
saved_model_main.train_and_save()
|
2020-10-31 08:09:08 +02:00
|
|
|
|
|
|
|
tf_accelerator, _ = saved_model_main.tf_accelerator_and_tolerances()
|
|
|
|
feature_model_dir = saved_model_main.savedmodel_dir()
|
|
|
|
|
|
|
|
# With Keras, we use the tf.distribute.OneDeviceStrategy as the high-level
|
|
|
|
# analogue of the tf.device(...) placement seen above.
|
|
|
|
# It works on CPU, GPU and TPU.
|
|
|
|
# Actual high-performance training would use the appropriately replicated
|
|
|
|
# TF Distribution Strategy.
|
|
|
|
strategy = tf.distribute.OneDeviceStrategy(tf_accelerator)
|
|
|
|
with strategy.scope():
|
|
|
|
images = tf.keras.layers.Input(
|
|
|
|
mnist_lib.input_shape, batch_size=mnist_lib.train_batch_size)
|
2022-11-21 05:08:10 -08:00
|
|
|
keras_feature_extractor = hub.KerasLayer(feature_model_dir, trainable=True)
|
2020-10-31 08:09:08 +02:00
|
|
|
features = keras_feature_extractor(images)
|
|
|
|
predictor = tf.keras.layers.Dense(10, activation="softmax")
|
|
|
|
predictions = predictor(features)
|
|
|
|
keras_model = tf.keras.Model(images, predictions)
|
|
|
|
|
|
|
|
keras_model.compile(
|
|
|
|
loss=tf.keras.losses.categorical_crossentropy,
|
|
|
|
optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
|
|
|
|
metrics=["accuracy"])
|
|
|
|
logging.info(keras_model.summary())
|
|
|
|
|
|
|
|
train_ds = mnist_lib.load_mnist(
|
|
|
|
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
|
|
|
|
test_ds = mnist_lib.load_mnist(
|
|
|
|
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
|
|
|
|
keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)
|
|
|
|
|
2023-07-27 12:15:16 -07:00
|
|
|
if saved_model_main.SHOW_IMAGES.value:
|
2020-10-31 08:09:08 +02:00
|
|
|
mnist_lib.plot_images(
|
|
|
|
test_ds,
|
|
|
|
1,
|
|
|
|
5,
|
|
|
|
f"Keras inference with reuse of {saved_model_main.model_description()}",
|
|
|
|
inference_fn=lambda images: keras_model(tf.convert_to_tensor(images)))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
app.run(main)
|