diff options
Diffstat (limited to 'model-integration/src/test')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java | 24 |
1 files changed, 5 insertions, 19 deletions
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index bfb441443fc..6266dcef174 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -16,15 +16,11 @@ import static org.junit.Assume.assumeTrue; * @author lesters */ public class OnnxEvaluatorTest { - private static final String simpleModelPath = "src/test/models/onnx/simple/simple.onnx"; - // Check if onnxruntime is available, needs to be done only once due to static instance - // variable in OrtEnvironment - private static final boolean onnxRuntimeIsAvailable = onnxRuntimeIsAvailable(); @Test public void testSimpleModel() { - assumeTrue(onnxRuntimeIsAvailable); - OnnxEvaluator evaluator = new OnnxEvaluator(simpleModelPath); + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx"); // Input types Map<String, TensorType> inputTypes = evaluator.getInputInfo(); @@ -48,7 +44,7 @@ public class OnnxEvaluatorTest { @Test public void testBatchDimension() { - assumeTrue(onnxRuntimeIsAvailable()); + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx"); // Input types @@ -67,7 +63,7 @@ public class OnnxEvaluatorTest { @Test public void testMatMul() { - assumeTrue(onnxRuntimeIsAvailable()); + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]"; String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]"; String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]"; @@ -76,7 +72,7 @@ public class OnnxEvaluatorTest { @Test public void testTypes() { - assumeTrue(onnxRuntimeIsAvailable()); + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]"); @@ -99,14 +95,4 @@ public class OnnxEvaluatorTest { assertEquals(expected.type().valueType(), result.type().valueType()); } - private static boolean onnxRuntimeIsAvailable() { - try { - new OnnxEvaluator(simpleModelPath); - return true; - } catch (UnsatisfiedLinkError e) { - System.out.println("onnxruntime not available, test will be ignored"); - return false; - } - } - } |