From b2e9037a14c2865d8c6377f9de3e07ad06627d9d Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Mon, 5 Jun 2023 17:17:01 +0200 Subject: Make pooling strategy configurable for Huggingface embedder --- .../container/component/HuggingFaceEmbedder.java | 3 ++ config-model/src/main/resources/schema/common.rnc | 11 +++-- .../src/test/cfg/application/embed/services.xml | 1 + .../src/vespa/hugging-face-embedder.def | 2 + .../java/ai/vespa/embedding/BertBaseEmbedder.java | 20 ++------- .../java/ai/vespa/embedding/PoolingStrategy.java | 48 ++++++++++++++++++++++ .../embedding/huggingface/HuggingFaceEmbedder.java | 3 ++ 7 files changed, 67 insertions(+), 21 deletions(-) create mode 100644 model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index 1c36716699e..bb26b7e4fd7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -31,6 +31,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Integer onnxInteropThreads; private final Integer onnxIntraopThreads; private final Integer onnxGpuDevice; + private final String poolingStrategy; public HuggingFaceEmbedder(Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); @@ -50,6 +51,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null); } private static ModelReference resolveDefaultVocab(Element model, boolean hosted) { @@ -75,5 +77,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); + if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy)); } } diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index b2b71950a0c..061e54740f1 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -94,7 +94,8 @@ HuggingFaceEmbedder = element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & element normalize { xsd:boolean }? & - OnnxModelExecutionParams + OnnxModelExecutionParams & + EmbedderPoolingStrategy HuggingFaceTokenizer = attribute type { "hugging-face-tokenizer" } & @@ -108,17 +109,19 @@ BertBaseEmbedder = element transformer-model { ModelReference } & element tokenizer-vocab { ModelReference } & element max-tokens { xsd:nonNegativeInteger }? & - element pooling-strategy { "cls" | "mean" }? & element transformer-input-ids { xsd:string }? & element transformer-attention-mask { xsd:string }? & element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & element transformer-start-sequence-token { xsd:integer }? & element transformer-end-sequence-token { xsd:integer }? & - OnnxModelExecutionParams + OnnxModelExecutionParams & + EmbedderPoolingStrategy OnnxModelExecutionParams = element onnx-execution-mode { "parallel" | "sequential" }? & element onnx-interop-threads { xsd:integer }? & element onnx-intraop-threads { xsd:integer }? & - element onnx-gpu-device { xsd:integer }? \ No newline at end of file + element onnx-gpu-device { xsd:integer }? + +EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }? \ No newline at end of file diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 2b54d850452..0ce61b8ddf8 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -16,6 +16,7 @@ 10 8 1 + mean diff --git a/configdefinitions/src/vespa/hugging-face-embedder.def b/configdefinitions/src/vespa/hugging-face-embedder.def index 36957004e02..7ea4227b3cd 100644 --- a/configdefinitions/src/vespa/hugging-face-embedder.def +++ b/configdefinitions/src/vespa/hugging-face-embedder.def @@ -21,6 +21,8 @@ transformerOutput string default=last_hidden_state # Normalize tensors from tokenizer normalize bool default=false +poolingStrategy enum { cls, mean } default=mean + # Settings for ONNX model evaluation transformerExecutionMode enum { parallel, sequential } default=sequential transformerInterOpThreads int default=1 diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index b172ef7beee..a12424c7d12 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -10,7 +10,6 @@ import com.yahoo.language.process.Embedder; import com.yahoo.language.wordpiece.WordPieceEmbedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.util.ArrayList; @@ -39,7 +38,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { private final String attentionMaskName; private final String tokenTypeIdsName; private final String outputName; - private final String poolingStrategy; + private final PoolingStrategy poolingStrategy; private final WordPieceEmbedder tokenizer; private final OnnxEvaluator evaluator; @@ -53,7 +52,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); - poolingStrategy = config.poolingStrategy().toString(); + poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); OnnxEvaluatorOptions options = new OnnxEvaluatorOptions(); options.setExecutionMode(config.onnxExecutionMode().toString()); @@ -124,20 +123,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { Tensor tokenEmbeddings = outputs.get(outputName); - Tensor.Builder builder = Tensor.Builder.of(type); - if (poolingStrategy.equals("mean")) { // average over tokens - Tensor summedEmbeddings = tokenEmbeddings.sum("d1"); - Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1"); - Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); - for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { - builder.cell(averaged.get(TensorAddress.of(0,i)), i); - } - } else { // CLS - use first token - for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { - builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i); - } - } - return builder.build(); + return poolingStrategy.toSentenceEmbedding(type, tokenEmbeddings, attentionMask); } private List embedWithSeparatorTokens(String text, Context context, int maxLength) { diff --git a/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java new file mode 100644 index 00000000000..28104d8eeef --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java @@ -0,0 +1,48 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.embedding; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; + +/** + * @author bjorncs + */ +public enum PoolingStrategy { + MEAN { + @Override + public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask) { + var builder = Tensor.Builder.of(type); + var summedEmbeddings = tokenEmbeddings.sum("d1"); + var summedAttentionMask = attentionMask.expand("d0").sum("d1"); + var averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); + for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { + builder.cell(averaged.get(TensorAddress.of(0, i)), i); + } + return builder.build(); + } + }, + CLS { + @Override + public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor ignored) { + var builder = Tensor.Builder.of(type); + for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { + builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i); + } + return builder.build(); + } + }; + + public abstract Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask); + + public static PoolingStrategy fromString(String strategy) { + return switch (strategy.toLowerCase()) { + case "mean" -> MEAN; + case "cls" -> CLS; + default -> throw new IllegalArgumentException("Unknown pooling strategy '%s'".formatted(strategy)); + }; + } +} diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 01804656bb6..f93b1a3c1f8 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -1,5 +1,6 @@ package ai.vespa.embedding.huggingface; +import ai.vespa.embedding.PoolingStrategy; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import ai.vespa.modelintegration.evaluator.OnnxRuntime; @@ -28,6 +29,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final boolean normalize; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; + private final PoolingStrategy poolingStrategy; @Inject public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { @@ -42,6 +44,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { .setTruncation(true) .setMaxLength(config.transformerMaxTokens()) .build(); + poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) onnxOpts.setGpuDevice(config.transformerGpuDevice()); -- cgit v1.2.3