summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java10
3 files changed, 10 insertions, 4 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 09455abc380..96bf2c64485 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -24,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx").asNative();
// Check constants
assertEquals(2, model.largeConstants().size());
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
index f03c629df78..cd13afca77b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java
@@ -14,7 +14,7 @@ public class PyTorchImportTestCase extends TestableModel {
@Test
public void testPyTorchExport() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx").asNative();
Tensor onnxResult = evaluateVespa(model, "output", model.inputs());
assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult);
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
index 04db902073b..4c7f6139032 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -14,6 +14,8 @@ import com.yahoo.tensor.Tensor;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
/**
* @author lesters
@@ -24,6 +26,10 @@ public class SimpleImportTestCase {
public void testSimpleOnnxModelImport() {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx");
+ assertFalse(model.isNative());
+ model = model.asNative();
+ assertTrue(model.isNative());
+
MapContext context = new MapContext();
context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]")));
context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]")));
@@ -35,7 +41,7 @@ public class SimpleImportTestCase {
@Test
public void testGather() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx").asNative();
MapContext context = new MapContext();
context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]")));
@@ -49,7 +55,7 @@ public class SimpleImportTestCase {
@Test
public void testConcat() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx").asNative();
MapContext context = new MapContext();
context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]")));