From a374e90b5f95b3f3c533a4d0302ac0e66c32668f Mon Sep 17 00:00:00 2001 From: Jo Kristian Bergum Date: Thu, 25 Apr 2024 15:32:04 +0200 Subject: add prepend support --- .../embedding/huggingface/HuggingFaceEmbedder.java | 18 ++++++++- .../huggingface/HuggingFaceEmbedderTest.java | 47 ++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) (limited to 'model-integration') diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 20d8b6362d3..3e5dcfda3e9 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -39,6 +39,10 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final OnnxEvaluator evaluator; private final PoolingStrategy poolingStrategy; + private final String prependQuery; + + private final String prependDocument; + @Inject public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFaceEmbedderConfig config) { this.runtime = runtime; @@ -47,6 +51,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); normalize = config.normalize(); + prependQuery = config.prependQuery(); + prependDocument = config.prependDocument(); var tokenizerPath = Paths.get(config.tokenizerPath().toString()); var builder = new HuggingFaceTokenizer.Builder() .addSpecialTokens(true) @@ -113,7 +119,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { if (!tensorType.dimensions().get(0).isIndexed()) { throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed."); } - var embeddingResult = lookupOrEvaluate(context, text); + var embeddingResult = lookupOrEvaluate(context, prependInstruction(text, context)); IndexedTensor tokenEmbeddings = embeddingResult.output; if (tensorType.valueType() == TensorType.Value.INT8) { return binaryQuantization(embeddingResult, tensorType); @@ -123,6 +129,16 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { } } + String prependInstruction(String text, Context context) { + if (prependQuery != null && !prependQuery.isEmpty() && context.getDestination().startsWith("query")) { + return prependQuery + " " + text; + } + if (prependDocument != null && !prependDocument.isEmpty()){ + return prependDocument + " " + text; + } + return text; + } + Tensor normalize(Tensor embedding, TensorType tensorType) { double sumOfSquares = 0.0; diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java index d504d77cc9b..c2c37db31f6 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java @@ -143,6 +143,39 @@ public class HuggingFaceEmbedderTest { }); } + @Test + public void testEmbedderWithNormalizationAndPrefix() { + String input = "This is a test"; + var context = new Embedder.Context("schema.indexing"); + Tensor result = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor(x[8])"))); + assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3); + result = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor(x[16])"))); + assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3); + Tensor binarizedResult = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor(x[2])"))); + assertEquals("tensor(x[2]):[125, 44]", binarizedResult.toAbbreviatedString()); + + var queryContext = new Embedder.Context("query.qt"); + Tensor queryResult = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor(x[8])"))); + assertEquals(1.0, queryResult.multiply(queryResult).sum().asDouble(), 1e-3); + queryResult = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor(x[16])"))); + assertEquals(1.0, queryResult.multiply(queryResult).sum().asDouble(), 1e-3); + Tensor binarizedResultQuery = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor(x[2])"))); + assertNotEquals(binarizedResult.toAbbreviatedString(), binarizedResultQuery.toAbbreviatedString()); + assertEquals("tensor(x[2]):[119, -116]", binarizedResultQuery.toAbbreviatedString()); + } + + @Test + public void testPrepend() { + var context = new Embedder.Context("schema.indexing"); + String input = "This is a test"; + var embedder = getNormalizePrefixdEmbedder(); + var result = embedder.prependInstruction(input, context); + assertEquals("This is a document: This is a test", result); + var queryContext = new Embedder.Context("query.qt"); + var queryResult = embedder.prependInstruction(input, queryContext); + assertEquals("Represent this text: This is a test", queryResult); + } + private static HuggingFaceEmbedder getEmbedder() { String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx"; @@ -165,6 +198,20 @@ public class HuggingFaceEmbedderTest { return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); } + private static HuggingFaceEmbedder getNormalizePrefixdEmbedder() { + String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; + String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx"; + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); + HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder(); + builder.tokenizerPath(ModelReference.valueOf(vocabPath)); + builder.transformerModel(ModelReference.valueOf(modelPath)); + builder.transformerGpuDevice(-1); + builder.normalize(true); + builder.prependQuery("Represent this text:"); + builder.prependDocument("This is a document:"); + return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + } + public static Tensor expandBitTensor(Tensor packed) { var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.DOUBLE, "big"); var context = new MapContext(); -- cgit v1.2.3