summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2022-02-10 14:13:12 +0100
committerHarald Musum <musum@yahooinc.com>2022-02-10 14:13:12 +0100
commit4213e2ce69f07e9719958c280d331c5e28eaa568 (patch)
treeb347dd0b40dd08b82499cbd24a7dbe5c6f869973 /model-integration
parent9bfa86b49129ca89f59ef4f79428407a109db96c (diff)
Skip tests if onnxruntime unavailable
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java23
1 files changed, 21 insertions, 2 deletions
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 4f8ea362467..bfb441443fc 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -10,15 +10,21 @@ import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assume.assumeTrue;
/**
* @author lesters
*/
public class OnnxEvaluatorTest {
+ private static final String simpleModelPath = "src/test/models/onnx/simple/simple.onnx";
+ // Check if onnxruntime is available, needs to be done only once due to static instance
+ // variable in OrtEnvironment
+ private static final boolean onnxRuntimeIsAvailable = onnxRuntimeIsAvailable();
@Test
- public void testSimpleMoodel() {
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx");
+ public void testSimpleModel() {
+ assumeTrue(onnxRuntimeIsAvailable);
+ OnnxEvaluator evaluator = new OnnxEvaluator(simpleModelPath);
// Input types
Map<String, TensorType> inputTypes = evaluator.getInputInfo();
@@ -42,6 +48,7 @@ public class OnnxEvaluatorTest {
@Test
public void testBatchDimension() {
+ assumeTrue(onnxRuntimeIsAvailable());
OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx");
// Input types
@@ -60,6 +67,7 @@ public class OnnxEvaluatorTest {
@Test
public void testMatMul() {
+ assumeTrue(onnxRuntimeIsAvailable());
String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]";
String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]";
String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]";
@@ -68,6 +76,7 @@ public class OnnxEvaluatorTest {
@Test
public void testTypes() {
+ assumeTrue(onnxRuntimeIsAvailable());
assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
@@ -90,4 +99,14 @@ public class OnnxEvaluatorTest {
assertEquals(expected.type().valueType(), result.type().valueType());
}
+ private static boolean onnxRuntimeIsAvailable() {
+ try {
+ new OnnxEvaluator(simpleModelPath);
+ return true;
+ } catch (UnsatisfiedLinkError e) {
+ System.out.println("onnxruntime not available, test will be ignored");
+ return false;
+ }
+ }
+
}