From ce34b6dd37afdce666e3b0b058c524ef9ebb5ef6 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 22 Jan 2018 13:25:32 +0100 Subject: Add batch normalization test case --- .../tensorflow/TensorflowImportTestCase.java | 91 ++-------------------- 1 file changed, 6 insertions(+), 85 deletions(-) (limited to 'searchlib/src/test/java/com/yahoo') diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java index e22e4a36bab..cf4e64c74a1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java @@ -71,99 +71,20 @@ public class TensorflowImportTestCase { } @Test - public void test3LayerMnistImport() { - String modelDir = "src/test/files/integration/tensorflow/3_layer_mnist/saved"; + public void testBatchNormImport() { + String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved"; SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); TensorFlowModel result = new TensorFlowImporter().importModel(model); + TensorFlowModel.Signature signature = result.signature("serving_default"); - // Check constants - assertEquals(8, result.constants().size()); - - Tensor outputBias = result.constants().get("dnn/outputs/bias"); - assertNotNull(outputBias); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), outputBias.type()); - assertEquals(10, outputBias.size()); - - Tensor outputWeights = result.constants().get("dnn/outputs/weights"); - assertNotNull(outputWeights); - assertEquals(new TensorType.Builder().indexed("d0", 40).indexed("d1", 10).build(), outputWeights.type()); - assertEquals(400, outputWeights.size()); - - Tensor hidden3Bias = result.constants().get("dnn/hidden3/bias"); - assertNotNull(hidden3Bias); - assertEquals(new TensorType.Builder().indexed("d0", 40).build(), hidden3Bias.type()); - assertEquals(40, hidden3Bias.size()); - - Tensor hidden3Weights = result.constants().get("dnn/hidden3/weights"); - assertNotNull(hidden3Weights); - assertEquals(new TensorType.Builder().indexed("d0", 100).indexed("d1", 40).build(), hidden3Weights.type()); - assertEquals(4000, hidden3Weights.size()); - - Tensor hidden2Bias = result.constants().get("dnn/hidden2/bias"); - assertNotNull(hidden2Bias); - assertEquals(new TensorType.Builder().indexed("d0", 100).build(), hidden2Bias.type()); - assertEquals(100, hidden2Bias.size()); - - Tensor hidden2Weights = result.constants().get("dnn/hidden2/weights"); - assertNotNull(hidden2Weights); - assertEquals(new TensorType.Builder().indexed("d0", 300).indexed("d1", 100).build(), hidden2Weights.type()); - assertEquals(30000, hidden2Weights.size()); - - Tensor hidden1Bias = result.constants().get("dnn/hidden1/bias"); - assertNotNull(hidden1Bias); - assertEquals(new TensorType.Builder().indexed("d0", 300).build(), hidden1Bias.type()); - assertEquals(300, hidden1Bias.size()); - - Tensor hidden1Weights = result.constants().get("dnn/hidden1/weights"); - assertNotNull(hidden1Weights); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 300).build(), hidden1Weights.type()); - assertEquals(235200, hidden1Weights.size()); - - // Check signatures - assertEquals(1, result.signatures().size()); - TensorFlowModel.Signature signature = result.signatures().get("serving_default"); - assertNotNull(signature); + assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size()); - // ... signature inputs - assertEquals(1, signature.inputs().size()); - TensorType argument0 = signature.inputArgument("x"); - assertNotNull(argument0); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); - - // ... signature outputs - assertEquals(1, signature.outputs().size()); RankingExpression output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/outputs/add", output.getName()); - assertEquals("" + - "join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(X, rename(constant('dnn/hidden1/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden1/bias'), d0, d1), f(a,b)(a + b)), f(a)(if (a < 0, exp(a) - 1, a))), rename(constant('dnn/hidden2/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden2/bias'), d0, d1), f(a,b)(a + b)), f(a)(max(0,a))), rename(constant('dnn/hidden3/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden3/bias'), d0, d1), f(a,b)(a + b)), f(a)(1 / (1 + exp(-a)))), rename(constant('dnn/outputs/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/outputs/bias'), d0, d1), f(a,b)(a + b))", - toNonPrimitiveString(output)); - - // Test constants - assertEqualResult(model, result, "X", "dnn/hidden1/weights/read"); - assertEqualResult(model, result, "X", "dnn/hidden1/bias/read"); - assertEqualResult(model, result, "X", "dnn/hidden2/weights/read"); - assertEqualResult(model, result, "X", "dnn/hidden2/bias/read"); - assertEqualResult(model, result, "X", "dnn/hidden3/weights/read"); - assertEqualResult(model, result, "X", "dnn/hidden3/bias/read"); - assertEqualResult(model, result, "X", "dnn/outputs/weights/read"); - assertEqualResult(model, result, "X", "dnn/outputs/bias/read"); - - // Test execution - assertEqualResult(model, result, "X", "dnn/hidden1/MatMul"); - assertEqualResult(model, result, "X", "dnn/hidden1/add"); - assertEqualResult(model, result, "X", "dnn/hidden1/Elu"); - assertEqualResult(model, result, "X", "dnn/hidden2/MatMul"); - assertEqualResult(model, result, "X", "dnn/hidden2/add"); - assertEqualResult(model, result, "X", "dnn/hidden2/Relu"); - assertEqualResult(model, result, "X", "dnn/hidden3/MatMul"); - assertEqualResult(model, result, "X", "dnn/hidden3/add"); - assertEqualResult(model, result, "X", "dnn/hidden3/Sigmoid"); - assertEqualResult(model, result, "X", "dnn/outputs/MatMul"); - assertEqualResult(model, result, "X", "dnn/outputs/add"); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); + assertEqualResult(model, result, "X", output.getName()); } - private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) { Tensor tfResult = tensorFlowExecute(model, inputName, operationName); Context context = contextFrom(result); -- cgit v1.2.3