diff options
author | Lester Solbakken <lesters@oath.com> | 2022-05-23 10:55:21 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-05-23 10:55:21 +0200 |
commit | 7593c064d3ecf3649cd27f5b9c820b5510f225ee (patch) | |
tree | e8e0cf3e574dfddc4e97c670ad53e6104c8f675a /model-integration | |
parent | e657c0a9618868c9dcf32cfa7e05ac73750b904c (diff) |
Add services.xml syntax for embedders
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java | 17 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java | 9 |
2 files changed, 19 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index 1831903d626..bc3f08ce3d6 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -11,6 +11,8 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.io.File; +import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -58,13 +60,22 @@ public class BertBaseEmbedder implements Embedder { options.setInterOpThreads(modifyThreadCount(config.onnxInterOpThreads())); options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads())); - // Todo: use either file or url - tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocabUrl().getAbsolutePath()).build(); - evaluator = new OnnxEvaluator(config.transformerModelUrl().getAbsolutePath(), options); + String tokenizerFile = pathOrUrl(config.tokenizerVocabPath(), config.tokenizerVocabUrl()); + String modelFile = pathOrUrl(config.transformerModelPath(), config.transformerModelUrl()); + + tokenizer = new WordPieceEmbedder.Builder(tokenizerFile).build(); + evaluator = new OnnxEvaluator(modelFile, options); validateModel(); } + private String pathOrUrl(Path path, File url) { + if (path.endsWith("services.xml")) { + return url.getAbsolutePath(); + } + return path.toAbsolutePath().toString(); + } + private void validateModel() { Map<String, TensorType> inputs = evaluator.getInputInfo(); validateName(inputs, inputIdsName, "input"); diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index 464e5941e89..c224b87982d 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -1,6 +1,7 @@ package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import com.yahoo.config.FileReference; import com.yahoo.config.UrlReference; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.tensor.Tensor; @@ -14,8 +15,6 @@ import static org.junit.Assume.assumeTrue; public class BertBaseEmbedderTest { - - @Test public void testEmbedder() { String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; @@ -23,8 +22,10 @@ public class BertBaseEmbedderTest { assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); - builder.tokenizerVocabUrl(new UrlReference(vocabPath)); - builder.transformerModelUrl(new UrlReference(modelPath)); + builder.tokenizerVocabPath(new FileReference(vocabPath)); + builder.tokenizerVocabUrl(new UrlReference("")); + builder.transformerModelPath(new FileReference(modelPath)); + builder.transformerModelUrl(new UrlReference("")); BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); |