summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2022-05-23 10:55:21 +0200
committerLester Solbakken <lesters@oath.com>2022-05-23 10:55:21 +0200
commit7593c064d3ecf3649cd27f5b9c820b5510f225ee (patch)
treee8e0cf3e574dfddc4e97c670ad53e6104c8f675a /model-integration
parente657c0a9618868c9dcf32cfa7e05ac73750b904c (diff)
Add services.xml syntax for embedders
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java17
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java9
2 files changed, 19 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index 1831903d626..bc3f08ce3d6 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -11,6 +11,8 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.io.File;
+import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -58,13 +60,22 @@ public class BertBaseEmbedder implements Embedder {
options.setInterOpThreads(modifyThreadCount(config.onnxInterOpThreads()));
options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads()));
- // Todo: use either file or url
- tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocabUrl().getAbsolutePath()).build();
- evaluator = new OnnxEvaluator(config.transformerModelUrl().getAbsolutePath(), options);
+ String tokenizerFile = pathOrUrl(config.tokenizerVocabPath(), config.tokenizerVocabUrl());
+ String modelFile = pathOrUrl(config.transformerModelPath(), config.transformerModelUrl());
+
+ tokenizer = new WordPieceEmbedder.Builder(tokenizerFile).build();
+ evaluator = new OnnxEvaluator(modelFile, options);
validateModel();
}
+ private String pathOrUrl(Path path, File url) {
+ if (path.endsWith("services.xml")) {
+ return url.getAbsolutePath();
+ }
+ return path.toAbsolutePath().toString();
+ }
+
private void validateModel() {
Map<String, TensorType> inputs = evaluator.getInputInfo();
validateName(inputs, inputIdsName, "input");
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index 464e5941e89..c224b87982d 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -1,6 +1,7 @@
package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import com.yahoo.config.FileReference;
import com.yahoo.config.UrlReference;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.tensor.Tensor;
@@ -14,8 +15,6 @@ import static org.junit.Assume.assumeTrue;
public class BertBaseEmbedderTest {
-
-
@Test
public void testEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
@@ -23,8 +22,10 @@ public class BertBaseEmbedderTest {
assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
- builder.tokenizerVocabUrl(new UrlReference(vocabPath));
- builder.transformerModelUrl(new UrlReference(modelPath));
+ builder.tokenizerVocabPath(new FileReference(vocabPath));
+ builder.tokenizerVocabUrl(new UrlReference(""));
+ builder.transformerModelPath(new FileReference(modelPath));
+ builder.transformerModelUrl(new UrlReference(""));
BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");