aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java22
1 files changed, 22 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
index 9631bddd93d..abecf4f5cb4 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -35,6 +35,15 @@ public class SimpleImportTestCase {
}
@Test
+ public void testConstant() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/const.onnx");
+
+ MapContext context = new MapContext();
+ Tensor result = model.expressions().get("output").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor():0.42"));
+ }
+
+ @Test
public void testGather() {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx");
@@ -48,6 +57,19 @@ public class SimpleImportTestCase {
assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"));
}
+ @Test
+ public void testConcat() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx");
+
+ MapContext context = new MapContext();
+ context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]")));
+ context.put("j", new TensorValue(Tensor.from("tensor(d0[1]):[2]")));
+ context.put("k", new TensorValue(Tensor.from("tensor(d0[1]):[3]")));
+
+ Tensor result = model.expressions().get("y").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor(d0[3]):[1, 2, 3]"));
+ }
+
private void evaluateFunction(Context context, ImportedModel model, String functionName) {
if (!context.names().contains(functionName)) {
RankingExpression e = RankingExpression.from(model.functions().get(functionName));