summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-04-08 21:52:40 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2024-04-08 21:52:40 +0200
commit4d233b5379b8dc4b94901f8df8acda0a6f2c4420 (patch)
tree006e1ec72bc0ae46a86b6cded4c72f936ac45483
parent6715471dceedbbda28d9d29ffb9d441ebfb848a2 (diff)
cache more and re-factor
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java119
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java (renamed from model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java)58
2 files changed, 109 insertions, 68 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 1b9f9dd2fe3..20d8b6362d3 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -106,54 +106,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
@SuppressWarnings("unchecked")
@Override
- public Tensor embed(String s, Context context, TensorType tensorType) {
- var start = System.nanoTime();
- var encoding = tokenizer.encode(s, context.getLanguage());
- runtime.sampleSequenceLength(encoding.ids().size(), context);
- Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1");
- Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1");
- Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1");
-
- Map<String, Tensor> inputs;
- if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) {
- inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
- attentionMaskName, attentionMask.expand("d0"));
- } else {
- inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
- attentionMaskName, attentionMask.expand("d0"),
- tokenTypeIdsName, tokenTypeIds.expand("d0"));
+ public Tensor embed(String text, Context context, TensorType tensorType) {
+ if (tensorType.dimensions().size() != 1) {
+ throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': should only have one dimension.");
}
- IndexedTensor tokenEmbeddings = (IndexedTensor) evaluateIfNotPresent(inputs,context,s).get(outputName);
- long[] resultShape = tokenEmbeddings.shape();
- //shape batch, sequence, embedding dimensionality
- if (resultShape.length != 3) {
- throw new IllegalArgumentException("" +
- "Expected 3 output dimensions for output name '" +
- outputName + "': [batch, sequence, embedding], got " + resultShape.length);
+ if (!tensorType.dimensions().get(0).isIndexed()) {
+ throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed.");
}
- Tensor result;
- if (tensorType.valueType() == TensorType.Value.INT8) { // binary quantization
- long outputDimensions = resultShape[2];
- long targetDim = tensorType.dimensions().get(0).size().get();
- //🪆 flexibility - packing only the first 8*targetDim float values from the model output
- long floatDimensions = 8 * targetDim;
- if(floatDimensions > outputDimensions) {
- throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s");
- }
- //perform pooling and normalizing using float version before binary quantization
- TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).
- indexed(tensorType.indexedSubtype().dimensions().get(0).name(),
- floatDimensions).build();
- result = poolingStrategy.toSentenceEmbedding(poolingType, tokenEmbeddings, attentionMask);
- result = normalize? normalize(result, poolingType) : result;
- result = binarize((IndexedTensor) result, tensorType);
-
- } else { // regular float embeddings up to the target dimensionality
- result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask);
- result = normalize ? normalize(result, tensorType) : result;
+ var embeddingResult = lookupOrEvaluate(context, text);
+ IndexedTensor tokenEmbeddings = embeddingResult.output;
+ if (tensorType.valueType() == TensorType.Value.INT8) {
+ return binaryQuantization(embeddingResult, tensorType);
+ } else {
+ Tensor result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, embeddingResult.attentionMask);
+ return normalize ? normalize(result, tensorType) : result;
}
- runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context);
- return result;
}
Tensor normalize(Tensor embedding, TensorType tensorType) {
@@ -175,15 +142,56 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
- /**
- * Evaluate the model if the result is not present in the context cache.
- * @param inputs the tensor inputs
- * @param context the context accompanying the request, a singleton per embedder instance and request
- * @param hashKey the key to the cached value
- * @return the model output
- */
- protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
+ private HuggingFaceEmbedder.HFEmbeddingResult lookupOrEvaluate(Context context, String text) {
+ var key = new HFEmbedderCacheKey(context.getEmbedderId(), text);
+ return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text));
+ }
+
+ private HuggingFaceEmbedder.HFEmbeddingResult evaluate(Context context, String text) {
+ var start = System.nanoTime();
+ var encoding = tokenizer.encode(text, context.getLanguage());
+ runtime.sampleSequenceLength(encoding.ids().size(), context);
+ Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1");
+ Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1");
+ Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1");
+
+ Map<String, Tensor> inputs;
+ if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"));
+ } else {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"),
+ tokenTypeIdsName, tokenTypeIds.expand("d0"));
+ }
+ IndexedTensor tokenEmbeddings = (IndexedTensor) evaluator.evaluate(inputs).get(outputName);
+ long[] resultShape = tokenEmbeddings.shape();
+ //shape batch, sequence, embedding dimensionality
+ if (resultShape.length != 3) {
+ throw new IllegalArgumentException("" +
+ "Expected 3 output dimensions for output name '" +
+ outputName + "': [batch, sequence, embedding], got " + resultShape.length);
+ }
+ runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context);
+ return new HFEmbeddingResult(tokenEmbeddings, attentionMask, context.getEmbedderId());
+ }
+
+ private Tensor binaryQuantization(HuggingFaceEmbedder.HFEmbeddingResult embeddingResult, TensorType tensorType) {
+ long outputDimensions = embeddingResult.output().shape()[2];
+ long targetDim = tensorType.dimensions().get(0).size().get();
+ //🪆 flexibility - packing only the first 8*targetDim float values from the model output
+ long floatDimensions = 8 * targetDim;
+ if(floatDimensions > outputDimensions) {
+ throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s");
+ }
+ //perform pooling and normalizing using float version before binary quantization
+ TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).
+ indexed(tensorType.indexedSubtype().dimensions().get(0).name(),
+ floatDimensions).build();
+ Tensor result = poolingStrategy.toSentenceEmbedding(poolingType, embeddingResult.output(), embeddingResult.attentionMask());
+ result = normalize? normalize(result, poolingType) : result;
+ result = binarize((IndexedTensor) result, tensorType);
+ return result;
}
/**
@@ -222,6 +230,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
-
+ protected record HFEmbeddingResult(IndexedTensor output, Tensor attentionMask, String embedderId) {}
+ protected record HFEmbedderCacheKey(String embedderId, Object embeddedValue) { }
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
index 89f9c63ad5f..d504d77cc9b 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
@@ -1,7 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.embedding;
+package ai.vespa.embedding.huggingface;
+
-import ai.vespa.embedding.huggingface.HuggingFaceEmbedder;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
@@ -15,8 +15,7 @@ import org.junit.Test;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeTrue;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.*;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -48,8 +47,8 @@ public class HuggingFaceEmbedderTest {
private void assertPackRight(String input, String expected, TensorType type) {
Tensor inputTensor = Tensor.from(input);
Tensor result = HuggingFaceEmbedder.binarize((IndexedTensor) inputTensor, type);
- assertEquals(expected.toString(), result.toString());
- //Verify against what is done in ranking with unpack_bits
+ assertEquals(expected, result.toString());
+ //Verify that the unpack_bits ranking feature produce compatible output
Tensor unpacked = expandBitTensor(result);
assertEquals(inputTensor.toString(), unpacked.toString());
}
@@ -57,19 +56,34 @@ public class HuggingFaceEmbedderTest {
@Test
public void testCaching() {
var context = new Embedder.Context("schema.indexing");
+ var myEmbedderId = "my-hf-embedder";
+ context.setEmbedderId(myEmbedderId);
var input = "This is a test string to embed";
- embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[8])"));
- var modelOuput = context.getCachedValue(input);
+ Tensor result = embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[8])"));
+ HuggingFaceEmbedder.HFEmbedderCacheKey key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, input);
+ var modelOuput = context.getCachedValue(key);
+ assertNotNull(modelOuput);
- embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[4])"));
- var modelOuput2 = context.getCachedValue(input);
+ Tensor binaryResult = embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(x[4])"));
+ var modelOuput2 = context.getCachedValue(key);
assertEquals(modelOuput, modelOuput2);
+ assertNotEquals(result, binaryResult);
- var input2 = "This is a different test string to embed";
- embedder.embed(input2, context,TensorType.fromSpec("tensor<float>(x[4])"));
- var modelOuput3 = context.getCachedValue(input2);
+ var anotherInput = "This is a different test string to embed with the same embedder";
+ embedder.embed(anotherInput, context,TensorType.fromSpec("tensor<float>(x[4])"));
+ key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, anotherInput);
+ var modelOuput3 = context.getCachedValue(key);
assertNotEquals(modelOuput, modelOuput3);
+
+ //context cache is shared
+ var copyContext = context.copy();
+ var anotherEmbedderId = "another-hf-embedder";
+ copyContext.setEmbedderId(anotherEmbedderId);
+ key = new HuggingFaceEmbedder.HFEmbedderCacheKey(anotherEmbedderId, input);
+ assertNull(copyContext.getCachedValue(key));
+ embedder.embed(input, copyContext,TensorType.fromSpec("tensor<int8>(x[2])"));
+ assertNotEquals(modelOuput, copyContext.getCachedValue(key));
}
@Test
public void testEmbedder() {
@@ -111,6 +125,24 @@ public class HuggingFaceEmbedderTest {
assertEquals("tensor<int8>(x[2]):[119, 44]", binarizedResult.toAbbreviatedString());
}
+ @Test
+ public void testThatWrongTensorTypeThrows() {
+ var context = new Embedder.Context("schema.indexing");
+ String input = "This is a test";
+ assertThrows(IllegalArgumentException.class, () -> {
+ // throws because the target tensor type is mapped
+ embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{})")));
+ });
+ assertThrows(IllegalArgumentException.class, () -> {
+ // throws because the target tensor is 0d
+ embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[0]")));
+ });
+ assertThrows(IllegalArgumentException.class, () -> {
+ // throws because the target tensor is 2d
+ embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{}, y[2])")));
+ });
+ }
+
private static HuggingFaceEmbedder getEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";