summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-06 09:50:12 +0200
committerGitHub <noreply@github.com>2023-06-06 09:50:12 +0200
commit4ae8a32cd71cc23501f7e1737b27e0bcac7fbd41 (patch)
tree3e31ecaaab916dd6ce1bd51b3e552e09d578e1eb /model-integration
parent4878116a848f0ceff01c49b67657d63a4113789d (diff)
parent6c664b24186756021e6b39801b9694d1815311bf (diff)
Merge pull request #27297 from vespa-engine/bjorncs/bert-embedder-services-xml
Bjorncs/bert embedder services xml
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java48
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java3
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def32
4 files changed, 54 insertions, 49 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index b172ef7beee..a12424c7d12 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -10,7 +10,6 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.language.wordpiece.WordPieceEmbedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
@@ -39,7 +38,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
private final String attentionMaskName;
private final String tokenTypeIdsName;
private final String outputName;
- private final String poolingStrategy;
+ private final PoolingStrategy poolingStrategy;
private final WordPieceEmbedder tokenizer;
private final OnnxEvaluator evaluator;
@@ -53,7 +52,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
- poolingStrategy = config.poolingStrategy().toString();
+ poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
options.setExecutionMode(config.onnxExecutionMode().toString());
@@ -124,20 +123,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
Tensor tokenEmbeddings = outputs.get(outputName);
- Tensor.Builder builder = Tensor.Builder.of(type);
- if (poolingStrategy.equals("mean")) { // average over tokens
- Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
- Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
- Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
- for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
- builder.cell(averaged.get(TensorAddress.of(0,i)), i);
- }
- } else { // CLS - use first token
- for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
- builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i);
- }
- }
- return builder.build();
+ return poolingStrategy.toSentenceEmbedding(type, tokenEmbeddings, attentionMask);
}
private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java
new file mode 100644
index 00000000000..28104d8eeef
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java
@@ -0,0 +1,48 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.embedding;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+
+/**
+ * @author bjorncs
+ */
+public enum PoolingStrategy {
+ MEAN {
+ @Override
+ public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask) {
+ var builder = Tensor.Builder.of(type);
+ var summedEmbeddings = tokenEmbeddings.sum("d1");
+ var summedAttentionMask = attentionMask.expand("d0").sum("d1");
+ var averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(averaged.get(TensorAddress.of(0, i)), i);
+ }
+ return builder.build();
+ }
+ },
+ CLS {
+ @Override
+ public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor ignored) {
+ var builder = Tensor.Builder.of(type);
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i);
+ }
+ return builder.build();
+ }
+ };
+
+ public abstract Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask);
+
+ public static PoolingStrategy fromString(String strategy) {
+ return switch (strategy.toLowerCase()) {
+ case "mean" -> MEAN;
+ case "cls" -> CLS;
+ default -> throw new IllegalArgumentException("Unknown pooling strategy '%s'".formatted(strategy));
+ };
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 01804656bb6..f93b1a3c1f8 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -1,5 +1,6 @@
package ai.vespa.embedding.huggingface;
+import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
@@ -28,6 +29,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final boolean normalize;
private final HuggingFaceTokenizer tokenizer;
private final OnnxEvaluator evaluator;
+ private final PoolingStrategy poolingStrategy;
@Inject
public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) {
@@ -42,6 +44,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
.setTruncation(true)
.setMaxLength(config.transformerMaxTokens())
.build();
+ poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
onnxOpts.setGpuDevice(config.transformerGpuDevice());
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
deleted file mode 100644
index 2d8e840377b..00000000000
--- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
+++ /dev/null
@@ -1,32 +0,0 @@
-
-namespace=embedding
-
-# Wordpiece tokenizer
-tokenizerVocab model
-
-transformerModel model
-
-# Max length of token sequence model can handle
-transformerMaxTokens int default=384
-
-# Pooling strategy
-poolingStrategy enum { cls, mean } default=mean
-
-# Input names
-transformerInputIds string default=input_ids
-transformerAttentionMask string default=attention_mask
-transformerTokenTypeIds string default=token_type_ids
-
-# special token ids
-transformerStartSequenceToken int default=101
-transformerEndSequenceToken int default=102
-
-# Output name
-transformerOutput string default=output_0
-
-# Settings for ONNX model evaluation
-onnxExecutionMode enum { parallel, sequential } default=sequential
-onnxInterOpThreads int default=1
-onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
-# GPU device id, -1 for CPU
-onnxGpuDevice int default=0