aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 17:17:01 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 17:51:00 +0200
commitb2e9037a14c2865d8c6377f9de3e07ad06627d9d (patch)
tree3204c48c4d3a69119801bdc83d13e5e36a196aae /model-integration
parent684eb28e3b38ce015f5af060a5ba7b14f8f999c9 (diff)
Make pooling strategy configurable for Huggingface embedder
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java48
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java3
3 files changed, 54 insertions, 17 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index b172ef7beee..a12424c7d12 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -10,7 +10,6 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.language.wordpiece.WordPieceEmbedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
@@ -39,7 +38,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
private final String attentionMaskName;
private final String tokenTypeIdsName;
private final String outputName;
- private final String poolingStrategy;
+ private final PoolingStrategy poolingStrategy;
private final WordPieceEmbedder tokenizer;
private final OnnxEvaluator evaluator;
@@ -53,7 +52,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
- poolingStrategy = config.poolingStrategy().toString();
+ poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
options.setExecutionMode(config.onnxExecutionMode().toString());
@@ -124,20 +123,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
Tensor tokenEmbeddings = outputs.get(outputName);
- Tensor.Builder builder = Tensor.Builder.of(type);
- if (poolingStrategy.equals("mean")) { // average over tokens
- Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
- Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
- Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
- for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
- builder.cell(averaged.get(TensorAddress.of(0,i)), i);
- }
- } else { // CLS - use first token
- for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
- builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i);
- }
- }
- return builder.build();
+ return poolingStrategy.toSentenceEmbedding(type, tokenEmbeddings, attentionMask);
}
private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java
new file mode 100644
index 00000000000..28104d8eeef
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java
@@ -0,0 +1,48 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.embedding;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+
+/**
+ * @author bjorncs
+ */
+public enum PoolingStrategy {
+ MEAN {
+ @Override
+ public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask) {
+ var builder = Tensor.Builder.of(type);
+ var summedEmbeddings = tokenEmbeddings.sum("d1");
+ var summedAttentionMask = attentionMask.expand("d0").sum("d1");
+ var averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(averaged.get(TensorAddress.of(0, i)), i);
+ }
+ return builder.build();
+ }
+ },
+ CLS {
+ @Override
+ public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor ignored) {
+ var builder = Tensor.Builder.of(type);
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i);
+ }
+ return builder.build();
+ }
+ };
+
+ public abstract Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask);
+
+ public static PoolingStrategy fromString(String strategy) {
+ return switch (strategy.toLowerCase()) {
+ case "mean" -> MEAN;
+ case "cls" -> CLS;
+ default -> throw new IllegalArgumentException("Unknown pooling strategy '%s'".formatted(strategy));
+ };
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 01804656bb6..f93b1a3c1f8 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -1,5 +1,6 @@
package ai.vespa.embedding.huggingface;
+import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
@@ -28,6 +29,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final boolean normalize;
private final HuggingFaceTokenizer tokenizer;
private final OnnxEvaluator evaluator;
+ private final PoolingStrategy poolingStrategy;
@Inject
public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) {
@@ -42,6 +44,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
.setTruncation(true)
.setMaxLength(config.transformerMaxTokens())
.build();
+ poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
onnxOpts.setGpuDevice(config.transformerGpuDevice());