aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 17:17:01 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 17:51:00 +0200
commitb2e9037a14c2865d8c6377f9de3e07ad06627d9d (patch)
tree3204c48c4d3a69119801bdc83d13e5e36a196aae
parent684eb28e3b38ce015f5af060a5ba7b14f8f999c9 (diff)
Make pooling strategy configurable for Huggingface embedder
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java3
-rw-r--r--config-model/src/main/resources/schema/common.rnc11
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml1
-rw-r--r--configdefinitions/src/vespa/hugging-face-embedder.def2
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java48
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java3
7 files changed, 67 insertions, 21 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index 1c36716699e..bb26b7e4fd7 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -31,6 +31,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private final Integer onnxInteropThreads;
private final Integer onnxIntraopThreads;
private final Integer onnxGpuDevice;
+ private final String poolingStrategy;
public HuggingFaceEmbedder(Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
@@ -50,6 +51,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null);
}
private static ModelReference resolveDefaultVocab(Element model, boolean hosted) {
@@ -75,5 +77,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
+ if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy));
}
}
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index b2b71950a0c..061e54740f1 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -94,7 +94,8 @@ HuggingFaceEmbedder =
element transformer-token-type-ids { xsd:string }? &
element transformer-output { xsd:string }? &
element normalize { xsd:boolean }? &
- OnnxModelExecutionParams
+ OnnxModelExecutionParams &
+ EmbedderPoolingStrategy
HuggingFaceTokenizer =
attribute type { "hugging-face-tokenizer" } &
@@ -108,17 +109,19 @@ BertBaseEmbedder =
element transformer-model { ModelReference } &
element tokenizer-vocab { ModelReference } &
element max-tokens { xsd:nonNegativeInteger }? &
- element pooling-strategy { "cls" | "mean" }? &
element transformer-input-ids { xsd:string }? &
element transformer-attention-mask { xsd:string }? &
element transformer-token-type-ids { xsd:string }? &
element transformer-output { xsd:string }? &
element transformer-start-sequence-token { xsd:integer }? &
element transformer-end-sequence-token { xsd:integer }? &
- OnnxModelExecutionParams
+ OnnxModelExecutionParams &
+ EmbedderPoolingStrategy
OnnxModelExecutionParams =
element onnx-execution-mode { "parallel" | "sequential" }? &
element onnx-interop-threads { xsd:integer }? &
element onnx-intraop-threads { xsd:integer }? &
- element onnx-gpu-device { xsd:integer }? \ No newline at end of file
+ element onnx-gpu-device { xsd:integer }?
+
+EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }? \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 2b54d850452..0ce61b8ddf8 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -16,6 +16,7 @@
<onnx-intraop-threads>10</onnx-intraop-threads>
<onnx-interop-threads>8</onnx-interop-threads>
<onnx-gpu-device>1</onnx-gpu-device>
+ <pooling-strategy>mean</pooling-strategy>
</component>
<component id="hf-tokenizer" type="hugging-face-tokenizer">
diff --git a/configdefinitions/src/vespa/hugging-face-embedder.def b/configdefinitions/src/vespa/hugging-face-embedder.def
index 36957004e02..7ea4227b3cd 100644
--- a/configdefinitions/src/vespa/hugging-face-embedder.def
+++ b/configdefinitions/src/vespa/hugging-face-embedder.def
@@ -21,6 +21,8 @@ transformerOutput string default=last_hidden_state
# Normalize tensors from tokenizer
normalize bool default=false
+poolingStrategy enum { cls, mean } default=mean
+
# Settings for ONNX model evaluation
transformerExecutionMode enum { parallel, sequential } default=sequential
transformerInterOpThreads int default=1
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index b172ef7beee..a12424c7d12 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -10,7 +10,6 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.language.wordpiece.WordPieceEmbedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
@@ -39,7 +38,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
private final String attentionMaskName;
private final String tokenTypeIdsName;
private final String outputName;
- private final String poolingStrategy;
+ private final PoolingStrategy poolingStrategy;
private final WordPieceEmbedder tokenizer;
private final OnnxEvaluator evaluator;
@@ -53,7 +52,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
- poolingStrategy = config.poolingStrategy().toString();
+ poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
options.setExecutionMode(config.onnxExecutionMode().toString());
@@ -124,20 +123,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
Tensor tokenEmbeddings = outputs.get(outputName);
- Tensor.Builder builder = Tensor.Builder.of(type);
- if (poolingStrategy.equals("mean")) { // average over tokens
- Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
- Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
- Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
- for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
- builder.cell(averaged.get(TensorAddress.of(0,i)), i);
- }
- } else { // CLS - use first token
- for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
- builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i);
- }
- }
- return builder.build();
+ return poolingStrategy.toSentenceEmbedding(type, tokenEmbeddings, attentionMask);
}
private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java
new file mode 100644
index 00000000000..28104d8eeef
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java
@@ -0,0 +1,48 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.embedding;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+
+/**
+ * @author bjorncs
+ */
+public enum PoolingStrategy {
+ MEAN {
+ @Override
+ public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask) {
+ var builder = Tensor.Builder.of(type);
+ var summedEmbeddings = tokenEmbeddings.sum("d1");
+ var summedAttentionMask = attentionMask.expand("d0").sum("d1");
+ var averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(averaged.get(TensorAddress.of(0, i)), i);
+ }
+ return builder.build();
+ }
+ },
+ CLS {
+ @Override
+ public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor ignored) {
+ var builder = Tensor.Builder.of(type);
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i);
+ }
+ return builder.build();
+ }
+ };
+
+ public abstract Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask);
+
+ public static PoolingStrategy fromString(String strategy) {
+ return switch (strategy.toLowerCase()) {
+ case "mean" -> MEAN;
+ case "cls" -> CLS;
+ default -> throw new IllegalArgumentException("Unknown pooling strategy '%s'".formatted(strategy));
+ };
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 01804656bb6..f93b1a3c1f8 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -1,5 +1,6 @@
package ai.vespa.embedding.huggingface;
+import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
@@ -28,6 +29,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final boolean normalize;
private final HuggingFaceTokenizer tokenizer;
private final OnnxEvaluator evaluator;
+ private final PoolingStrategy poolingStrategy;
@Inject
public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) {
@@ -42,6 +44,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
.setTruncation(true)
.setMaxLength(config.transformerMaxTokens())
.build();
+ poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
onnxOpts.setGpuDevice(config.transformerGpuDevice());