summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py')
-rw-r--r--model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py
new file mode 100644
index 00000000000..bc6ea13ebc1
--- /dev/null
+++ b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py
@@ -0,0 +1,95 @@
+# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import tensorflow as tf
+
+from functools import partial
+from tensorflow.examples.tutorials.mnist import input_data
+from datetime import datetime
+
+now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
+root_logdir = "tf_logs"
+logdir = "{}/run-{}/".format(root_logdir, now)
+
+mnist = input_data.read_data_sets("/tmp/data/")
+X_train = mnist.train.images
+X_test = mnist.test.images
+y_train = mnist.train.labels.astype("int")
+y_test = mnist.test.labels.astype("int")
+
+n_inputs = 28*28 # MNIST
+n_hidden1 = 300
+n_hidden2 = 100
+n_hidden3 = 40
+n_outputs = 10
+
+learning_rate = 0.01
+n_epochs = 20
+batch_size = 200
+batch_norm_momentum = 0.9
+
+X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
+y = tf.placeholder(tf.int64, shape=(None), name="y")
+training = tf.placeholder_with_default(False, shape=(), name='training')
+
+def leaky_relu(z, name=None):
+ return tf.maximum(0.01 * z, z, name=name)
+
+with tf.name_scope("dnn"):
+ he_init = tf.contrib.layers.variance_scaling_initializer()
+
+ batch_norm_layer = partial(tf.layers.batch_normalization, training=training, momentum=batch_norm_momentum)
+ dense_layer = partial(tf.layers.dense, kernel_initializer=he_init)
+
+ hidden1 = dense_layer(X, n_hidden1, name="hidden1", activation=leaky_relu)
+ bn1 = tf.nn.elu(batch_norm_layer(hidden1))
+ hidden2 = dense_layer(bn1, n_hidden2, name="hidden2", activation=tf.nn.elu)
+ bn2 = tf.nn.elu(batch_norm_layer(hidden2))
+ logits_before_bn = dense_layer(bn2, n_outputs, name="outputs", activation=tf.nn.selu)
+ logits = batch_norm_layer(logits_before_bn)
+
+with tf.name_scope("loss"):
+ xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
+ loss = tf.reduce_mean(xentropy, name="loss")
+
+with tf.name_scope("train"):
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+ training_op = optimizer.minimize(loss)
+
+with tf.name_scope("eval"):
+ correct = tf.nn.in_top_k(logits, y, 1)
+ accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
+
+init = tf.global_variables_initializer()
+accuracy_summary = tf.summary.scalar('Accuracy', accuracy)
+file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph())
+extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+
+with tf.Session() as sess:
+ init.run()
+ for epoch in range(n_epochs):
+ for iteration in range(mnist.train.num_examples // batch_size):
+ X_batch, y_batch = mnist.train.next_batch(batch_size)
+ sess.run([training_op, extra_update_ops],
+ feed_dict={training: True, X: X_batch, y: y_batch})
+
+ accuracy_val = accuracy.eval(feed_dict={X: mnist.test.images,
+ y: mnist.test.labels})
+ print(epoch, "Test accuracy:", accuracy_val)
+
+ # Save summary for tensorboard
+ summary_str = accuracy_summary.eval(feed_dict={X: mnist.validation.images,
+ y: mnist.validation.labels})
+ file_writer.add_summary(summary_str, epoch)
+
+ export_path = "saved"
+ print('Exporting trained model to ', export_path)
+ builder = tf.saved_model.builder.SavedModelBuilder(export_path)
+ signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':X}, outputs = {'y':logits})
+ builder.add_meta_graph_and_variables(sess,
+ [tf.saved_model.tag_constants.SERVING],
+ signature_def_map={'serving_default':signature})
+ builder.save(as_text=True)
+
+file_writer.close()
+
+