aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java5
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def5
2 files changed, 8 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 44593fa2e57..7715ae2c896 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -26,6 +26,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final String attentionMaskName;
private final String outputName;
private final int maxTokens;
+ private final boolean normalize;
private final HuggingFaceTokenizer tokenizer;
private final OnnxEvaluator evaluator;
@@ -35,6 +36,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
outputName = config.transformerOutput();
+ normalize = config.normalize();
tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
@@ -107,7 +109,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
builder.cell(averaged.get(TensorAddress.of(0,i)), i);
}
- return normalize(builder.build(), tensorType);
+ Tensor result = builder.build();
+ return normalize ? normalize(result, tensorType) : result;
}
Tensor normalize(Tensor embedding, TensorType tensorType) {
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
index 5ecdb59eae3..3eac14afc12 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
@@ -19,4 +19,7 @@ transformerOutput string default=last_hidden_state
# GPU configuration
transformerGpuDevice int default=-1
-transformerGpuRequired bool default=false \ No newline at end of file
+transformerGpuRequired bool default=false
+
+# Normalize tensors from tokenizer
+normalize bool default=false