diff options
author | Lester Solbakken <lesters@oath.com> | 2022-04-01 09:49:51 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-04-01 09:49:51 +0200 |
commit | 2157f9adc8d1a9d553eb9d7fbf202530faf8af43 (patch) | |
tree | 7114f26db5741519b7bbac0033ad9ba8d208fd98 /model-integration/src | |
parent | 4cc5a56f7006c40b059855f345ded365ace8550c (diff) |
Skip test if onnx is not available (arm64)
Diffstat (limited to 'model-integration/src')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java | 6 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java | 12 |
2 files changed, 15 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index c9ab9924214..bdcceddb04f 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -94,8 +94,12 @@ public class OnnxEvaluator { } public static boolean isRuntimeAvailable() { + return isRuntimeAvailable(""); + } + + public static boolean isRuntimeAvailable(String modelPath) { try { - new OnnxEvaluator(""); + new OnnxEvaluator(modelPath); return true; } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) { return false; diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index 519f24795ca..0ecc78f7668 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -1,5 +1,6 @@ package ai.vespa.embedding; +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import com.yahoo.config.UrlReference; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -8,14 +9,21 @@ import org.junit.Test; import java.util.List; import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; public class BertBaseEmbedderTest { + + @Test public void testEmbedder() { + String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; + String modelPath = "src/test/models/onnx/transformer/dummy_transformer.onnx"; + assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); - builder.tokenizerVocabUrl(new UrlReference("src/test/models/onnx/transformer/dummy_vocab.txt")); - builder.transformerModelUrl(new UrlReference("src/test/models/onnx/transformer/dummy_transformer.onnx")); + builder.tokenizerVocabUrl(new UrlReference(vocabPath)); + builder.transformerModelUrl(new UrlReference(modelPath)); BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); |