diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-08-31 22:50:14 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-08-31 22:50:14 +0200 |
commit | adcb1d4d55e71d78c662f798b033d3abea0d4b9e (patch) | |
tree | 5867c3ac85792c1578d6ce463e8e24dd2aea7fb0 /model-integration | |
parent | 2b83da619a3ee2f38a1a3b05576f44d7451b3daf (diff) |
Add 'model' config type
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java | 17 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java | 7 |
2 files changed, 5 insertions, 19 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index bc3f08ce3d6..3dd8d7eefc4 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -11,13 +11,10 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import java.io.File; -import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.Map; - /** * A BERT Base compatible embedder. This embedder uses a WordPiece embedder to * produce a token sequence that is input to a transformer model. A BERT base @@ -60,22 +57,12 @@ public class BertBaseEmbedder implements Embedder { options.setInterOpThreads(modifyThreadCount(config.onnxInterOpThreads())); options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads())); - String tokenizerFile = pathOrUrl(config.tokenizerVocabPath(), config.tokenizerVocabUrl()); - String modelFile = pathOrUrl(config.transformerModelPath(), config.transformerModelUrl()); - - tokenizer = new WordPieceEmbedder.Builder(tokenizerFile).build(); - evaluator = new OnnxEvaluator(modelFile, options); + tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().value()).build(); + evaluator = new OnnxEvaluator(config.transformerModel().value(), options); validateModel(); } - private String pathOrUrl(Path path, File url) { - if (path.endsWith("services.xml")) { - return url.getAbsolutePath(); - } - return path.toAbsolutePath().toString(); - } - private void validateModel() { Map<String, TensorType> inputs = evaluator.getInputInfo(); validateName(inputs, inputIdsName, "input"); diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index c224b87982d..82a78115e63 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -2,6 +2,7 @@ package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import com.yahoo.config.FileReference; +import com.yahoo.config.ModelReference; import com.yahoo.config.UrlReference; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.tensor.Tensor; @@ -22,10 +23,8 @@ public class BertBaseEmbedderTest { assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); - builder.tokenizerVocabPath(new FileReference(vocabPath)); - builder.tokenizerVocabUrl(new UrlReference("")); - builder.transformerModelPath(new FileReference(modelPath)); - builder.transformerModelUrl(new UrlReference("")); + builder.tokenizerVocab(ModelReference.fromPath(vocabPath)); + builder.transformerModel(ModelReference.fromPath(modelPath)); BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); |