summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-05-07 14:55:45 +0200
committerLester Solbakken <lesters@oath.com>2019-05-07 14:55:45 +0200
commit35995299077105e10f08187c508d5ffafaef50b6 (patch)
treef9666671404895421bbd0b5ed38ecad6541acace /model-integration/src/test
parent50ccc80a1242a1013c0510e11ff3f8fc18f2ded0 (diff)
ONNX: use node output as name if not explicitly set
Diffstat (limited to 'model-integration/src/test')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java41
-rw-r--r--model-integration/src/test/models/onnx/simple/simple.onnx23
-rwxr-xr-xmodel-integration/src/test/models/onnx/simple/simple.py32
3 files changed, 96 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
new file mode 100644
index 00000000000..d1dea730da5
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -0,0 +1,41 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class SimpleImportTestCase {
+
+ @Test
+ public void testSimpleOnnxModelImport() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx");
+
+ MapContext context = new MapContext();
+ context.put("query_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[4])")).
+ cell(0.1, 0, 0).
+ cell(0.2, 0, 1).
+ cell(0.3, 0, 2).
+ cell(0.4, 0, 3).build()));
+ context.put("attribute_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[4],d1[1])")).
+ cell(0.1, 0, 0).
+ cell(0.2, 1, 0).
+ cell(0.3, 2, 0).
+ cell(0.4, 3, 0).build()));
+ context.put("bias_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[1])")).
+ cell(1.0, 0, 0).build()));
+
+ Tensor result = model.expressions().get("output").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor(d0[1],d1[1]):{{d0:0,d1:0}:1.3}"));
+ }
+
+}
diff --git a/model-integration/src/test/models/onnx/simple/simple.onnx b/model-integration/src/test/models/onnx/simple/simple.onnx
new file mode 100644
index 00000000000..1c746c90efa
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/simple.onnx
@@ -0,0 +1,23 @@
+ simple.py:ã
+0
+ query_tensor
+attribute_tensormatmul"MatMul
+"
+matmul
+ bias_tensoroutput"addsimple_scoringZ
+ query_tensor
+ 
+
+Z"
+attribute_tensor
+ 
+
+Z
+ bias_tensor
+ 
+
+b
+output
+ 
+
+B
diff --git a/model-integration/src/test/models/onnx/simple/simple.py b/model-integration/src/test/models/onnx/simple/simple.py
new file mode 100755
index 00000000000..4471ed812b8
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/simple.py
@@ -0,0 +1,32 @@
+import onnx
+from onnx import helper, TensorProto
+
+QUERY_TENSOR = helper.make_tensor_value_info('query_tensor', TensorProto.FLOAT, [1, 4])
+ATTRIBUTE_TENSOR = helper.make_tensor_value_info('attribute_tensor', TensorProto.FLOAT, [4, 1])
+BIAS_TENSOR = helper.make_tensor_value_info('bias_tensor', TensorProto.FLOAT, [1, 1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 1])
+
+nodes = [
+ helper.make_node(
+ 'MatMul',
+ ['query_tensor', 'attribute_tensor'],
+ ['matmul'],
+ ),
+ helper.make_node(
+ 'add',
+ ['matmul', 'bias_tensor'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'simple_scoring',
+ [
+ QUERY_TENSOR,
+ ATTRIBUTE_TENSOR,
+ BIAS_TENSOR,
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='simple.py')
+onnx.save(model_def, 'simple.onnx')