aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2022-04-01 09:49:51 +0200
committerLester Solbakken <lesters@oath.com>2022-04-01 09:49:51 +0200
commit2157f9adc8d1a9d553eb9d7fbf202530faf8af43 (patch)
tree7114f26db5741519b7bbac0033ad9ba8d208fd98 /model-integration/src
parent4cc5a56f7006c40b059855f345ded365ace8550c (diff)
Skip test if onnx is not available (arm64)
Diffstat (limited to 'model-integration/src')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java6
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java12
2 files changed, 15 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index c9ab9924214..bdcceddb04f 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -94,8 +94,12 @@ public class OnnxEvaluator {
}
public static boolean isRuntimeAvailable() {
+ return isRuntimeAvailable("");
+ }
+
+ public static boolean isRuntimeAvailable(String modelPath) {
try {
- new OnnxEvaluator("");
+ new OnnxEvaluator(modelPath);
return true;
} catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) {
return false;
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index 519f24795ca..0ecc78f7668 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -1,5 +1,6 @@
package ai.vespa.embedding;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import com.yahoo.config.UrlReference;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -8,14 +9,21 @@ import org.junit.Test;
import java.util.List;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assume.assumeTrue;
public class BertBaseEmbedderTest {
+
+
@Test
public void testEmbedder() {
+ String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
+ String modelPath = "src/test/models/onnx/transformer/dummy_transformer.onnx";
+ assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
- builder.tokenizerVocabUrl(new UrlReference("src/test/models/onnx/transformer/dummy_vocab.txt"));
- builder.transformerModelUrl(new UrlReference("src/test/models/onnx/transformer/dummy_transformer.onnx"));
+ builder.tokenizerVocabUrl(new UrlReference(vocabPath));
+ builder.transformerModelUrl(new UrlReference(modelPath));
BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");