summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java25
1 files changed, 19 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 17b63fb1c7d..b035541bb0f 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -18,10 +18,15 @@ import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
+import java.util.logging.Logger;
+
+import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;
@Beta
public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
+ private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName());
+
private final String inputIdsName;
private final String attentionMaskName;
private final String tokenTypeIdsName;
@@ -38,13 +43,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
normalize = config.normalize();
- tokenizer = new HuggingFaceTokenizer.Builder()
+ var tokenizerPath = Paths.get(config.tokenizerPath().toString());
+ var builder = new HuggingFaceTokenizer.Builder()
.addSpecialTokens(true)
- .addDefaultModel(Paths.get(config.tokenizerPath().toString()))
- .setTruncation(true)
- .setPadding(false)
- .setMaxLength(config.transformerMaxTokens())
- .build();
+ .addDefaultModel(tokenizerPath)
+ .setPadding(false);
+ var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath);
+ log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info));
+ if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) {
+ // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration
+ int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens()
+ ? info.maxLength()
+ : config.transformerMaxTokens();
+ builder.setTruncation(true).setMaxLength(maxLength);
+ }
+ this.tokenizer = builder.build();
poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)