mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Expanded the SavedModel library to allow the compilation of models.
This is important for use with TensorFlow serving. Also removed the servo_main.py (only applies to OSS TF Serving, which does not yet support XLA) Co-authored-by: Benjamin Chetioui <3920784+bchetioui@users.noreply.github.com>
This commit is contained in:
parent
cc8fe15e46
commit
4ee7a296c3
@ -1,14 +1,18 @@
|
||||
Getting started with jax2tf
|
||||
===========================
|
||||
|
||||
This directory contains a number of examples of using the jax2tf converter to:
|
||||
|
||||
* save SavedModel from trained MNIST models, using both pure JAX and Flax.
|
||||
* load the SavedModel into the TensorFlow model server and use it for
|
||||
inference. We show also how to use a batch-polymorphic saved model.
|
||||
* reuse the feature-extractor part of the trained MNIST model in
|
||||
TensorFlow Hub, and in a larger TensorFlow Keras model.
|
||||
|
||||
Preparing the model for jax2tf
|
||||
==============================
|
||||
It is also possible to use jax2tf-generated SavedModel with TensorFlow serving.
|
||||
At the moment, the open-source TensorFlow model server is missing XLA support,
|
||||
but the Google version can be used, as described in README_serving.md.
|
||||
|
||||
# Preparing the model for jax2tf
|
||||
|
||||
|
||||
The most important detail for using jax2tf with SavedModel is to express
|
||||
the trained model as a pair:
|
||||
@ -16,17 +20,18 @@ the trained model as a pair:
|
||||
* `func`: a two-argument function with signature:
|
||||
`(params: Parameters, inputs: Inputs) -> Outputs`.
|
||||
Both arguments must be a `numpy.ndarray` or
|
||||
nested tuple/list/dictionaries thereof.
|
||||
nested tuple/list/dictionaries thereof. In particular, it is important
|
||||
for models that take multiple inputs to be adapted to take their inputs
|
||||
as a single tuple/list/dictionary of `numpy.ndarray`.
|
||||
* `params`: of type `Parameters`, the model parameters, to be used as the
|
||||
first input for `func`.
|
||||
|
||||
You can see in [mnist_lib.py](mnist_lib.py) how this can be done for two
|
||||
You can see in [mnist_lib.py](mnist_lib.py) how this can be done for two
|
||||
implementations of MNIST, one using pure JAX (`PureJaxMNIST`) and a CNN
|
||||
one using Flax (`FlaxMNIST`). Other Flax models can be arranged similarly,
|
||||
and the same strategy should work for other neural-network libraries for JAX.
|
||||
|
||||
Generating TensorFlow SavedModel
|
||||
=====================
|
||||
# Generating TensorFlow SavedModel
|
||||
|
||||
Once you have the model in this form, you can use the
|
||||
`saved_model_lib.save_model` function [saved_model_lib.py](saved_model_lib.py)
|
||||
@ -36,19 +41,24 @@ into functions that behave as if they had been written with TensorFlow.
|
||||
Therefore, if you are familiar with how to generate SavedModel, you can most
|
||||
likely just use your own code for this.
|
||||
|
||||
The file `saved_model_main.py` is an executable that shows who to perform the following
|
||||
sequence of steps:
|
||||
The file `saved_model_main.py` is an executable that shows how to perform the
|
||||
following sequence of steps:
|
||||
|
||||
* train an MNIST model, and obtain a pair of an inference function and the parameters.
|
||||
* train an MNIST model, and obtain a pair of an inference function and the
|
||||
parameters.
|
||||
* convert the inference function with jax2tf, for one of more batch sizes.
|
||||
* save a SavedModel and dump its contents.
|
||||
* reload the SavedModel and run it with TensorFlow to test that the inference
|
||||
function produces the same results as the JAX inference function.
|
||||
* optionally plot images with the training digits and the inference results.
|
||||
|
||||
|
||||
|
||||
There are a number of flags to select the Flax model (`--model=mnist_flag`),
|
||||
to skip the training and just test a previously loaded
|
||||
SavedModel (`--nogenerate_model`), to choose the saving path, etc.
|
||||
The default saving location is `/tmp/jax2tf/saved_models/1`.
|
||||
|
||||
|
||||
By default, this example will convert the inference function for three separate
|
||||
batch sizes: 1, 16, 128. You can see this in the dumped SavedModel. If you
|
||||
@ -58,42 +68,8 @@ be done using jax2tf's
|
||||
As a result, the inference function will be traced only once and the SavedModel
|
||||
will contain a single batch-polymorphic TensorFlow graph.
|
||||
|
||||
Using the TensorFlow model server
|
||||
=================================
|
||||
|
||||
The executable `servo_main.py` extends the `saved_model_main.py` with code to
|
||||
show how to use the SavedModel with TensorFlow model server. All the flags
|
||||
of `saved_model_main.py` also apply to `servo_main.py`. In particular, you
|
||||
can select the Flax MNIST model: `--model=mnist_flax`,
|
||||
batch-polymorphic conversion to TensorFlow `--serving_batch_size=-1`,
|
||||
skip the training, reuse a previously trained model: `--nogenerate_model`.
|
||||
|
||||
If you want to start your own model server, you should pass the
|
||||
`--nostart_model_server` flag and also `--serving_url` to point to the
|
||||
HTTP REST API end point of your model server. You can see the path of the
|
||||
trained and saved model in the output.
|
||||
|
||||
Open-source model server
|
||||
------------------------
|
||||
|
||||
We have tried this example with the OSS model server, using a
|
||||
[TensorFlow Serving with Docker](https://www.tensorflow.org/tfx/serving/docker).
|
||||
Specifically, you need to install Docker and run
|
||||
|
||||
```
|
||||
docker pull tensorflow/serving
|
||||
```
|
||||
|
||||
The actual starting of the model server is done by the `servo_main.py` script.
|
||||
The script also does a sanity check that the model server is serving a model
|
||||
with the right batch size and image shapes, and makes a few requests to the
|
||||
server to perform inference.
|
||||
|
||||
Instructions for using the Google internal version of the model server: TBA.
|
||||
|
||||
|
||||
Reusing models with TensorFlow Hub and TensorFlow Keras
|
||||
=======================================================
|
||||
# Reusing models with TensorFlow Hub and TensorFlow Keras
|
||||
|
||||
The SavedModel produced by the example in `saved_model_main.py` already
|
||||
implements the [reusable saved models interface](https://www.tensorflow.org/hub/reusable_saved_models).
|
||||
|
@ -31,40 +31,64 @@ def save_model(jax_fn: Callable,
|
||||
*,
|
||||
input_signatures: Sequence[tf.TensorSpec],
|
||||
shape_polymorphic_input_spec: Optional[str] = None,
|
||||
with_gradient: bool = False):
|
||||
with_gradient: bool = False,
|
||||
compile_model: bool = True):
|
||||
"""Saves the SavedModel for a function.
|
||||
|
||||
In order to use this wrapper you must first convert your model to a function
|
||||
with two arguments: the parameters and the input on which you want to do
|
||||
inference. Both arguments may be tuples/lists/dictionaries of np.ndarray.
|
||||
inference. Both arguments may be np.ndarray or
|
||||
(nested) tuples/lists/dictionaries thereof.
|
||||
|
||||
If you want to save the model for a function with multiple parameters and
|
||||
multiple inputs, you have to collect the parameters and the inputs into
|
||||
one argument, e.g., adding a tuple or dictionary at top-level.
|
||||
|
||||
```
|
||||
def jax_fn_multi(param1, param2, input1, input2):
|
||||
# JAX model with multiple parameters and multiple inputs. They all can
|
||||
# be (nested) tuple/list/dictionaries of np.ndarray.
|
||||
...
|
||||
|
||||
def jax_fn_for_save_model(params, inputs):
|
||||
# JAX model with parameters and inputs collected in a tuple each. We can
|
||||
# use dictionaries also (in which case the keys would appear as the names
|
||||
# of the inputs)
|
||||
param1, param2 = params
|
||||
input1, input2 = inputs
|
||||
return jax_fn_multi(param1, param2, input1, input2)
|
||||
save_model(jax_fn_for_save_model, (param1, param2), ...)
|
||||
```
|
||||
|
||||
See examples in mnist_lib.py and saved_model.py.
|
||||
|
||||
Args:
|
||||
jax_fn: a JAX function taking two arguments, the parameters and the inputs.
|
||||
Both arguments may be tuples/lists/dictionaries of np.ndarray.
|
||||
Both arguments may be (nested) tuples/lists/dictionaries of np.ndarray.
|
||||
params: the parameters, to be used as first argument for `jax_fn`. These
|
||||
must be tuples/lists/dictionaries of np.ndarray, and will be saved as the
|
||||
variables of the SavedModel.
|
||||
must be (nested) tuples/lists/dictionaries of np.ndarray, and will be
|
||||
saved as the variables of the SavedModel.
|
||||
model_dir: the directory where the model should be saved.
|
||||
input_signatures: the input signatures for the second argument of `jax_fn`
|
||||
(the input). A signature must be a `tensorflow.TensorSpec` instance, or a
|
||||
tuple/list/dictionary thereof with a structure matching the second
|
||||
argument of `jax_fn`. The first input_signature will be saved as the
|
||||
default serving signature. The additional signatures will be used only to
|
||||
ensure that the `jax_fn` is traced and converted to TF for the
|
||||
(nested) tuple/list/dictionary thereof with a structure matching the
|
||||
second argument of `jax_fn`. The first input_signature will be saved as
|
||||
the default serving signature. The additional signatures will be used
|
||||
only to ensure that the `jax_fn` is traced and converted to TF for the
|
||||
corresponding input shapes.
|
||||
shape_polymorphic_input_spec: if given then it will be used as the
|
||||
`in_shapes` argument to jax2tf.convert for the second parameter of
|
||||
`jax_fn`. In this case, a single `input_signatures` is supported, and
|
||||
should have `None` in the polymorphic dimensions. Should be a string, or a
|
||||
tuple/list/dictionary thereof with a structure matching the second
|
||||
argument of `jax_fn`.
|
||||
(nesteD) tuple/list/dictionary thereof with a structure matching the
|
||||
second argument of `jax_fn`.
|
||||
with_gradient: whether the SavedModel should support gradients. If True,
|
||||
then a custom gradient is saved. If False, then a
|
||||
tf.raw_ops.PreventGradient is saved to error if a gradient is attempted.
|
||||
(At the moment due to a bug in SavedModel, custom gradients are not
|
||||
supported.)
|
||||
compile_model: use TensorFlow experimental_compiler on the SavedModel. This
|
||||
is needed if the SavedModel will be used for TensorFlow serving.
|
||||
"""
|
||||
if not input_signatures:
|
||||
raise ValueError("At least one input_signature must be given")
|
||||
@ -77,49 +101,46 @@ def save_model(jax_fn: Callable,
|
||||
with_gradient=with_gradient,
|
||||
in_shapes=[None, shape_polymorphic_input_spec])
|
||||
|
||||
wrapper = _ExportWrapper(tf_fn, params, params_trainable=with_gradient)
|
||||
# Create tf.Variables for the parameters.
|
||||
param_vars = tf.nest.map_structure(
|
||||
# If with_gradient=False, we mark the variables behind as non-trainable,
|
||||
# to ensure that users of the SavedModel will not try to fine tune them.
|
||||
lambda param: tf.Variable(param, trainable=with_gradient),
|
||||
params)
|
||||
tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
|
||||
autograph=False,
|
||||
experimental_compile=compile_model)
|
||||
|
||||
signatures = {}
|
||||
signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
|
||||
wrapper.__call__.get_concrete_function(input_signatures[0])
|
||||
tf_graph.get_concrete_function(input_signatures[0])
|
||||
for input_signature in input_signatures[1:]:
|
||||
# If there are more signatures, trace and cache a TF function for each one
|
||||
wrapper.__call__.get_concrete_function(input_signature)
|
||||
tf_graph.get_concrete_function(input_signature)
|
||||
wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
|
||||
tf.saved_model.save(wrapper, model_dir, signatures=signatures)
|
||||
|
||||
|
||||
class _ExportWrapper(tf.train.Checkpoint):
|
||||
class _ReusableSavedModelWrapper(tf.train.Checkpoint):
|
||||
"""Wraps a function and its parameters for saving to a SavedModel.
|
||||
|
||||
Implements the interface described at
|
||||
https://www.tensorflow.org/hub/reusable_saved_models.
|
||||
"""
|
||||
|
||||
def __init__(self, tf_fn: Callable, params, params_trainable: bool):
|
||||
def __init__(self, tf_graph, param_vars):
|
||||
"""Args:
|
||||
|
||||
tf_fn: a TF function taking two arguments, the parameters and the inputs.
|
||||
Both arguments may be tuples/lists/dictionaries of np.ndarray of tensors.
|
||||
params: the parameters, to be used as first argument for `tf_fn`. These
|
||||
must be tuples/lists/dictionaries of np.ndarray, and will be saved as the
|
||||
variables of the SavedModel.
|
||||
tf_fn: a tf.function taking one argument (the inputs), which can be
|
||||
be tuples/lists/dictionaries of np.ndarray or tensors.
|
||||
params: the parameters, as tuples/lists/dictionaries of tf.Variable,
|
||||
and will be saved as the variables of the SavedModel.
|
||||
"""
|
||||
super().__init__()
|
||||
self._fn = tf_fn
|
||||
# Create tf.Variables for the parameters.
|
||||
self._params = tf.nest.map_structure(
|
||||
# If with_gradient=False, we mark the variables behind as non-trainable,
|
||||
# to ensure that users of the SavedModel will not try to fine tune them.
|
||||
lambda param: tf.Variable(param, trainable=params_trainable),
|
||||
params)
|
||||
|
||||
# Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models
|
||||
self.variables = tf.nest.flatten(self._params)
|
||||
self.variables = tf.nest.flatten(param_vars)
|
||||
self.trainable_variables = [v for v in self.variables if v.trainable]
|
||||
# If you intend to prescribe regularization terms for users of the model,
|
||||
# add them as @tf.functions with no inputs to this list. Else drop this.
|
||||
self.regularization_losses = []
|
||||
|
||||
@tf.function(autograph=False)
|
||||
def __call__(self, inputs):
|
||||
outputs = self._fn(self._params, inputs)
|
||||
return outputs
|
||||
self.__call__ = tf_graph
|
||||
|
@ -55,6 +55,10 @@ flags.DEFINE_integer("num_epochs", 3, "For how many epochs to train.")
|
||||
flags.DEFINE_boolean(
|
||||
"generate_model", True,
|
||||
"Train and save a new model. Otherwise, use an existing SavedModel.")
|
||||
flags.DEFINE_boolean(
|
||||
"compile_model", True,
|
||||
"Enable TensorFlow experimental_compiler for the SavedModel. This is "
|
||||
"necessary if you want to use the model for TensorFlow serving.")
|
||||
flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.")
|
||||
flags.DEFINE_boolean(
|
||||
"show_images", False,
|
||||
@ -111,7 +115,8 @@ def train_and_save():
|
||||
predict_params,
|
||||
model_dir,
|
||||
input_signatures=input_signatures,
|
||||
shape_polymorphic_input_spec=shape_polymorphic_input_spec)
|
||||
shape_polymorphic_input_spec=shape_polymorphic_input_spec,
|
||||
compile_model=FLAGS.compile_model)
|
||||
|
||||
if FLAGS.test_savedmodel:
|
||||
tf_accelerator, tolerances = tf_accelerator_and_tolerances()
|
||||
@ -134,9 +139,7 @@ def train_and_save():
|
||||
pure_restored_model(tf.convert_to_tensor(test_input)),
|
||||
predict_fn(predict_params, test_input), **tolerances)
|
||||
|
||||
assert os.path.isdir(model_dir)
|
||||
if FLAGS.show_model:
|
||||
|
||||
def print_model(model_dir: str):
|
||||
cmd = f"saved_model_cli show --all --dir {model_dir}"
|
||||
print(cmd)
|
||||
@ -165,14 +168,9 @@ def model_description() -> str:
|
||||
|
||||
def savedmodel_dir(with_version: bool = True) -> str:
|
||||
"""The directory where we save the SavedModel."""
|
||||
if FLAGS.model == "mnist_pure_jax":
|
||||
model_class = mnist_lib.PureJaxMNIST
|
||||
elif FLAGS.model == "mnist_flax":
|
||||
model_class = mnist_lib.FlaxMNIST
|
||||
|
||||
model_dir = os.path.join(
|
||||
FLAGS.model_path,
|
||||
f"{model_class.name}{'' if FLAGS.model_classifier_layer else '_features'}"
|
||||
f"{'mnist' if FLAGS.model_classifier_layer else 'mnist_features'}"
|
||||
)
|
||||
if with_version:
|
||||
model_dir = os.path.join(model_dir, str(FLAGS.model_version))
|
||||
|
@ -1,171 +0,0 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Demonstrates use of a jax2tf model in TensorFlow model server.
|
||||
|
||||
Includes the flags from saved_model_main.py.
|
||||
|
||||
If you want to start your own model server, you should pass the
|
||||
`--nostart_model_server` flag and also `--serving_url` to point to the
|
||||
HTTP REST API end point of your model server. You can see the path of the
|
||||
trained and saved model in the output.
|
||||
|
||||
See README.md.
|
||||
"""
|
||||
import atexit
|
||||
import logging
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from jax.experimental.jax2tf.examples import mnist_lib
|
||||
from jax.experimental.jax2tf.examples import saved_model_main
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
import tensorflow as tf # type: ignore
|
||||
import tensorflow_datasets as tfds # type: ignore
|
||||
|
||||
|
||||
flags.DEFINE_integer("count_images", 10, "How many images to test")
|
||||
flags.DEFINE_bool("start_model_server", True,
|
||||
"Whether to start/stop the model server.")
|
||||
flags.DEFINE_string("serving_url", "http://localhost:8501/v1/models/jax_model",
|
||||
"The HTTP endpoint for the model server")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def mnist_predict_request(serving_url: str, images):
|
||||
"""Predicts using the model server.
|
||||
|
||||
Args:
|
||||
serving_url: The URL for the model server.
|
||||
images: A batch of images of shape F32[B, 28, 28, 1]
|
||||
|
||||
Returns:
|
||||
a batch of one-hot predictions, of shape F32[B, 10]
|
||||
"""
|
||||
request = {"inputs": images.tolist()}
|
||||
response = requests.post(f"{serving_url}:predict", json=request)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError("Model server error: " + response_json["error"])
|
||||
|
||||
predictions = np.array(response_json["outputs"])
|
||||
return predictions
|
||||
|
||||
|
||||
def main(_):
|
||||
if FLAGS.count_images % FLAGS.serving_batch_size != 0:
|
||||
raise ValueError("count_images must be a multiple of serving_batch_size")
|
||||
|
||||
saved_model_main.train_and_save()
|
||||
# Strip the version number from the model directory
|
||||
servo_model_dir = saved_model_main.savedmodel_dir(with_version=False)
|
||||
|
||||
if FLAGS.start_model_server:
|
||||
model_server_proc = _start_localhost_model_server(servo_model_dir)
|
||||
|
||||
try:
|
||||
_mnist_sanity_check(FLAGS.serving_url, FLAGS.serving_batch_size)
|
||||
test_ds = mnist_lib.load_mnist(
|
||||
tfds.Split.TEST, batch_size=FLAGS.serving_batch_size)
|
||||
images_and_labels = tfds.as_numpy(
|
||||
test_ds.take(FLAGS.count_images // FLAGS.serving_batch_size))
|
||||
|
||||
for (images, labels) in images_and_labels:
|
||||
predictions_one_hot = mnist_predict_request(FLAGS.serving_url, images)
|
||||
predictions_digit = np.argmax(predictions_one_hot, axis=1)
|
||||
label_digit = np.argmax(labels, axis=1)
|
||||
logging.info(
|
||||
f" predicted = {predictions_digit} labelled digit {label_digit}")
|
||||
finally:
|
||||
if FLAGS.start_model_server:
|
||||
model_server_proc.kill()
|
||||
model_server_proc.communicate()
|
||||
|
||||
|
||||
def _mnist_sanity_check(serving_url: str, serving_batch_size: int):
|
||||
"""Checks that we can reach a model server with a model that matches MNIST."""
|
||||
logging.info("Checking that model server serves a compatible model.")
|
||||
response = requests.get(f"{serving_url}/metadata")
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise IOError("Model server error: " + response_json["error"])
|
||||
|
||||
try:
|
||||
signature_def = response_json["metadata"]["signature_def"]["signature_def"]
|
||||
serving_default = signature_def[
|
||||
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
assert serving_default["method_name"] == "tensorflow/serving/predict"
|
||||
inputs, = serving_default["inputs"].values()
|
||||
b, w, h, c = [int(d["size"]) for d in inputs["tensor_shape"]["dim"]]
|
||||
assert b == -1 or b == serving_batch_size, (
|
||||
f"Found input batch size {b}. Expecting {serving_batch_size}")
|
||||
assert (w, h, c) == mnist_lib.input_shape
|
||||
except Exception as e:
|
||||
raise IOError(
|
||||
f"Unexpected response from model server: {response_json}") from e
|
||||
|
||||
|
||||
def _start_localhost_model_server(model_dir):
|
||||
"""Starts the model server on localhost, using docker.
|
||||
|
||||
Ignore this if you have a different way to start the model server.
|
||||
"""
|
||||
cmd = ("docker run -p 8501:8501 --mount "
|
||||
f"type=bind,source={model_dir}/,target=/models/jax_model "
|
||||
"-e MODEL_NAME=jax_model -t --rm --name=serving tensorflow/serving")
|
||||
cmd_args = cmd.split(" ")
|
||||
logging.info("Starting model server")
|
||||
logging.info(f"Running {cmd}")
|
||||
proc = subprocess.Popen(
|
||||
cmd_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
model_server_ready = False
|
||||
|
||||
def model_server_output_reader():
|
||||
for line in iter(proc.stdout.readline, b""):
|
||||
line_str = line.decode("utf-8").strip()
|
||||
if "Exporting HTTP/REST API at:localhost:8501" in line_str:
|
||||
nonlocal model_server_ready
|
||||
model_server_ready = True
|
||||
logging.info(f"Model server: {line_str}")
|
||||
|
||||
output_thread = threading.Thread(target=model_server_output_reader, args=())
|
||||
output_thread.start()
|
||||
|
||||
def _stop_model_server():
|
||||
logging.info("Stopping the model server")
|
||||
subprocess.run("docker container stop serving".split(" "),
|
||||
check=True)
|
||||
|
||||
atexit.register(_stop_model_server)
|
||||
|
||||
wait_iteration_sec = 2
|
||||
wait_remaining_sec = 10
|
||||
while not model_server_ready and wait_remaining_sec > 0:
|
||||
logging.info("Waiting for the model server to be ready...")
|
||||
time.sleep(wait_iteration_sec)
|
||||
wait_remaining_sec -= wait_iteration_sec
|
||||
|
||||
if wait_remaining_sec <= 0:
|
||||
raise IOError("Model server failed to start properly")
|
||||
return proc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
Loading…
x
Reference in New Issue
Block a user