From 710c04989b94d51a772c635fffcc93b8b8a52895 Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Fri, 5 May 2023 16:46:57 +0200 Subject: Make normalization optional --- .../java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 5 ++++- .../embedding.huggingface.hugging-face-embedder.def | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 44593fa2e57..7715ae2c896 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -26,6 +26,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final String attentionMaskName; private final String outputName; private final int maxTokens; + private final boolean normalize; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; @@ -35,6 +36,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); outputName = config.transformerOutput(); + normalize = config.normalize(); tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString())); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) @@ -107,7 +109,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { builder.cell(averaged.get(TensorAddress.of(0,i)), i); } - return normalize(builder.build(), tensorType); + Tensor result = builder.build(); + return normalize ? normalize(result, tensorType) : result; } Tensor normalize(Tensor embedding, TensorType tensorType) { diff --git a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def index 5ecdb59eae3..3eac14afc12 100644 --- a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def +++ b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def @@ -19,4 +19,7 @@ transformerOutput string default=last_hidden_state # GPU configuration transformerGpuDevice int default=-1 -transformerGpuRequired bool default=false \ No newline at end of file +transformerGpuRequired bool default=false + +# Normalize tensors from tokenizer +normalize bool default=false -- cgit v1.2.3