summaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com
diff options
context:
space:
mode:
authorHarald Musum <musum@oath.com>2018-06-06 14:23:55 +0200
committerGitHub <noreply@github.com>2018-06-06 14:23:55 +0200
commit681963959794b47102d1a1cf72f215c72b0e2b51 (patch)
tree12ea280192a44f26b9718018c7cfb39b0c4c4735 /searchlib/src/test/java/com
parent240176d60c44507f4e6733c7512620e80554c8de (diff)
Revert "Refactor ONNX and TF import to use same code base"
Diffstat (limited to 'searchlib/src/test/java/com')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java)22
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java)6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java)14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java)2
8 files changed, 35 insertions, 25 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
index a7926cd2e02..4b68cd40a08 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -1,9 +1,11 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.onnx;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -22,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() throws IOException {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
+ OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
@@ -46,7 +48,7 @@ public class OnnxMnistSoftmaxImportTestCase {
model.requiredMacros().get("Placeholder"));
// Check outputs
- RankingExpression output = model.defaultSignature().outputExpression("add");
+ RankingExpression output = model.outputExpression("add");
assertNotNull(output);
assertEquals("add", output.getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
@@ -66,12 +68,13 @@ public class OnnxMnistSoftmaxImportTestCase {
}
private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) {
- ImportedModel model = new TensorFlowImporter().importModel("test", path);
+ SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve");
+ TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
- ImportedModel model = new OnnxImporter().importModel("test", path);
+ OnnxModel model = new OnnxImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
@@ -80,7 +83,14 @@ public class OnnxMnistSoftmaxImportTestCase {
return expression.evaluate(context).asTensor();
}
- private Context contextFrom(ImportedModel result) {
+ private Context contextFrom(TensorFlowModel result) {
+ MapContext context = new MapContext();
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private Context contextFrom(OnnxModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
index bf9684082f4..0f5eec93feb 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -15,7 +15,7 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
- ImportedModel.Signature signature = model.get().signature("serving_default");
+ TensorFlowModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
index c8c7ec798bb..74b0d11f1d6 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import org.junit.Test;
import static org.junit.Assert.assertTrue;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index a63c7346335..50a467ec581 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.TensorType;
@@ -24,7 +24,7 @@ public class DropoutImportTestCase {
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
model.get().requiredMacros().get("X"));
- ImportedModel.Signature signature = model.get().signature("serving_default");
+ TensorFlowModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/Maximum", output.getName());
- assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index bd7644be23b..9f919c452d6 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -45,7 +45,7 @@ public class MnistSoftmaxImportTestCase {
// Check signatures
assertEquals(1, model.get().signatures().size());
- ImportedModel.Signature signature = model.get().signatures().get("serving_default");
+ TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
index b2443082ab1..beec2ab1ead 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index 723c5f27914..7ca16939477 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -1,11 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals;
public class TestableTensorFlowModel {
private SavedModelBundle tensorFlowModel;
- private ImportedModel model;
+ private TensorFlowModel model;
// Sizes of the input vector
private final int d0Size = 1;
@@ -39,7 +39,7 @@ public class TestableTensorFlowModel {
model = new TensorFlowImporter().importModel(modelName, tensorFlowModel);
}
- public ImportedModel get() { return model; }
+ public TensorFlowModel get() { return model; }
public void assertEqualResult(String inputName, String operationName) {
Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName);
@@ -66,7 +66,7 @@ public class TestableTensorFlowModel {
return TensorConverter.toVespaTensor(results.get(0));
}
- private Context contextFrom(ImportedModel result) {
+ private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
@@ -81,7 +81,7 @@ public class TestableTensorFlowModel {
return b.build();
}
- private void evaluateMacro(Context context, ImportedModel model, String macroName) {
+ private void evaluateMacro(Context context, TensorFlowModel model, String macroName) {
if (!context.names().contains(macroName)) {
RankingExpression e = model.macros().get(macroName);
evaluateMacroDependencies(context, model, e.getRoot());
@@ -89,7 +89,7 @@ public class TestableTensorFlowModel {
}
}
- private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) {
+ private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) {
if (node instanceof ReferenceNode) {
String name = node.toString();
if (model.macros().containsKey(name)) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java
index f94098e6255..051c2c60c95 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java
@@ -1,4 +1,4 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import org.junit.Test;