diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-01-25 13:25:26 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-01-25 13:25:26 +0100 |
commit | 31b00d9cdbba6081c18dce9e2dae76c33e580557 (patch) | |
tree | 69e5a0cf56d321f99754b32e5252af9318a01484 | |
parent | 880149e6380a52edb089d59752a8fd4ea669e400 (diff) |
Refactor: Rename
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java | 9 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java | 20 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java) | 20 |
3 files changed, 24 insertions, 25 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java index 3d028b0775e..c6ee586a78c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java @@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; -import org.tensorflow.SavedModelBundle; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -15,16 +14,16 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { - TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/batch_norm/saved"); - TensorFlowModel.Signature signature = tester.result().signature("serving_default"); + TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/batch_norm/saved"); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", - 0, tester.result().signature("serving_default").skippedOutputs().size()); + 0, model.get().signature("serving_default").skippedOutputs().size()); RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); - tester.assertEqualResult("X", output.getName()); + model.assertEqualResult("X", output.getName()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index 044a6917e00..f12b9a2c628 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -16,26 +16,26 @@ public class MnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - TensorFlowImportTester tester = new TensorFlowImportTester("src/test/files/integration/tensorflow/mnist_softmax/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved"); // Check constants - assertEquals(2, tester.result().constants().size()); + assertEquals(2, model.get().constants().size()); - Tensor constant0 = tester.result().constants().get("Variable"); + Tensor constant0 = model.get().constants().get("Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = tester.result().constants().get("Variable_1"); + Tensor constant1 = model.get().constants().get("Variable_1"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d0", 10).build(), constant1.type()); assertEquals(10, constant1.size()); // Check signatures - assertEquals(1, tester.result().signatures().size()); - TensorFlowModel.Signature signature = tester.result().signatures().get("serving_default"); + assertEquals(1, model.get().signatures().size()); + TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); assertNotNull(signature); // ... signature inputs @@ -53,10 +53,10 @@ public class MnistSoftmaxImportTestCase { output.getRoot().toString()); // Test execution - tester.assertEqualResult("Placeholder", "Variable/read"); - tester.assertEqualResult("Placeholder", "Variable_1/read"); - tester.assertEqualResult("Placeholder", "MatMul"); - tester.assertEqualResult("Placeholder", "add"); + model.assertEqualResult("Placeholder", "Variable/read"); + model.assertEqualResult("Placeholder", "Variable_1/read"); + model.assertEqualResult("Placeholder", "MatMul"); + model.assertEqualResult("Placeholder", "add"); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index c6623296d04..186717d24cd 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -19,28 +19,28 @@ import static org.junit.Assert.assertEquals; * * @author bratseth */ -public class TensorFlowImportTester { +public class TestableTensorFlowModel { - private SavedModelBundle model; - private TensorFlowModel result; + private SavedModelBundle tensorFlowModel; + private TensorFlowModel model; // Sizes of the input vector private final int d0Size = 1; private final int d1Size = 784; - public TensorFlowImportTester(String modelDir) { - model = SavedModelBundle.load(modelDir, "serve"); - result = new TensorFlowImporter().importModel(model); + public TestableTensorFlowModel(String modelDir) { + tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); + model = new TensorFlowImporter().importModel(tensorFlowModel); } - public TensorFlowModel result() { return result; } + public TensorFlowModel get() { return model; } public void assertEqualResult(String inputName, String operationName) { - Tensor tfResult = tensorFlowExecute(model, inputName, operationName); - Context context = contextFrom(result); + Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); + Context context = contextFrom(model); Tensor placeholder = placeholderArgument(); context.put(inputName, new TensorValue(placeholder)); - Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); + Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); } |