summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-01-25 13:25:26 +0100
committerJon Bratseth <bratseth@oath.com>2018-01-25 13:25:26 +0100
commit31b00d9cdbba6081c18dce9e2dae76c33e580557 (patch)
tree69e5a0cf56d321f99754b32e5252af9318a01484 /searchlib
parent880149e6380a52edb089d59752a8fd4ea669e400 (diff)
Refactor: Rename
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java9
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java20
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java)20
3 files changed, 24 insertions, 25 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
index 3d028b0775e..c6ee586a78c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
@@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
-import org.tensorflow.SavedModelBundle;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -15,16 +14,16 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
- TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/batch_norm/saved");
- TensorFlowModel.Signature signature = tester.result().signature("serving_default");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/batch_norm/saved");
+ TensorFlowModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
- 0, tester.result().signature("serving_default").skippedOutputs().size());
+ 0, model.get().signature("serving_default").skippedOutputs().size());
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
- tester.assertEqualResult("X", output.getName());
+ model.assertEqualResult("X", output.getName());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index 044a6917e00..f12b9a2c628 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -16,26 +16,26 @@ public class MnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/mnist_softmax/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved");
// Check constants
- assertEquals(2, tester.result().constants().size());
+ assertEquals(2, model.get().constants().size());
- Tensor constant0 = tester.result().constants().get("Variable");
+ Tensor constant0 = model.get().constants().get("Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = tester.result().constants().get("Variable_1");
+ Tensor constant1 = model.get().constants().get("Variable_1");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
constant1.type());
assertEquals(10, constant1.size());
// Check signatures
- assertEquals(1, tester.result().signatures().size());
- TensorFlowModel.Signature signature = tester.result().signatures().get("serving_default");
+ assertEquals(1, model.get().signatures().size());
+ TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
@@ -53,10 +53,10 @@ public class MnistSoftmaxImportTestCase {
output.getRoot().toString());
// Test execution
- tester.assertEqualResult("Placeholder", "Variable/read");
- tester.assertEqualResult("Placeholder", "Variable_1/read");
- tester.assertEqualResult("Placeholder", "MatMul");
- tester.assertEqualResult("Placeholder", "add");
+ model.assertEqualResult("Placeholder", "Variable/read");
+ model.assertEqualResult("Placeholder", "Variable_1/read");
+ model.assertEqualResult("Placeholder", "MatMul");
+ model.assertEqualResult("Placeholder", "add");
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index c6623296d04..186717d24cd 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -19,28 +19,28 @@ import static org.junit.Assert.assertEquals;
*
* @author bratseth
*/
-public class TensorFlowImportTester {
+public class TestableTensorFlowModel {
- private SavedModelBundle model;
- private TensorFlowModel result;
+ private SavedModelBundle tensorFlowModel;
+ private TensorFlowModel model;
// Sizes of the input vector
private final int d0Size = 1;
private final int d1Size = 784;
- public TensorFlowImportTester(String modelDir) {
- model = SavedModelBundle.load(modelDir, "serve");
- result = new TensorFlowImporter().importModel(model);
+ public TestableTensorFlowModel(String modelDir) {
+ tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
+ model = new TensorFlowImporter().importModel(tensorFlowModel);
}
- public TensorFlowModel result() { return result; }
+ public TensorFlowModel get() { return model; }
public void assertEqualResult(String inputName, String operationName) {
- Tensor tfResult = tensorFlowExecute(model, inputName, operationName);
- Context context = contextFrom(result);
+ Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName);
+ Context context = contextFrom(model);
Tensor placeholder = placeholderArgument();
context.put(inputName, new TensorValue(placeholder));
- Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
+ Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
}