summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-03-30 15:56:09 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-03-30 15:56:12 +0200
commit5f641cbe5a558550b787945cea9ee4e20a3a659a (patch)
tree1a226bd2c744e885d50484ae3fac13d06de8d012
parent73f5b777ba374c3a0a92ca661ce8cbb35beb509f (diff)
Don't reuse runtime between methods
Caching evaluators between test methods may have unwanted side effects
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java39
1 files changed, 18 insertions, 21 deletions
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 5aba54de11b..b2d76baa566 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -5,30 +5,23 @@ package ai.vespa.modelintegration.evaluator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
-import org.junit.jupiter.api.BeforeAll;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
/**
* @author lesters
*/
public class OnnxEvaluatorTest {
- private static OnnxRuntime runtime;
-
- @BeforeAll
- public static void beforeAll() {
- if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime();
- }
-
@Test
public void testSimpleModel() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx");
// Input types
@@ -53,7 +46,8 @@ public class OnnxEvaluatorTest {
@Test
public void testBatchDimension() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx");
// Input types
@@ -72,21 +66,23 @@ public class OnnxEvaluatorTest {
@Test
public void testMatMul() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]";
String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]";
String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]";
- assertEvaluate("simple/matmul.onnx", expected, input1, input2);
+ assertEvaluate(runtime, "simple/matmul.onnx", expected, input1, input2);
}
@Test
public void testTypes() {
- assumeNotNull(runtime);
- assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
- assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
- assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
- assertEvaluate("cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]");
- assertEvaluate("cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]");
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
+ assertEvaluate(runtime, "add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
+ assertEvaluate(runtime, "add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
+ assertEvaluate(runtime, "add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
+ assertEvaluate(runtime, "cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]");
+ assertEvaluate(runtime, "cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]");
// ONNX Runtime 1.8.0 does not support much of bfloat16 yet
// assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]");
@@ -94,7 +90,8 @@ public class OnnxEvaluatorTest {
@Test
public void testNotIdentifiers() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx");
var inputInfo = evaluator.getInputInfo();
var outputInfo = evaluator.getOutputInfo();
@@ -159,7 +156,7 @@ public class OnnxEvaluatorTest {
assertEquals(3, allResults.size());
}
- private void assertEvaluate(String model, String output, String... input) {
+ private void assertEvaluate(OnnxRuntime runtime, String model, String output, String... input) {
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model);
Map<String, Tensor> inputs = new HashMap<>();
for (int i = 0; i < input.length; ++i) {