diff options
8 files changed, 36 insertions, 6 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java index 5630d3cc186..f246b87d9bf 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java @@ -1,12 +1,14 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container.ml; +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.models.evaluation.FunctionEvaluator; import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.tensor.Tensor; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; /** * Tests the ModelsEvaluatorTester. @@ -17,6 +19,7 @@ public class ModelsEvaluatorTest { @Test public void testModelsEvaluatorTester() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); ModelsEvaluator modelsEvaluator = ModelsEvaluatorTester.create("src/test/cfg/application/stateless_eval"); assertEquals(3, modelsEvaluator.models().size()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index 7627ba6319b..c60817704cd 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; import ai.vespa.models.handler.ModelsEvaluationHandler; @@ -30,6 +31,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; /** * Tests stateless model evaluation (turned on by the "model-evaluation" tag in "container") @@ -61,6 +63,7 @@ public class ModelEvaluationTest { @Test public void testMl_serving() throws IOException { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); Path appDir = Path.fromString("src/test/cfg/application/ml_serving"); Path storedAppDir = appDir.append("copy"); try { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java index 6e096dd68e4..7372e871e5d 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.models.evaluation.FunctionEvaluator; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; @@ -29,6 +30,7 @@ import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; /** * Tests stateless model evaluation (turned on by the "model-evaluation" tag in "container") @@ -39,7 +41,8 @@ import static org.junit.Assert.assertTrue; public class StatelessOnnxEvaluationTest { @Test - public void testStatelessOnnxModelNameCollision() throws IOException { + public void testStatelessOnnxModelNameCollision() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); Path appDir = Path.fromString("src/test/cfg/application/onnx_name_collision"); try { ImportedModelTester tester = new ImportedModelTester("onnx", appDir); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java index c5559d9bed5..27d1c08ea39 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java @@ -1,11 +1,10 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import com.yahoo.config.subscription.ConfigGetter; -import com.yahoo.config.subscription.FileSource; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; -import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; @@ -19,6 +18,7 @@ import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; /** * @author lesters @@ -30,6 +30,7 @@ public class OnnxEvaluatorTest { @Test public void testOnnxEvaluation() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); ModelsEvaluator models = createModels(); assertTrue(models.models().containsKey("add_mul")); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index bb442d76763..215e230b45d 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.handler; +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.models.evaluation.ModelsEvaluator; import ai.vespa.models.evaluation.RankProfilesConfigImporterWithMockedConstants; import com.yahoo.config.subscription.ConfigGetter; @@ -18,6 +19,8 @@ import org.junit.Test; import java.util.HashMap; import java.util.Map; +import static org.junit.Assume.assumeTrue; + public class ModelsEvaluationHandlerTest { private static final String MODELS_DIR = "src/test/resources/config/models/"; @@ -244,6 +247,7 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSavedEvaluateSpecificFunction() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); Map<String, String> properties = new HashMap<>(); properties.put("input", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval"; diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java index ccd303990c8..f065435ec15 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java @@ -1,12 +1,11 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.handler; +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.config.subscription.ConfigGetter; -import com.yahoo.config.subscription.FileSource; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; -import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; @@ -19,6 +18,8 @@ import java.io.File; import java.util.HashMap; import java.util.Map; +import static org.junit.Assume.assumeTrue; + public class OnnxEvaluationHandlerTest { private static HandlerTester handler; @@ -26,6 +27,7 @@ public class OnnxEvaluationHandlerTest { @BeforeClass static public void setUp() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); handler = new HandlerTester(createModels()); } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 87b964a2c56..c9ab9924214 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -93,4 +93,13 @@ public class OnnxEvaluator { } } + public static boolean isRuntimeAvailable() { + try { + new OnnxEvaluator(""); + return true; + } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) { + return false; + } + } + } diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index 4f8ea362467..6266dcef174 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -10,6 +10,7 @@ import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; /** * @author lesters @@ -17,7 +18,8 @@ import static org.junit.Assert.assertEquals; public class OnnxEvaluatorTest { @Test - public void testSimpleMoodel() { + public void testSimpleModel() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx"); // Input types @@ -42,6 +44,7 @@ public class OnnxEvaluatorTest { @Test public void testBatchDimension() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx"); // Input types @@ -60,6 +63,7 @@ public class OnnxEvaluatorTest { @Test public void testMatMul() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]"; String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]"; String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]"; @@ -68,6 +72,7 @@ public class OnnxEvaluatorTest { @Test public void testTypes() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]"); |