diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-26 15:44:33 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-26 15:44:33 +0200 |
commit | da7eaaf9389b6b2d1bfd1bcff3770d61b150fd1f (patch) | |
tree | 9562eb4764e15ea9452f8f0911bfe3737c5666b7 /model-integration | |
parent | 17f53418cfaaf49b87f06123af0aaf73f3593fe8 (diff) |
Make truncation and max length configurable
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 15 |
1 files changed, 3 insertions, 12 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 0c1cc80544e..d08bc8a3e8c 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -29,14 +29,12 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final String attentionMaskName; private final String tokenTypeIdsName; private final String outputName; - private final int maxTokens; private final boolean normalize; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; @Inject public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { - maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); @@ -45,6 +43,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenizer = new HuggingFaceTokenizer.Builder() .addSpecialTokens(true) .addDefaultModel(Paths.get(config.tokenizerPath().toString())) + .setTruncation(true) + .setMaxLength(config.transformerMaxTokens()) .build(); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) @@ -74,16 +74,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public List<Integer> embed(String s, Context context) { - var tokenIds = tokenizer.embed(s, context); - - int tokensSize = tokenIds.size(); - - if (tokensSize > maxTokens) { - Integer lastElement = tokenIds.get(tokensSize - 1); - tokenIds = tokenIds.subList(0, maxTokens - 1); - tokenIds.add(lastElement); - } - return tokenIds; + return tokenizer.embed(s, context); } @Override |