diff options
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java index 9631bddd93d..abecf4f5cb4 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -35,6 +35,15 @@ public class SimpleImportTestCase { } @Test + public void testConstant() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/const.onnx"); + + MapContext context = new MapContext(); + Tensor result = model.expressions().get("output").evaluate(context).asTensor(); + assertEquals(result, Tensor.from("tensor():0.42")); + } + + @Test public void testGather() { ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx"); @@ -48,6 +57,19 @@ public class SimpleImportTestCase { assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]")); } + @Test + public void testConcat() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx"); + + MapContext context = new MapContext(); + context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]"))); + context.put("j", new TensorValue(Tensor.from("tensor(d0[1]):[2]"))); + context.put("k", new TensorValue(Tensor.from("tensor(d0[1]):[3]"))); + + Tensor result = model.expressions().get("y").evaluate(context).asTensor(); + assertEquals(result, Tensor.from("tensor(d0[3]):[1, 2, 3]")); + } + private void evaluateFunction(Context context, ImportedModel model, String functionName) { if (!context.names().contains(functionName)) { RankingExpression e = RankingExpression.from(model.functions().get(functionName)); |