aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-01-25 13:22:47 +0100
committerJon Bratseth <bratseth@oath.com>2018-01-25 13:22:47 +0100
commit880149e6380a52edb089d59752a8fd4ea669e400 (patch)
treeb2910aa28821f064eecb4efcf977afc7f2753a4f /searchlib/src/test/java/com
parenta0f6d44333202731d07139ba6f0256dd4443da78 (diff)
Refactor: Move state to helper
Diffstat (limited to 'searchlib/src/test/java/com')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java12
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java31
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java18
3 files changed, 29 insertions, 32 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
index 770ba168f19..3d028b0775e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
@@ -15,18 +15,16 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
- TensorFlowImportTester tester = new TensorFlowImportTester();
- String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved";
- SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- TensorFlowModel result = new TensorFlowImporter().importModel(model);
- TensorFlowModel.Signature signature = result.signature("serving_default");
+ TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/batch_norm/saved");
+ TensorFlowModel.Signature signature = tester.result().signature("serving_default");
- assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size());
+ assertEquals("Has skipped outputs",
+ 0, tester.result().signature("serving_default").skippedOutputs().size());
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
- tester.assertEqualResult(model, result, "X", output.getName());
+ tester.assertEqualResult("X", output.getName());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index 11063690e2a..044a6917e00 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -2,17 +2,9 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-
-import java.nio.FloatBuffer;
-import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -24,29 +16,26 @@ public class MnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- TensorFlowImportTester tester = new TensorFlowImportTester();
- String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
- SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- TensorFlowModel result = new TensorFlowImporter().importModel(model);
+ TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/mnist_softmax/saved");
// Check constants
- assertEquals(2, result.constants().size());
+ assertEquals(2, tester.result().constants().size());
- Tensor constant0 = result.constants().get("Variable");
+ Tensor constant0 = tester.result().constants().get("Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = result.constants().get("Variable_1");
+ Tensor constant1 = tester.result().constants().get("Variable_1");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
constant1.type());
assertEquals(10, constant1.size());
// Check signatures
- assertEquals(1, result.signatures().size());
- TensorFlowModel.Signature signature = result.signatures().get("serving_default");
+ assertEquals(1, tester.result().signatures().size());
+ TensorFlowModel.Signature signature = tester.result().signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
@@ -64,10 +53,10 @@ public class MnistSoftmaxImportTestCase {
output.getRoot().toString());
// Test execution
- tester.assertEqualResult(model, result, "Placeholder", "Variable/read");
- tester.assertEqualResult(model, result, "Placeholder", "Variable_1/read");
- tester.assertEqualResult(model, result, "Placeholder", "MatMul");
- tester.assertEqualResult(model, result, "Placeholder", "add");
+ tester.assertEqualResult("Placeholder", "Variable/read");
+ tester.assertEqualResult("Placeholder", "Variable_1/read");
+ tester.assertEqualResult("Placeholder", "MatMul");
+ tester.assertEqualResult("Placeholder", "add");
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java
index 5e5b474e445..c6623296d04 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java
@@ -14,18 +14,28 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
/**
- * Helper for TensorFlow import tests.
- * This currently assumes the TensorFlow model takes a single input named Placeholder, of type tensor(d0[1],d1[784])
+ * Helper for TensorFlow import tests: Imports a model and provides asserts on it.
+ * This currently assumes the TensorFlow model takes a single input of type tensor(d0[1],d1[784])
*
* @author bratseth
*/
public class TensorFlowImportTester {
- // Sizes of the "Placeholder" vector
+ private SavedModelBundle model;
+ private TensorFlowModel result;
+
+ // Sizes of the input vector
private final int d0Size = 1;
private final int d1Size = 784;
- public void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) {
+ public TensorFlowImportTester(String modelDir) {
+ model = SavedModelBundle.load(modelDir, "serve");
+ result = new TensorFlowImporter().importModel(model);
+ }
+
+ public TensorFlowModel result() { return result; }
+
+ public void assertEqualResult(String inputName, String operationName) {
Tensor tfResult = tensorFlowExecute(model, inputName, operationName);
Context context = contextFrom(result);
Tensor placeholder = placeholderArgument();