diff options
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 15 |
1 files changed, 3 insertions, 12 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 0c1cc80544e..d08bc8a3e8c 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -29,14 +29,12 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final String attentionMaskName; private final String tokenTypeIdsName; private final String outputName; - private final int maxTokens; private final boolean normalize; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; @Inject public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { - maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); @@ -45,6 +43,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenizer = new HuggingFaceTokenizer.Builder() .addSpecialTokens(true) .addDefaultModel(Paths.get(config.tokenizerPath().toString())) + .setTruncation(true) + .setMaxLength(config.transformerMaxTokens()) .build(); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) @@ -74,16 +74,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public List<Integer> embed(String s, Context context) { - var tokenIds = tokenizer.embed(s, context); - - int tokensSize = tokenIds.size(); - - if (tokensSize > maxTokens) { - Integer lastElement = tokenIds.get(tokensSize - 1); - tokenIds = tokenIds.subList(0, maxTokens - 1); - tokenIds.add(lastElement); - } - return tokenIds; + return tokenizer.embed(s, context); } @Override |