summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-05 16:46:57 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-05 16:47:13 +0200
commit710c04989b94d51a772c635fffcc93b8b8a52895 (patch)
tree2063fd59a51f75f05fb61f3ff7cad1c5e72b7d83 /model-integration
parentea351a9d4d393cbf9a2018197557f42ce3c490c1 (diff)
Make normalization optional
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java5
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def5
2 files changed, 8 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 44593fa2e57..7715ae2c896 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -26,6 +26,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final String attentionMaskName;
private final String outputName;
private final int maxTokens;
+ private final boolean normalize;
private final HuggingFaceTokenizer tokenizer;
private final OnnxEvaluator evaluator;
@@ -35,6 +36,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
outputName = config.transformerOutput();
+ normalize = config.normalize();
tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
@@ -107,7 +109,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
builder.cell(averaged.get(TensorAddress.of(0,i)), i);
}
- return normalize(builder.build(), tensorType);
+ Tensor result = builder.build();
+ return normalize ? normalize(result, tensorType) : result;
}
Tensor normalize(Tensor embedding, TensorType tensorType) {
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
index 5ecdb59eae3..3eac14afc12 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
@@ -19,4 +19,7 @@ transformerOutput string default=last_hidden_state
# GPU configuration
transformerGpuDevice int default=-1
-transformerGpuRequired bool default=false \ No newline at end of file
+transformerGpuRequired bool default=false
+
+# Normalize tensors from tokenizer
+normalize bool default=false