summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 16:41:37 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 16:51:26 +0200
commit4f722322cc9f8df5146ffb27d74239b3b4f2d634 (patch)
treedad0f0a70513a861844d10a35ba93c1901b48057 /model-integration
parent838f918baf2f64b5cb737a59e624f20773d95baa (diff)
Prefer truncation configuration from tokenizer model
Only override truncation if not specified or max length exceeds max tokens accepted by model. Use JNI wrapper directly to determine existing truncation configuration (JSON format is not really documented). Simply configuration for pure tokenizer embedder. Disable DJL usage telemetry.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java25
1 files changed, 19 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 17b63fb1c7d..b035541bb0f 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -18,10 +18,15 @@ import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
+import java.util.logging.Logger;
+
+import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;
@Beta
public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
+ private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName());
+
private final String inputIdsName;
private final String attentionMaskName;
private final String tokenTypeIdsName;
@@ -38,13 +43,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
normalize = config.normalize();
- tokenizer = new HuggingFaceTokenizer.Builder()
+ var tokenizerPath = Paths.get(config.tokenizerPath().toString());
+ var builder = new HuggingFaceTokenizer.Builder()
.addSpecialTokens(true)
- .addDefaultModel(Paths.get(config.tokenizerPath().toString()))
- .setTruncation(true)
- .setPadding(false)
- .setMaxLength(config.transformerMaxTokens())
- .build();
+ .addDefaultModel(tokenizerPath)
+ .setPadding(false);
+ var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath);
+ log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info));
+ if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) {
+ // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration
+ int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens()
+ ? info.maxLength()
+ : config.transformerMaxTokens();
+ builder.setTruncation(true).setMaxLength(maxLength);
+ }
+ this.tokenizer = builder.build();
poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)