summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-12-12 09:42:54 -0800
committerLester Solbakken <lesters@oath.com>2019-12-12 09:42:54 -0800
commitf99ef6d4d400be906d26fbf59762bc27553ed32b (patch)
treead25e844c0237673b4549b7d30fd1e420aebb7d3 /model-integration/src/test/java/ai
parent14b0a54720077edf95d270741d207f9015a1c7aa (diff)
Initial conversion of TF to ONNX for testing
Diffstat (limited to 'model-integration/src/test/java/ai')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Tf2OnnxImportTestCase.java59
1 files changed, 59 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Tf2OnnxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Tf2OnnxImportTestCase.java
new file mode 100644
index 00000000000..dc53678483d
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Tf2OnnxImportTestCase.java
@@ -0,0 +1,59 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.tensorflow;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class Tf2OnnxImportTestCase {
+
+ @Test
+ public void testConversionFromTensorFlowToOnnx() {
+ String modelPath = "src/test/models/tensorflow/mnist_softmax/saved";
+ String modelPathToConvert = "src/test/models/tensorflow/mnist_softmax/tf_2_onnx";
+
+ Tensor argument = placeholderArgument();
+ Tensor tensorFlowResult = evaluateTensorFlowModel(modelPath, argument, "Placeholder", "add");
+ Tensor tf2OnnxResult = evaluateTensorFlowModel(modelPathToConvert, argument, "Placeholder", "add");
+
+ assertEquals("Operation 'add' produces equal results", tensorFlowResult, tf2OnnxResult);
+ }
+
+ private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) {
+ ImportedModel model = new TensorFlowImporter().importModel("test", path);
+ String outputExpr = model.signatures().values().iterator().next().outputs().values().iterator().next();
+ return evaluateExpression(model.expressions().get(outputExpr), contextFrom(model), argument, input);
+ }
+
+ private Tensor evaluateExpression(RankingExpression expression, Context context, Tensor argument, String input) {
+ context.put(input, new TensorValue(argument));
+ return expression.evaluate(context).asTensor();
+ }
+
+ private Context contextFrom(ImportedModel result) {
+ MapContext context = new MapContext();
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ return context;
+ }
+
+ private Tensor placeholderArgument() {
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", 784).build());
+ for (int d0 = 0; d0 < 1; d0++)
+ for (int d1 = 0; d1 < 784; d1++)
+ b.cell(d1 * 1.0 / 784, d0, d1);
+ return b.build();
+ }
+
+
+}