summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-04-04 09:16:40 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2024-04-04 09:16:40 +0200
commit3aadce672938ac990c261b97d0ca9d752c0d0cf6 (patch)
tree7b529f897ef0f25a38dc53a96df0f4a57a5cc5c3
parent531bc532c592703221e232d817850d802cdcfd11 (diff)
Add caching of onnx inference output using Context cache
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java49
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java24
2 files changed, 55 insertions, 18 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 169648967d7..08c98fedf3e 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -104,6 +104,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenizer.close();
}
+ @SuppressWarnings("unchecked")
@Override
public Tensor embed(String s, Context context, TensorType tensorType) {
var start = System.nanoTime();
@@ -113,7 +114,6 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1");
Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1");
-
Map<String, Tensor> inputs;
if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) {
inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
@@ -123,9 +123,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
attentionMaskName, attentionMask.expand("d0"),
tokenTypeIdsName, tokenTypeIds.expand("d0"));
}
-
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- IndexedTensor tokenEmbeddings = (IndexedTensor) outputs.get(outputName);
+ IndexedTensor tokenEmbeddings = (IndexedTensor) evaluateIfNotPresent(inputs,context,s).get(outputName);
long[] resultShape = tokenEmbeddings.shape();
//shape batch, sequence, embedding dimensionality
if (resultShape.length != 3) {
@@ -134,24 +132,23 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
outputName + "': [batch, sequence, embedding], got " + resultShape.length);
}
Tensor result;
- if (tensorType.valueType() == TensorType.Value.INT8) {
+ if (tensorType.valueType() == TensorType.Value.INT8) { // binary quantization
long outputDimensions = resultShape[2];
long targetDim = tensorType.dimensions().get(0).size().get();
-
- if(targetDim * 8 > outputDimensions) {
+ //🪆 flexibility - packing only the first 8*targetDim float values from the model output
+ long floatDimensions = 8 * targetDim;
+ if(floatDimensions > outputDimensions) {
throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s");
}
- //Dimensionality flexibility 🪆 - packing only the first 8*targetDim values from the model output
- long firstDimensions = 8 * targetDim;
- String name = tensorType.indexedSubtype().dimensions().get(0).name();
- //perform pooling and normalizing using floating point embeddings before binarizing
- //using the firstDimensions as the target dimensionality
- TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).indexed(name, firstDimensions).build();
+ //perform pooling and normalizing using float version before binary quantization
+ TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).
+ indexed(tensorType.indexedSubtype().dimensions().get(0).name(),
+ floatDimensions).build();
result = poolingStrategy.toSentenceEmbedding(poolingType, tokenEmbeddings, attentionMask);
result = normalize? normalize(result, poolingType) : result;
result = binarize((IndexedTensor) result, tensorType);
- } else { // regular floating points embeddings
+ } else { // regular float embeddings up to the target dimensionality
result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask);
result = normalize ? normalize(result, tensorType) : result;
}
@@ -178,6 +175,30 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
+ /**
+ * Evaluate the model if the result is not present in the context cache.
+ * @param inputs the tensor inputs
+ * @param context the context accompanying the request, a singleton per embedder instance and request
+ * @param hashKey the key to the cached value
+ * @return the model output
+ */
+ @SuppressWarnings("unchecked")
+ protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
+ if (context.getCachedValue(hashKey) == null) {
+ Map<String, Tensor> outputs = evaluator.evaluate(inputs);
+ context.putCachedValue(hashKey, outputs);
+ return outputs;
+ } else {
+ return (Map<String, Tensor>) context.getCachedValue(hashKey);
+ }
+ }
+
+ /**
+ * Binary quantization of the embedding into a tensor of type int8 with the specified dimensions.
+ * @param embedding
+ * @param tensorType
+ * @return
+ */
static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
Tensor.Builder builder = Tensor.Builder.of(tensorType);
BitSet bitSet = new BitSet(8);
diff --git a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
index 1ce1d955b00..89f9c63ad5f 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TensorAddress;
import org.junit.Test;
+import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -26,7 +27,6 @@ public class HuggingFaceEmbedderTest {
static HuggingFaceEmbedder embedder = getEmbedder();
static HuggingFaceEmbedder normalizedEmbedder = getNormalizedEmbedder();
- static Embedder.Context context = new Embedder.Context("schema.indexing");
@Test
public void testBinarization() {
@@ -55,9 +55,26 @@ public class HuggingFaceEmbedderTest {
}
@Test
+ public void testCaching() {
+ var context = new Embedder.Context("schema.indexing");
+
+ var input = "This is a test string to embed";
+ embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[8])"));
+ var modelOuput = context.getCachedValue(input);
+
+ embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[4])"));
+ var modelOuput2 = context.getCachedValue(input);
+ assertEquals(modelOuput, modelOuput2);
+
+ var input2 = "This is a different test string to embed";
+ embedder.embed(input2, context,TensorType.fromSpec("tensor<float>(x[4])"));
+ var modelOuput3 = context.getCachedValue(input2);
+ assertNotEquals(modelOuput, modelOuput3);
+ }
+ @Test
public void testEmbedder() {
+ var context = new Embedder.Context("schema.indexing");
String input = "This is a test";
-
Tensor expected = Tensor.from("tensor<float>(x[8]):[-0.666, 0.335, 0.227, 0.0919, -0.069, 0.323, 0.422, 0.270]");
Tensor result = embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
for(int i = 0; i < 8; i++) {
@@ -85,10 +102,9 @@ public class HuggingFaceEmbedderTest {
@Test
public void testEmbedderWithNormalization() {
String input = "This is a test";
-
+ var context = new Embedder.Context("schema.indexing");
Tensor result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
-
result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[16])")));
assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
Tensor binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])")));