aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-04-25 15:32:04 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2024-04-25 15:32:04 +0200
commita374e90b5f95b3f3c533a4d0302ac0e66c32668f (patch)
tree5496e8603a071aea36a5f0f8d24bcdf266bb3b23 /model-integration/src/main/java/ai
parent117cace612ab00de27b8ec5e77896056e449bf33 (diff)
add prepend support
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java18
1 files changed, 17 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 20d8b6362d3..3e5dcfda3e9 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -39,6 +39,10 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final OnnxEvaluator evaluator;
private final PoolingStrategy poolingStrategy;
+ private final String prependQuery;
+
+ private final String prependDocument;
+
@Inject
public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFaceEmbedderConfig config) {
this.runtime = runtime;
@@ -47,6 +51,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
normalize = config.normalize();
+ prependQuery = config.prependQuery();
+ prependDocument = config.prependDocument();
var tokenizerPath = Paths.get(config.tokenizerPath().toString());
var builder = new HuggingFaceTokenizer.Builder()
.addSpecialTokens(true)
@@ -113,7 +119,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
if (!tensorType.dimensions().get(0).isIndexed()) {
throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed.");
}
- var embeddingResult = lookupOrEvaluate(context, text);
+ var embeddingResult = lookupOrEvaluate(context, prependInstruction(text, context));
IndexedTensor tokenEmbeddings = embeddingResult.output;
if (tensorType.valueType() == TensorType.Value.INT8) {
return binaryQuantization(embeddingResult, tensorType);
@@ -123,6 +129,16 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
}
}
+ String prependInstruction(String text, Context context) {
+ if (prependQuery != null && !prependQuery.isEmpty() && context.getDestination().startsWith("query")) {
+ return prependQuery + " " + text;
+ }
+ if (prependDocument != null && !prependDocument.isEmpty()){
+ return prependDocument + " " + text;
+ }
+ return text;
+ }
+
Tensor normalize(Tensor embedding, TensorType tensorType) {
double sumOfSquares = 0.0;