diff options
Diffstat (limited to 'model-integration/src/test/java/ai')
3 files changed, 10 insertions, 4 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 09455abc380..96bf2c64485 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -24,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx").asNative(); // Check constants assertEquals(2, model.largeConstants().size()); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java index f03c629df78..cd13afca77b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/PyTorchImportTestCase.java @@ -14,7 +14,7 @@ public class PyTorchImportTestCase extends TestableModel { @Test public void testPyTorchExport() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/pytorch.onnx").asNative(); Tensor onnxResult = evaluateVespa(model, "output", model.inputs()); assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java index 04db902073b..4c7f6139032 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -14,6 +14,8 @@ import com.yahoo.tensor.Tensor; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; /** * @author lesters @@ -24,6 +26,10 @@ public class SimpleImportTestCase { public void testSimpleOnnxModelImport() { ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx"); + assertFalse(model.isNative()); + model = model.asNative(); + assertTrue(model.isNative()); + MapContext context = new MapContext(); context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]"))); context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]"))); @@ -35,7 +41,7 @@ public class SimpleImportTestCase { @Test public void testGather() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx").asNative(); MapContext context = new MapContext(); context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"))); @@ -49,7 +55,7 @@ public class SimpleImportTestCase { @Test public void testConcat() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx").asNative(); MapContext context = new MapContext(); context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]"))); |