diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-12 16:41:37 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-12 16:51:26 +0200 |
commit | 4f722322cc9f8df5146ffb27d74239b3b4f2d634 (patch) | |
tree | dad0f0a70513a861844d10a35ba93c1901b48057 /model-integration | |
parent | 838f918baf2f64b5cb737a59e624f20773d95baa (diff) |
Prefer truncation configuration from tokenizer model
Only override truncation if not specified or max length exceeds max tokens accepted by model.
Use JNI wrapper directly to determine existing truncation configuration (JSON format is not really documented).
Simply configuration for pure tokenizer embedder.
Disable DJL usage telemetry.
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 17b63fb1c7d..b035541bb0f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -18,10 +18,15 @@ import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; import java.util.Map; +import java.util.logging.Logger; + +import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; @Beta public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { + private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName()); + private final String inputIdsName; private final String attentionMaskName; private final String tokenTypeIdsName; @@ -38,13 +43,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); normalize = config.normalize(); - tokenizer = new HuggingFaceTokenizer.Builder() + var tokenizerPath = Paths.get(config.tokenizerPath().toString()); + var builder = new HuggingFaceTokenizer.Builder() .addSpecialTokens(true) - .addDefaultModel(Paths.get(config.tokenizerPath().toString())) - .setTruncation(true) - .setPadding(false) - .setMaxLength(config.transformerMaxTokens()) - .build(); + .addDefaultModel(tokenizerPath) + .setPadding(false); + var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); + log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info)); + if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { + // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration + int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() + ? info.maxLength() + : config.transformerMaxTokens(); + builder.setTruncation(true).setMaxLength(maxLength); + } + this.tokenizer = builder.build(); poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) |