diff options
author | Bjørn Christian Seime <bjorn.christian@seime.no> | 2024-04-26 07:56:30 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-26 07:56:30 +0200 |
commit | 95cccd4698808d320f41461697f74a1c9d161bac (patch) | |
tree | cd9336653627e4f3274a29f5f9da435a97b2150c | |
parent | face477e177b1fdd5462a5c961e2d00d43895b97 (diff) | |
parent | a374e90b5f95b3f3c533a4d0302ac0e66c32668f (diff) |
Merge pull request #31049 from vespa-engine/jobergum/add-prepend-embedder-support
add prepend support to embedder
7 files changed, 93 insertions, 1 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index fe0bb7c8075..91060579c4e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -12,6 +12,7 @@ import java.util.Set; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode; +import static com.yahoo.text.XML.getChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; @@ -34,6 +35,10 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Boolean normalize; private final String poolingStrategy; + private String prependQuery; + + private String prependDocument; + public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); @@ -51,6 +56,12 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm transformerOutput = getChildValue(xml, "transformer-output").orElse(null); normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); + Element prepend = getChild(xml, "prepend"); + if (prepend != null) { + prependQuery = getChildValue(prepend, "query").orElse(null); + prependDocument = getChildValue(prepend, "document").orElse(null); + } + model.registerOnnxModelCost(cluster, onnxModelOptions); } @@ -64,6 +75,8 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm if (transformerOutput != null) b.transformerOutput(transformerOutput); if (normalize != null) b.normalize(normalize); if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy)); + if(prependQuery != null) b.prependQuery(prependQuery); + if(prependDocument != null) b.prependDocument(prependDocument); onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index 14fae90678d..d949edbcacf 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -94,9 +94,15 @@ HuggingFaceEmbedder = element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & element normalize { xsd:boolean }? & + PrependResources? & OnnxModelExecutionParams & EmbedderPoolingStrategy +PrependResources = element prepend { + element query { xsd:string }? & + element document { xsd:string }? +} + SpladeEmbedder = attribute type { "splade-embedder" } & element transformer-model { ModelReference } & diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 089adbd7517..ae0d0952630 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -12,6 +12,10 @@ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> <transformer-output>my_output</transformer-output> <normalize>true</normalize> + <prepend> + <query>Represent this sentence for searching relevant passages:</query> + <document>passage:</document> + </prepend> <onnx-execution-mode>parallel</onnx-execution-mode> <onnx-intraop-threads>10</onnx-intraop-threads> <onnx-interop-threads>8</onnx-interop-threads> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index fb1e176f707..6f629d99c92 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -88,6 +88,8 @@ public class EmbedderTestCase { var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); assertEquals(-1, tokenizerCfg.maxLength()); + assertEquals("Represent this sentence for searching relevant passages:", embedderCfg.prependQuery()); + assertEquals("passage:", embedderCfg.prependDocument()); } @Test @@ -101,6 +103,8 @@ public class EmbedderTestCase { var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); assertEquals(-1, tokenizerCfg.maxLength()); + assertEquals("Represent this sentence for searching relevant passages:", embedderCfg.prependQuery()); + assertEquals("passage:", embedderCfg.prependDocument()); } @Test diff --git a/configdefinitions/src/vespa/hugging-face-embedder.def b/configdefinitions/src/vespa/hugging-face-embedder.def index a26f6917443..d89d6923802 100644 --- a/configdefinitions/src/vespa/hugging-face-embedder.def +++ b/configdefinitions/src/vespa/hugging-face-embedder.def @@ -18,6 +18,8 @@ transformerTokenTypeIds string default=token_type_ids # Output name transformerOutput string default=last_hidden_state +prependQuery string default="" +prependDocument string default="" # Normalize tensors from tokenizer normalize bool default=false diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 20d8b6362d3..3e5dcfda3e9 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -39,6 +39,10 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final OnnxEvaluator evaluator; private final PoolingStrategy poolingStrategy; + private final String prependQuery; + + private final String prependDocument; + @Inject public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFaceEmbedderConfig config) { this.runtime = runtime; @@ -47,6 +51,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); normalize = config.normalize(); + prependQuery = config.prependQuery(); + prependDocument = config.prependDocument(); var tokenizerPath = Paths.get(config.tokenizerPath().toString()); var builder = new HuggingFaceTokenizer.Builder() .addSpecialTokens(true) @@ -113,7 +119,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { if (!tensorType.dimensions().get(0).isIndexed()) { throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed."); } - var embeddingResult = lookupOrEvaluate(context, text); + var embeddingResult = lookupOrEvaluate(context, prependInstruction(text, context)); IndexedTensor tokenEmbeddings = embeddingResult.output; if (tensorType.valueType() == TensorType.Value.INT8) { return binaryQuantization(embeddingResult, tensorType); @@ -123,6 +129,16 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { } } + String prependInstruction(String text, Context context) { + if (prependQuery != null && !prependQuery.isEmpty() && context.getDestination().startsWith("query")) { + return prependQuery + " " + text; + } + if (prependDocument != null && !prependDocument.isEmpty()){ + return prependDocument + " " + text; + } + return text; + } + Tensor normalize(Tensor embedding, TensorType tensorType) { double sumOfSquares = 0.0; diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java index d504d77cc9b..c2c37db31f6 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java @@ -143,6 +143,39 @@ public class HuggingFaceEmbedderTest { }); } + @Test + public void testEmbedderWithNormalizationAndPrefix() { + String input = "This is a test"; + var context = new Embedder.Context("schema.indexing"); + Tensor result = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])"))); + assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3); + result = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor<float>(x[16])"))); + assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3); + Tensor binarizedResult = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])"))); + assertEquals("tensor<int8>(x[2]):[125, 44]", binarizedResult.toAbbreviatedString()); + + var queryContext = new Embedder.Context("query.qt"); + Tensor queryResult = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor<float>(x[8])"))); + assertEquals(1.0, queryResult.multiply(queryResult).sum().asDouble(), 1e-3); + queryResult = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor<float>(x[16])"))); + assertEquals(1.0, queryResult.multiply(queryResult).sum().asDouble(), 1e-3); + Tensor binarizedResultQuery = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor<int8>(x[2])"))); + assertNotEquals(binarizedResult.toAbbreviatedString(), binarizedResultQuery.toAbbreviatedString()); + assertEquals("tensor<int8>(x[2]):[119, -116]", binarizedResultQuery.toAbbreviatedString()); + } + + @Test + public void testPrepend() { + var context = new Embedder.Context("schema.indexing"); + String input = "This is a test"; + var embedder = getNormalizePrefixdEmbedder(); + var result = embedder.prependInstruction(input, context); + assertEquals("This is a document: This is a test", result); + var queryContext = new Embedder.Context("query.qt"); + var queryResult = embedder.prependInstruction(input, queryContext); + assertEquals("Represent this text: This is a test", queryResult); + } + private static HuggingFaceEmbedder getEmbedder() { String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx"; @@ -165,6 +198,20 @@ public class HuggingFaceEmbedderTest { return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); } + private static HuggingFaceEmbedder getNormalizePrefixdEmbedder() { + String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; + String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx"; + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); + HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder(); + builder.tokenizerPath(ModelReference.valueOf(vocabPath)); + builder.transformerModel(ModelReference.valueOf(modelPath)); + builder.transformerGpuDevice(-1); + builder.normalize(true); + builder.prependQuery("Represent this text:"); + builder.prependDocument("This is a document:"); + return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + } + public static Tensor expandBitTensor(Tensor packed) { var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.DOUBLE, "big"); var context = new MapContext(); |