summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-12 19:33:11 +0200
committerLester Solbakken <lesters@oath.com>2020-06-12 19:33:11 +0200
commit3905dbf4455c4426f86f08ec925d7f66a06e85b8 (patch)
tree621bbb8832ec4a0c8e674c709bcc99aecdfbc528 /model-integration/src/test
parent599ad95a4e5003b903e464f91210892c1bee44ce (diff)
Revert "Import Tensorflow models vis ONNX conversion"
This reverts commit 0a886d74d4c9ffde41eef1f7e3c186b60b9f3726.
Diffstat (limited to 'model-integration/src/test')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java43
1 files changed, 43 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 09455abc380..35c853bd746 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -55,4 +55,47 @@ public class OnnxMnistSoftmaxImportTestCase {
assertEquals("{Placeholder=tensor<float>(d0[],d1[784])}", output.argumentTypes().toString());
}
+ @Test
+ public void testComparisonBetweenOnnxAndTensorflow() {
+ String tfModelPath = "src/test/models/tensorflow/mnist_softmax/saved";
+ String onnxModelPath = "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx";
+
+ Tensor argument = placeholderArgument();
+ Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add");
+ Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add");
+
+ assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult);
+ }
+
+ private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) {
+ ImportedModel model = new TensorFlowImporter().importModel("test", path);
+ return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
+ }
+
+ private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
+ ImportedModel model = new OnnxImporter().importModel("test", path);
+ return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
+ }
+
+ private Tensor evaluateExpression(RankingExpression expression, Context context, Tensor argument, String input) {
+ context.put(input, new TensorValue(argument));
+ return expression.evaluate(context).asTensor();
+ }
+
+ private Context contextFrom(ImportedModel result) {
+ MapContext context = new MapContext();
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ return context;
+ }
+
+ private Tensor placeholderArgument() {
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", 784).build());
+ for (int d0 = 0; d0 < 1; d0++)
+ for (int d1 = 0; d1 < 784; d1++)
+ b.cell(d1 * 1.0 / 784, d0, d1);
+ return b.build();
+ }
+
+
}