aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-04-25 15:32:04 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2024-04-25 15:32:04 +0200
commita374e90b5f95b3f3c533a4d0302ac0e66c32668f (patch)
tree5496e8603a071aea36a5f0f8d24bcdf266bb3b23
parent117cace612ab00de27b8ec5e77896056e449bf33 (diff)
add prepend support
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java13
-rw-r--r--config-model/src/main/resources/schema/common.rnc6
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml4
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java4
-rw-r--r--configdefinitions/src/vespa/hugging-face-embedder.def2
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java18
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java47
7 files changed, 93 insertions, 1 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index fe0bb7c8075..91060579c4e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -12,6 +12,7 @@ import java.util.Set;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode;
+import static com.yahoo.text.XML.getChild;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
@@ -34,6 +35,10 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private final Boolean normalize;
private final String poolingStrategy;
+ private String prependQuery;
+
+ private String prependDocument;
+
public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow();
@@ -51,6 +56,12 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
+ Element prepend = getChild(xml, "prepend");
+ if (prepend != null) {
+ prependQuery = getChildValue(prepend, "query").orElse(null);
+ prependDocument = getChildValue(prepend, "document").orElse(null);
+ }
+
model.registerOnnxModelCost(cluster, onnxModelOptions);
}
@@ -64,6 +75,8 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
if (transformerOutput != null) b.transformerOutput(transformerOutput);
if (normalize != null) b.normalize(normalize);
if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy));
+ if(prependQuery != null) b.prependQuery(prependQuery);
+ if(prependDocument != null) b.prependDocument(prependDocument);
onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index 14fae90678d..d949edbcacf 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -94,9 +94,15 @@ HuggingFaceEmbedder =
element transformer-token-type-ids { xsd:string }? &
element transformer-output { xsd:string }? &
element normalize { xsd:boolean }? &
+ PrependResources? &
OnnxModelExecutionParams &
EmbedderPoolingStrategy
+PrependResources = element prepend {
+ element query { xsd:string }? &
+ element document { xsd:string }?
+}
+
SpladeEmbedder =
attribute type { "splade-embedder" } &
element transformer-model { ModelReference } &
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 089adbd7517..ae0d0952630 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -12,6 +12,10 @@
<transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
<transformer-output>my_output</transformer-output>
<normalize>true</normalize>
+ <prepend>
+ <query>Represent this sentence for searching relevant passages:</query>
+ <document>passage:</document>
+ </prepend>
<onnx-execution-mode>parallel</onnx-execution-mode>
<onnx-intraop-threads>10</onnx-intraop-threads>
<onnx-interop-threads>8</onnx-interop-threads>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index fb1e176f707..6f629d99c92 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -88,6 +88,8 @@ public class EmbedderTestCase {
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
assertEquals(-1, tokenizerCfg.maxLength());
+ assertEquals("Represent this sentence for searching relevant passages:", embedderCfg.prependQuery());
+ assertEquals("passage:", embedderCfg.prependDocument());
}
@Test
@@ -101,6 +103,8 @@ public class EmbedderTestCase {
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
assertEquals(-1, tokenizerCfg.maxLength());
+ assertEquals("Represent this sentence for searching relevant passages:", embedderCfg.prependQuery());
+ assertEquals("passage:", embedderCfg.prependDocument());
}
@Test
diff --git a/configdefinitions/src/vespa/hugging-face-embedder.def b/configdefinitions/src/vespa/hugging-face-embedder.def
index a26f6917443..d89d6923802 100644
--- a/configdefinitions/src/vespa/hugging-face-embedder.def
+++ b/configdefinitions/src/vespa/hugging-face-embedder.def
@@ -18,6 +18,8 @@ transformerTokenTypeIds string default=token_type_ids
# Output name
transformerOutput string default=last_hidden_state
+prependQuery string default=""
+prependDocument string default=""
# Normalize tensors from tokenizer
normalize bool default=false
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 20d8b6362d3..3e5dcfda3e9 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -39,6 +39,10 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final OnnxEvaluator evaluator;
private final PoolingStrategy poolingStrategy;
+ private final String prependQuery;
+
+ private final String prependDocument;
+
@Inject
public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFaceEmbedderConfig config) {
this.runtime = runtime;
@@ -47,6 +51,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
normalize = config.normalize();
+ prependQuery = config.prependQuery();
+ prependDocument = config.prependDocument();
var tokenizerPath = Paths.get(config.tokenizerPath().toString());
var builder = new HuggingFaceTokenizer.Builder()
.addSpecialTokens(true)
@@ -113,7 +119,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
if (!tensorType.dimensions().get(0).isIndexed()) {
throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed.");
}
- var embeddingResult = lookupOrEvaluate(context, text);
+ var embeddingResult = lookupOrEvaluate(context, prependInstruction(text, context));
IndexedTensor tokenEmbeddings = embeddingResult.output;
if (tensorType.valueType() == TensorType.Value.INT8) {
return binaryQuantization(embeddingResult, tensorType);
@@ -123,6 +129,16 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
}
}
+ String prependInstruction(String text, Context context) {
+ if (prependQuery != null && !prependQuery.isEmpty() && context.getDestination().startsWith("query")) {
+ return prependQuery + " " + text;
+ }
+ if (prependDocument != null && !prependDocument.isEmpty()){
+ return prependDocument + " " + text;
+ }
+ return text;
+ }
+
Tensor normalize(Tensor embedding, TensorType tensorType) {
double sumOfSquares = 0.0;
diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
index d504d77cc9b..c2c37db31f6 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
@@ -143,6 +143,39 @@ public class HuggingFaceEmbedderTest {
});
}
+ @Test
+ public void testEmbedderWithNormalizationAndPrefix() {
+ String input = "This is a test";
+ var context = new Embedder.Context("schema.indexing");
+ Tensor result = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
+ assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
+ result = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor<float>(x[16])")));
+ assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
+ Tensor binarizedResult = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])")));
+ assertEquals("tensor<int8>(x[2]):[125, 44]", binarizedResult.toAbbreviatedString());
+
+ var queryContext = new Embedder.Context("query.qt");
+ Tensor queryResult = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor<float>(x[8])")));
+ assertEquals(1.0, queryResult.multiply(queryResult).sum().asDouble(), 1e-3);
+ queryResult = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor<float>(x[16])")));
+ assertEquals(1.0, queryResult.multiply(queryResult).sum().asDouble(), 1e-3);
+ Tensor binarizedResultQuery = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor<int8>(x[2])")));
+ assertNotEquals(binarizedResult.toAbbreviatedString(), binarizedResultQuery.toAbbreviatedString());
+ assertEquals("tensor<int8>(x[2]):[119, -116]", binarizedResultQuery.toAbbreviatedString());
+ }
+
+ @Test
+ public void testPrepend() {
+ var context = new Embedder.Context("schema.indexing");
+ String input = "This is a test";
+ var embedder = getNormalizePrefixdEmbedder();
+ var result = embedder.prependInstruction(input, context);
+ assertEquals("This is a document: This is a test", result);
+ var queryContext = new Embedder.Context("query.qt");
+ var queryResult = embedder.prependInstruction(input, queryContext);
+ assertEquals("Represent this text: This is a test", queryResult);
+ }
+
private static HuggingFaceEmbedder getEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";
@@ -165,6 +198,20 @@ public class HuggingFaceEmbedderTest {
return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
}
+ private static HuggingFaceEmbedder getNormalizePrefixdEmbedder() {
+ String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
+ String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
+ HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
+ builder.tokenizerPath(ModelReference.valueOf(vocabPath));
+ builder.transformerModel(ModelReference.valueOf(modelPath));
+ builder.transformerGpuDevice(-1);
+ builder.normalize(true);
+ builder.prependQuery("Represent this text:");
+ builder.prependDocument("This is a document:");
+ return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ }
+
public static Tensor expandBitTensor(Tensor packed) {
var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.DOUBLE, "big");
var context = new MapContext();