diff options
Diffstat (limited to 'model-integration/src')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 17b63fb1c7d..b035541bb0f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -18,10 +18,15 @@ import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; import java.util.Map; +import java.util.logging.Logger; + +import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; @Beta public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { + private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName()); + private final String inputIdsName; private final String attentionMaskName; private final String tokenTypeIdsName; @@ -38,13 +43,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); normalize = config.normalize(); - tokenizer = new HuggingFaceTokenizer.Builder() + var tokenizerPath = Paths.get(config.tokenizerPath().toString()); + var builder = new HuggingFaceTokenizer.Builder() .addSpecialTokens(true) - .addDefaultModel(Paths.get(config.tokenizerPath().toString())) - .setTruncation(true) - .setPadding(false) - .setMaxLength(config.transformerMaxTokens()) - .build(); + .addDefaultModel(tokenizerPath) + .setPadding(false); + var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); + log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info)); + if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { + // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration + int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() + ? info.maxLength() + : config.transformerMaxTokens(); + builder.setTruncation(true).setMaxLength(maxLength); + } + this.tokenizer = builder.build(); poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) |