From 4213e2ce69f07e9719958c280d331c5e28eaa568 Mon Sep 17 00:00:00 2001 From: Harald Musum Date: Thu, 10 Feb 2022 14:13:12 +0100 Subject: Skip tests if onnxruntime unavailable --- .../evaluator/OnnxEvaluatorTest.java | 23 ++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) (limited to 'model-integration/src/test/java/ai') diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index 4f8ea362467..bfb441443fc 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -10,15 +10,21 @@ import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; /** * @author lesters */ public class OnnxEvaluatorTest { + private static final String simpleModelPath = "src/test/models/onnx/simple/simple.onnx"; + // Check if onnxruntime is available, needs to be done only once due to static instance + // variable in OrtEnvironment + private static final boolean onnxRuntimeIsAvailable = onnxRuntimeIsAvailable(); @Test - public void testSimpleMoodel() { - OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx"); + public void testSimpleModel() { + assumeTrue(onnxRuntimeIsAvailable); + OnnxEvaluator evaluator = new OnnxEvaluator(simpleModelPath); // Input types Map inputTypes = evaluator.getInputInfo(); @@ -42,6 +48,7 @@ public class OnnxEvaluatorTest { @Test public void testBatchDimension() { + assumeTrue(onnxRuntimeIsAvailable()); OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx"); // Input types @@ -60,6 +67,7 @@ public class OnnxEvaluatorTest { @Test public void testMatMul() { + assumeTrue(onnxRuntimeIsAvailable()); String expected = "tensor(d0[2],d1[4]):[38,44,50,56,83,98,113,128]"; String input1 = "tensor(d0[2],d1[3]):[1,2,3,4,5,6]"; String input2 = "tensor(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]"; @@ -68,6 +76,7 @@ public class OnnxEvaluatorTest { @Test public void testTypes() { + assumeTrue(onnxRuntimeIsAvailable()); assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); assertEvaluate("add_float.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); assertEvaluate("add_int64.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); @@ -90,4 +99,14 @@ public class OnnxEvaluatorTest { assertEquals(expected.type().valueType(), result.type().valueType()); } + private static boolean onnxRuntimeIsAvailable() { + try { + new OnnxEvaluator(simpleModelPath); + return true; + } catch (UnsatisfiedLinkError e) { + System.out.println("onnxruntime not available, test will be ignored"); + return false; + } + } + } -- cgit v1.2.3