diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-06 09:50:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-06 09:50:12 +0200 |
commit | 4ae8a32cd71cc23501f7e1737b27e0bcac7fbd41 (patch) | |
tree | 3e31ecaaab916dd6ce1bd51b3e552e09d578e1eb /model-integration | |
parent | 4878116a848f0ceff01c49b67657d63a4113789d (diff) | |
parent | 6c664b24186756021e6b39801b9694d1815311bf (diff) |
Merge pull request #27297 from vespa-engine/bjorncs/bert-embedder-services-xml
Bjorncs/bert embedder services xml
Diffstat (limited to 'model-integration')
4 files changed, 54 insertions, 49 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index b172ef7beee..a12424c7d12 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -10,7 +10,6 @@ import com.yahoo.language.process.Embedder; import com.yahoo.language.wordpiece.WordPieceEmbedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.util.ArrayList; @@ -39,7 +38,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { private final String attentionMaskName; private final String tokenTypeIdsName; private final String outputName; - private final String poolingStrategy; + private final PoolingStrategy poolingStrategy; private final WordPieceEmbedder tokenizer; private final OnnxEvaluator evaluator; @@ -53,7 +52,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); - poolingStrategy = config.poolingStrategy().toString(); + poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); OnnxEvaluatorOptions options = new OnnxEvaluatorOptions(); options.setExecutionMode(config.onnxExecutionMode().toString()); @@ -124,20 +123,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { Tensor tokenEmbeddings = outputs.get(outputName); - Tensor.Builder builder = Tensor.Builder.of(type); - if (poolingStrategy.equals("mean")) { // average over tokens - Tensor summedEmbeddings = tokenEmbeddings.sum("d1"); - Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1"); - Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); - for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { - builder.cell(averaged.get(TensorAddress.of(0,i)), i); - } - } else { // CLS - use first token - for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { - builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i); - } - } - return builder.build(); + return poolingStrategy.toSentenceEmbedding(type, tokenEmbeddings, attentionMask); } private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) { diff --git a/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java new file mode 100644 index 00000000000..28104d8eeef --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java @@ -0,0 +1,48 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.embedding; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; + +/** + * @author bjorncs + */ +public enum PoolingStrategy { + MEAN { + @Override + public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask) { + var builder = Tensor.Builder.of(type); + var summedEmbeddings = tokenEmbeddings.sum("d1"); + var summedAttentionMask = attentionMask.expand("d0").sum("d1"); + var averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); + for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { + builder.cell(averaged.get(TensorAddress.of(0, i)), i); + } + return builder.build(); + } + }, + CLS { + @Override + public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor ignored) { + var builder = Tensor.Builder.of(type); + for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { + builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i); + } + return builder.build(); + } + }; + + public abstract Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask); + + public static PoolingStrategy fromString(String strategy) { + return switch (strategy.toLowerCase()) { + case "mean" -> MEAN; + case "cls" -> CLS; + default -> throw new IllegalArgumentException("Unknown pooling strategy '%s'".formatted(strategy)); + }; + } +} diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 01804656bb6..f93b1a3c1f8 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -1,5 +1,6 @@ package ai.vespa.embedding.huggingface; +import ai.vespa.embedding.PoolingStrategy; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import ai.vespa.modelintegration.evaluator.OnnxRuntime; @@ -28,6 +29,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final boolean normalize; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; + private final PoolingStrategy poolingStrategy; @Inject public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { @@ -42,6 +44,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { .setTruncation(true) .setMaxLength(config.transformerMaxTokens()) .build(); + poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) onnxOpts.setGpuDevice(config.transformerGpuDevice()); diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def deleted file mode 100644 index 2d8e840377b..00000000000 --- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def +++ /dev/null @@ -1,32 +0,0 @@ - -namespace=embedding - -# Wordpiece tokenizer -tokenizerVocab model - -transformerModel model - -# Max length of token sequence model can handle -transformerMaxTokens int default=384 - -# Pooling strategy -poolingStrategy enum { cls, mean } default=mean - -# Input names -transformerInputIds string default=input_ids -transformerAttentionMask string default=attention_mask -transformerTokenTypeIds string default=token_type_ids - -# special token ids -transformerStartSequenceToken int default=101 -transformerEndSequenceToken int default=102 - -# Output name -transformerOutput string default=output_0 - -# Settings for ONNX model evaluation -onnxExecutionMode enum { parallel, sequential } default=sequential -onnxInterOpThreads int default=1 -onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n -# GPU device id, -1 for CPU -onnxGpuDevice int default=0 |