diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-04 09:16:40 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-04 09:16:40 +0200 |
commit | 3aadce672938ac990c261b97d0ca9d752c0d0cf6 (patch) | |
tree | 7b529f897ef0f25a38dc53a96df0f4a57a5cc5c3 | |
parent | 531bc532c592703221e232d817850d802cdcfd11 (diff) |
Add caching of onnx inference output using Context cache
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 49 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java | 24 |
2 files changed, 55 insertions, 18 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 169648967d7..08c98fedf3e 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -104,6 +104,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenizer.close(); } + @SuppressWarnings("unchecked") @Override public Tensor embed(String s, Context context, TensorType tensorType) { var start = System.nanoTime(); @@ -113,7 +114,6 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1"); - Map<String, Tensor> inputs; if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) { inputs = Map.of(inputIdsName, inputSequence.expand("d0"), @@ -123,9 +123,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { attentionMaskName, attentionMask.expand("d0"), tokenTypeIdsName, tokenTypeIds.expand("d0")); } - - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - IndexedTensor tokenEmbeddings = (IndexedTensor) outputs.get(outputName); + IndexedTensor tokenEmbeddings = (IndexedTensor) evaluateIfNotPresent(inputs,context,s).get(outputName); long[] resultShape = tokenEmbeddings.shape(); //shape batch, sequence, embedding dimensionality if (resultShape.length != 3) { @@ -134,24 +132,23 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { outputName + "': [batch, sequence, embedding], got " + resultShape.length); } Tensor result; - if (tensorType.valueType() == TensorType.Value.INT8) { + if (tensorType.valueType() == TensorType.Value.INT8) { // binary quantization long outputDimensions = resultShape[2]; long targetDim = tensorType.dimensions().get(0).size().get(); - - if(targetDim * 8 > outputDimensions) { + //🪆 flexibility - packing only the first 8*targetDim float values from the model output + long floatDimensions = 8 * targetDim; + if(floatDimensions > outputDimensions) { throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s"); } - //Dimensionality flexibility 🪆 - packing only the first 8*targetDim values from the model output - long firstDimensions = 8 * targetDim; - String name = tensorType.indexedSubtype().dimensions().get(0).name(); - //perform pooling and normalizing using floating point embeddings before binarizing - //using the firstDimensions as the target dimensionality - TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).indexed(name, firstDimensions).build(); + //perform pooling and normalizing using float version before binary quantization + TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT). + indexed(tensorType.indexedSubtype().dimensions().get(0).name(), + floatDimensions).build(); result = poolingStrategy.toSentenceEmbedding(poolingType, tokenEmbeddings, attentionMask); result = normalize? normalize(result, poolingType) : result; result = binarize((IndexedTensor) result, tensorType); - } else { // regular floating points embeddings + } else { // regular float embeddings up to the target dimensionality result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); result = normalize ? normalize(result, tensorType) : result; } @@ -178,6 +175,30 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } + /** + * Evaluate the model if the result is not present in the context cache. + * @param inputs the tensor inputs + * @param context the context accompanying the request, a singleton per embedder instance and request + * @param hashKey the key to the cached value + * @return the model output + */ + @SuppressWarnings("unchecked") + protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) { + if (context.getCachedValue(hashKey) == null) { + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + context.putCachedValue(hashKey, outputs); + return outputs; + } else { + return (Map<String, Tensor>) context.getCachedValue(hashKey); + } + } + + /** + * Binary quantization of the embedding into a tensor of type int8 with the specified dimensions. + * @param embedding + * @param tensorType + * @return + */ static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) { Tensor.Builder builder = Tensor.Builder.of(tensorType); BitSet bitSet = new BitSet(8); diff --git a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java index 1ce1d955b00..89f9c63ad5f 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java @@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TensorAddress; import org.junit.Test; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assume.assumeTrue; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -26,7 +27,6 @@ public class HuggingFaceEmbedderTest { static HuggingFaceEmbedder embedder = getEmbedder(); static HuggingFaceEmbedder normalizedEmbedder = getNormalizedEmbedder(); - static Embedder.Context context = new Embedder.Context("schema.indexing"); @Test public void testBinarization() { @@ -55,9 +55,26 @@ public class HuggingFaceEmbedderTest { } @Test + public void testCaching() { + var context = new Embedder.Context("schema.indexing"); + + var input = "This is a test string to embed"; + embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[8])")); + var modelOuput = context.getCachedValue(input); + + embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[4])")); + var modelOuput2 = context.getCachedValue(input); + assertEquals(modelOuput, modelOuput2); + + var input2 = "This is a different test string to embed"; + embedder.embed(input2, context,TensorType.fromSpec("tensor<float>(x[4])")); + var modelOuput3 = context.getCachedValue(input2); + assertNotEquals(modelOuput, modelOuput3); + } + @Test public void testEmbedder() { + var context = new Embedder.Context("schema.indexing"); String input = "This is a test"; - Tensor expected = Tensor.from("tensor<float>(x[8]):[-0.666, 0.335, 0.227, 0.0919, -0.069, 0.323, 0.422, 0.270]"); Tensor result = embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])"))); for(int i = 0; i < 8; i++) { @@ -85,10 +102,9 @@ public class HuggingFaceEmbedderTest { @Test public void testEmbedderWithNormalization() { String input = "This is a test"; - + var context = new Embedder.Context("schema.indexing"); Tensor result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])"))); assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3); - result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[16])"))); assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3); Tensor binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])"))); |