summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-26 15:44:33 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-26 15:44:33 +0200
commitda7eaaf9389b6b2d1bfd1bcff3770d61b150fd1f (patch)
tree9562eb4764e15ea9452f8f0911bfe3737c5666b7 /model-integration
parent17f53418cfaaf49b87f06123af0aaf73f3593fe8 (diff)
Make truncation and max length configurable
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java15
1 files changed, 3 insertions, 12 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 0c1cc80544e..d08bc8a3e8c 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -29,14 +29,12 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final String attentionMaskName;
private final String tokenTypeIdsName;
private final String outputName;
- private final int maxTokens;
private final boolean normalize;
private final HuggingFaceTokenizer tokenizer;
private final OnnxEvaluator evaluator;
@Inject
public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) {
- maxTokens = config.transformerMaxTokens();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
@@ -45,6 +43,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenizer = new HuggingFaceTokenizer.Builder()
.addSpecialTokens(true)
.addDefaultModel(Paths.get(config.tokenizerPath().toString()))
+ .setTruncation(true)
+ .setMaxLength(config.transformerMaxTokens())
.build();
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
@@ -74,16 +74,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
@Override
public List<Integer> embed(String s, Context context) {
- var tokenIds = tokenizer.embed(s, context);
-
- int tokensSize = tokenIds.size();
-
- if (tokensSize > maxTokens) {
- Integer lastElement = tokenIds.get(tokensSize - 1);
- tokenIds = tokenIds.subList(0, maxTokens - 1);
- tokenIds.add(lastElement);
- }
- return tokenIds;
+ return tokenizer.embed(s, context);
}
@Override