diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-12 12:16:56 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-12 12:16:56 +0200 |
commit | 0a886d74d4c9ffde41eef1f7e3c186b60b9f3726 (patch) | |
tree | e142b94341563b28a2b4a0e26fe77458749d2ed9 /config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/mnist_sftmax_with_saving.py | |
parent | 8de8ff4f87295d812d4e660f0216953726200c92 (diff) |
Import Tensorflow models vis ONNX conversion
Diffstat (limited to 'config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/mnist_sftmax_with_saving.py')
-rw-r--r-- | config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/mnist_sftmax_with_saving.py | 93 |
1 files changed, 0 insertions, 93 deletions
diff --git a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/mnist_sftmax_with_saving.py b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/mnist_sftmax_with_saving.py deleted file mode 100644 index 3f4f794d2ac..00000000000 --- a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/mnist_sftmax_with_saving.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""A very simple MNIST classifier. - -See extensive documentation at -https://www.tensorflow.org/get_started/mnist/beginners -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import sys - -from tensorflow.examples.tutorials.mnist import input_data - -import tensorflow as tf - -FLAGS = None - - -def main(_): - # Import data - mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) - - # Create the model - x = tf.placeholder(tf.float32, [None, 784]) - - with tf.name_scope("layer"): - W = tf.Variable(tf.zeros([784, 10])) - b = tf.Variable(tf.zeros([10])) - y = tf.matmul(x, W) + b - - - # Define loss and optimizer - y_ = tf.placeholder(tf.float32, [None, 10]) - - # The raw formulation of cross-entropy, - # - # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), - # reduction_indices=[1])) - # - # can be numerically unstable. - # - # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw - # outputs of 'y', and then average across the batch. - cross_entropy = tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) - train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) - - sess = tf.InteractiveSession() - tf.global_variables_initializer().run() - # Train - for _ in range(1000): - batch_xs, batch_ys = mnist.train.next_batch(100) - sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) - - # Test trained model - correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - print(sess.run(accuracy, feed_dict={x: mnist.test.images, - y_: mnist.test.labels})) - - # Save the model - export_path = "saved" - print('Exporting trained model to ', export_path) - builder = tf.saved_model.builder.SavedModelBuilder(export_path) - signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y}) - builder.add_meta_graph_and_variables(sess, - [tf.saved_model.tag_constants.SERVING], - signature_def_map={'serving_default':signature}) - builder.save(as_text=True) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |