aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-08-31 22:50:14 +0200
committerJon Bratseth <bratseth@gmail.com>2022-08-31 22:50:14 +0200
commitadcb1d4d55e71d78c662f798b033d3abea0d4b9e (patch)
tree5867c3ac85792c1578d6ce463e8e24dd2aea7fb0 /model-integration
parent2b83da619a3ee2f38a1a3b05576f44d7451b3daf (diff)
Add 'model' config type
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java17
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java7
2 files changed, 5 insertions, 19 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index bc3f08ce3d6..3dd8d7eefc4 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -11,13 +11,10 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import java.io.File;
-import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-
/**
* A BERT Base compatible embedder. This embedder uses a WordPiece embedder to
* produce a token sequence that is input to a transformer model. A BERT base
@@ -60,22 +57,12 @@ public class BertBaseEmbedder implements Embedder {
options.setInterOpThreads(modifyThreadCount(config.onnxInterOpThreads()));
options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads()));
- String tokenizerFile = pathOrUrl(config.tokenizerVocabPath(), config.tokenizerVocabUrl());
- String modelFile = pathOrUrl(config.transformerModelPath(), config.transformerModelUrl());
-
- tokenizer = new WordPieceEmbedder.Builder(tokenizerFile).build();
- evaluator = new OnnxEvaluator(modelFile, options);
+ tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().value()).build();
+ evaluator = new OnnxEvaluator(config.transformerModel().value(), options);
validateModel();
}
- private String pathOrUrl(Path path, File url) {
- if (path.endsWith("services.xml")) {
- return url.getAbsolutePath();
- }
- return path.toAbsolutePath().toString();
- }
-
private void validateModel() {
Map<String, TensorType> inputs = evaluator.getInputInfo();
validateName(inputs, inputIdsName, "input");
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index c224b87982d..82a78115e63 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -2,6 +2,7 @@ package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import com.yahoo.config.FileReference;
+import com.yahoo.config.ModelReference;
import com.yahoo.config.UrlReference;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.tensor.Tensor;
@@ -22,10 +23,8 @@ public class BertBaseEmbedderTest {
assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
- builder.tokenizerVocabPath(new FileReference(vocabPath));
- builder.tokenizerVocabUrl(new UrlReference(""));
- builder.transformerModelPath(new FileReference(modelPath));
- builder.transformerModelUrl(new UrlReference(""));
+ builder.tokenizerVocab(ModelReference.fromPath(vocabPath));
+ builder.transformerModel(ModelReference.fromPath(modelPath));
BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");