From 880149e6380a52edb089d59752a8fd4ea669e400 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 25 Jan 2018 13:22:47 +0100 Subject: Refactor: Move state to helper --- .../tensorflow/BatchNormImportTestCase.java | 12 ++++----- .../tensorflow/MnistSoftmaxImportTestCase.java | 31 +++++++--------------- .../tensorflow/TensorFlowImportTester.java | 18 ++++++++++--- 3 files changed, 29 insertions(+), 32 deletions(-) (limited to 'searchlib/src/test/java') diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java index 770ba168f19..3d028b0775e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java @@ -15,18 +15,16 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { - TensorFlowImportTester tester = new TensorFlowImportTester(); - String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved"; - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - TensorFlowModel result = new TensorFlowImporter().importModel(model); - TensorFlowModel.Signature signature = result.signature("serving_default"); + TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/batch_norm/saved"); + TensorFlowModel.Signature signature = tester.result().signature("serving_default"); - assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size()); + assertEquals("Has skipped outputs", + 0, tester.result().signature("serving_default").skippedOutputs().size()); RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); - tester.assertEqualResult(model, result, "X", output.getName()); + tester.assertEqualResult("X", output.getName()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index 11063690e2a..044a6917e00 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -2,17 +2,9 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; - -import java.nio.FloatBuffer; -import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -24,29 +16,26 @@ public class MnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - TensorFlowImportTester tester = new TensorFlowImportTester(); - String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - TensorFlowModel result = new TensorFlowImporter().importModel(model); + TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/mnist_softmax/saved"); // Check constants - assertEquals(2, result.constants().size()); + assertEquals(2, tester.result().constants().size()); - Tensor constant0 = result.constants().get("Variable"); + Tensor constant0 = tester.result().constants().get("Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = result.constants().get("Variable_1"); + Tensor constant1 = tester.result().constants().get("Variable_1"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d0", 10).build(), constant1.type()); assertEquals(10, constant1.size()); // Check signatures - assertEquals(1, result.signatures().size()); - TensorFlowModel.Signature signature = result.signatures().get("serving_default"); + assertEquals(1, tester.result().signatures().size()); + TensorFlowModel.Signature signature = tester.result().signatures().get("serving_default"); assertNotNull(signature); // ... signature inputs @@ -64,10 +53,10 @@ public class MnistSoftmaxImportTestCase { output.getRoot().toString()); // Test execution - tester.assertEqualResult(model, result, "Placeholder", "Variable/read"); - tester.assertEqualResult(model, result, "Placeholder", "Variable_1/read"); - tester.assertEqualResult(model, result, "Placeholder", "MatMul"); - tester.assertEqualResult(model, result, "Placeholder", "add"); + tester.assertEqualResult("Placeholder", "Variable/read"); + tester.assertEqualResult("Placeholder", "Variable_1/read"); + tester.assertEqualResult("Placeholder", "MatMul"); + tester.assertEqualResult("Placeholder", "add"); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java index 5e5b474e445..c6623296d04 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java @@ -14,18 +14,28 @@ import java.util.List; import static org.junit.Assert.assertEquals; /** - * Helper for TensorFlow import tests. - * This currently assumes the TensorFlow model takes a single input named Placeholder, of type tensor(d0[1],d1[784]) + * Helper for TensorFlow import tests: Imports a model and provides asserts on it. + * This currently assumes the TensorFlow model takes a single input of type tensor(d0[1],d1[784]) * * @author bratseth */ public class TensorFlowImportTester { - // Sizes of the "Placeholder" vector + private SavedModelBundle model; + private TensorFlowModel result; + + // Sizes of the input vector private final int d0Size = 1; private final int d1Size = 784; - public void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) { + public TensorFlowImportTester(String modelDir) { + model = SavedModelBundle.load(modelDir, "serve"); + result = new TensorFlowImporter().importModel(model); + } + + public TensorFlowModel result() { return result; } + + public void assertEqualResult(String inputName, String operationName) { Tensor tfResult = tensorFlowExecute(model, inputName, operationName); Context context = contextFrom(result); Tensor placeholder = placeholderArgument(); -- cgit v1.2.3