summaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-01-25 13:01:18 +0100
committerJon Bratseth <bratseth@oath.com>2018-01-25 13:01:18 +0100
commita0f6d44333202731d07139ba6f0256dd4443da78 (patch)
tree7e8aace9cae769ba9b1a4e0990c02d23117380ca /searchlib/src/test/java/com/yahoo
parent01f2897bce20939c5716fc19876c2541a3d9bbc5 (diff)
Refactor: Extract test helper logic
Diffstat (limited to 'searchlib/src/test/java/com/yahoo')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java32
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java73
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java61
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java125
4 files changed, 166 insertions, 125 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
new file mode 100644
index 00000000000..770ba168f19
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
@@ -0,0 +1,32 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import org.junit.Test;
+import org.tensorflow.SavedModelBundle;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author lesters
+ */
+public class BatchNormImportTestCase {
+
+ @Test
+ public void testBatchNormImport() {
+ TensorFlowImportTester tester = new TensorFlowImportTester();
+ String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved";
+ SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ TensorFlowModel result = new TensorFlowImporter().importModel(model);
+ TensorFlowModel.Signature signature = result.signature("serving_default");
+
+ assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size());
+
+ RankingExpression output = signature.outputExpression("y");
+ assertNotNull(output);
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
+ tester.assertEqualResult(model, result, "X", output.getName());
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
new file mode 100644
index 00000000000..11063690e2a
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -0,0 +1,73 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.FloatBuffer;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author bratseth
+ */
+public class MnistSoftmaxImportTestCase {
+
+ @Test
+ public void testMnistSoftmaxImport() {
+ TensorFlowImportTester tester = new TensorFlowImportTester();
+ String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
+ SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ TensorFlowModel result = new TensorFlowImporter().importModel(model);
+
+ // Check constants
+ assertEquals(2, result.constants().size());
+
+ Tensor constant0 = result.constants().get("Variable");
+ assertNotNull(constant0);
+ assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ constant0.type());
+ assertEquals(7840, constant0.size());
+
+ Tensor constant1 = result.constants().get("Variable_1");
+ assertNotNull(constant1);
+ assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ constant1.type());
+ assertEquals(10, constant1.size());
+
+ // Check signatures
+ assertEquals(1, result.signatures().size());
+ TensorFlowModel.Signature signature = result.signatures().get("serving_default");
+ assertNotNull(signature);
+
+ // ... signature inputs
+ assertEquals(1, signature.inputs().size());
+ TensorType argument0 = signature.inputArgument("x");
+ assertNotNull(argument0);
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
+
+ // ... signature outputs
+ assertEquals(1, signature.outputs().size());
+ RankingExpression output = signature.outputExpression("y");
+ assertNotNull(output);
+ assertEquals("add", output.getName());
+ assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))",
+ output.getRoot().toString());
+
+ // Test execution
+ tester.assertEqualResult(model, result, "Placeholder", "Variable/read");
+ tester.assertEqualResult(model, result, "Placeholder", "Variable_1/read");
+ tester.assertEqualResult(model, result, "Placeholder", "MatMul");
+ tester.assertEqualResult(model, result, "Placeholder", "add");
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java
new file mode 100644
index 00000000000..5e5b474e445
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImportTester.java
@@ -0,0 +1,61 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.FloatBuffer;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Helper for TensorFlow import tests.
+ * This currently assumes the TensorFlow model takes a single input named Placeholder, of type tensor(d0[1],d1[784])
+ *
+ * @author bratseth
+ */
+public class TensorFlowImportTester {
+
+ // Sizes of the "Placeholder" vector
+ private final int d0Size = 1;
+ private final int d1Size = 784;
+
+ public void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) {
+ Tensor tfResult = tensorFlowExecute(model, inputName, operationName);
+ Context context = contextFrom(result);
+ Tensor placeholder = placeholderArgument();
+ context.put(inputName, new TensorValue(placeholder));
+ Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
+ assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
+ }
+
+ private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
+ Session.Runner runner = model.session().runner();
+ org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size },
+ FloatBuffer.allocate(d0Size * d1Size));
+ runner.feed(inputName, placeholder);
+ List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
+ assertEquals(1, results.size());
+ return new TensorConverter().toVespaTensor(results.get(0));
+ }
+
+ private Context contextFrom(TensorFlowModel result) {
+ MapContext context = new MapContext();
+ result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private Tensor placeholderArgument() {
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; d1++)
+ b.cell(0, d0, d1);
+ return b.build();
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
deleted file mode 100644
index c01b92fb1c7..00000000000
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
+++ /dev/null
@@ -1,125 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.junit.Test;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-
-import java.nio.FloatBuffer;
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-
-/**
- * @author bratseth
- */
-public class TensorflowImportTestCase {
-
- @Test
- public void testMnistSoftmaxImport() {
- String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
- SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- TensorFlowModel result = new TensorFlowImporter().importModel(model);
-
- // Check constants
- assertEquals(2, result.constants().size());
-
- Tensor constant0 = result.constants().get("Variable");
- assertNotNull(constant0);
- assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
- constant0.type());
- assertEquals(7840, constant0.size());
-
- Tensor constant1 = result.constants().get("Variable_1");
- assertNotNull(constant1);
- assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
- constant1.type());
- assertEquals(10, constant1.size());
-
- // Check signatures
- assertEquals(1, result.signatures().size());
- TensorFlowModel.Signature signature = result.signatures().get("serving_default");
- assertNotNull(signature);
-
- // ... signature inputs
- assertEquals(1, signature.inputs().size());
- TensorType argument0 = signature.inputArgument("x");
- assertNotNull(argument0);
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
-
- // ... signature outputs
- assertEquals(1, signature.outputs().size());
- RankingExpression output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("add", output.getName());
- assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))",
- output.getRoot().toString());
-
- // Test execution
- assertEqualResult(model, result, "Placeholder", "Variable/read");
- assertEqualResult(model, result, "Placeholder", "Variable_1/read");
- assertEqualResult(model, result, "Placeholder", "MatMul");
- assertEqualResult(model, result, "Placeholder", "add");
- }
-
- @Test
- public void testBatchNormImport() {
- String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved";
- SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- TensorFlowModel result = new TensorFlowImporter().importModel(model);
- TensorFlowModel.Signature signature = result.signature("serving_default");
-
- assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size());
-
- RankingExpression output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
- assertEqualResult(model, result, "X", output.getName());
-
- }
-
- private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) {
- Tensor tfResult = tensorFlowExecute(model, inputName, operationName);
- Context context = contextFrom(result);
- Tensor placeholder = placeholderArgument();
- context.put(inputName, new TensorValue(placeholder));
- Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
- assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
- }
-
- // Sizes of the "Placeholder" vector
- private final int d0Size = 1;
- private final int d1Size = 784;
-
- private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
- Session.Runner runner = model.session().runner();
- org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size },
- FloatBuffer.allocate(d0Size * d1Size));
- runner.feed(inputName, placeholder);
- List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
- assertEquals(1, results.size());
- return new TensorConverter().toVespaTensor(results.get(0));
- }
-
- private Context contextFrom(TensorFlowModel result) {
- MapContext context = new MapContext();
- result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
- return context;
- }
-
- private Tensor placeholderArgument() {
- Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
- for (int d0 = 0; d0 < d0Size; d0++)
- for (int d1 = 0; d1 < d1Size; d1++)
- b.cell(0, d0, d1);
- return b.build();
- }
-
-}