diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-03-30 15:56:09 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-03-30 15:56:12 +0200 |
commit | 5f641cbe5a558550b787945cea9ee4e20a3a659a (patch) | |
tree | 1a226bd2c744e885d50484ae3fac13d06de8d012 | |
parent | 73f5b777ba374c3a0a92ca661ce8cbb35beb509f (diff) |
Don't reuse runtime between methods
Caching evaluators between test methods may have unwanted side effects
-rw-r--r-- | model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java | 39 |
1 files changed, 18 insertions, 21 deletions
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index 5aba54de11b..b2d76baa566 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -5,30 +5,23 @@ package ai.vespa.modelintegration.evaluator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; -import org.junit.jupiter.api.BeforeAll; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeNotNull; +import static org.junit.Assume.assumeTrue; /** * @author lesters */ public class OnnxEvaluatorTest { - private static OnnxRuntime runtime; - - @BeforeAll - public static void beforeAll() { - if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime(); - } - @Test public void testSimpleModel() { - assumeNotNull(runtime); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx"); // Input types @@ -53,7 +46,8 @@ public class OnnxEvaluatorTest { @Test public void testBatchDimension() { - assumeNotNull(runtime); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx"); // Input types @@ -72,21 +66,23 @@ public class OnnxEvaluatorTest { @Test public void testMatMul() { - assumeNotNull(runtime); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]"; String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]"; String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]"; - assertEvaluate("simple/matmul.onnx", expected, input1, input2); + assertEvaluate(runtime, "simple/matmul.onnx", expected, input1, input2); } @Test public void testTypes() { - assumeNotNull(runtime); - assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); - assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); - assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]"); - assertEvaluate("cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]"); - assertEvaluate("cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]"); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + assertEvaluate(runtime, "add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); + assertEvaluate(runtime, "add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); + assertEvaluate(runtime, "add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]"); + assertEvaluate(runtime, "cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]"); + assertEvaluate(runtime, "cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]"); // ONNX Runtime 1.8.0 does not support much of bfloat16 yet // assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]"); @@ -94,7 +90,8 @@ public class OnnxEvaluatorTest { @Test public void testNotIdentifiers() { - assumeNotNull(runtime); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx"); var inputInfo = evaluator.getInputInfo(); var outputInfo = evaluator.getOutputInfo(); @@ -159,7 +156,7 @@ public class OnnxEvaluatorTest { assertEquals(3, allResults.size()); } - private void assertEvaluate(String model, String output, String... input) { + private void assertEvaluate(OnnxRuntime runtime, String model, String output, String... input) { OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model); Map<String, Tensor> inputs = new HashMap<>(); for (int i = 0; i < input.length; ++i) { |