diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-08 21:52:40 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-08 21:52:40 +0200 |
commit | 4d233b5379b8dc4b94901f8df8acda0a6f2c4420 (patch) | |
tree | 006e1ec72bc0ae46a86b6cded4c72f936ac45483 | |
parent | 6715471dceedbbda28d9d29ffb9d441ebfb848a2 (diff) |
cache more and re-factor
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 119 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java (renamed from model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java) | 58 |
2 files changed, 109 insertions, 68 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 1b9f9dd2fe3..20d8b6362d3 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -106,54 +106,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @SuppressWarnings("unchecked") @Override - public Tensor embed(String s, Context context, TensorType tensorType) { - var start = System.nanoTime(); - var encoding = tokenizer.encode(s, context.getLanguage()); - runtime.sampleSequenceLength(encoding.ids().size(), context); - Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1"); - Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); - Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1"); - - Map<String, Tensor> inputs; - if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) { - inputs = Map.of(inputIdsName, inputSequence.expand("d0"), - attentionMaskName, attentionMask.expand("d0")); - } else { - inputs = Map.of(inputIdsName, inputSequence.expand("d0"), - attentionMaskName, attentionMask.expand("d0"), - tokenTypeIdsName, tokenTypeIds.expand("d0")); + public Tensor embed(String text, Context context, TensorType tensorType) { + if (tensorType.dimensions().size() != 1) { + throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': should only have one dimension."); } - IndexedTensor tokenEmbeddings = (IndexedTensor) evaluateIfNotPresent(inputs,context,s).get(outputName); - long[] resultShape = tokenEmbeddings.shape(); - //shape batch, sequence, embedding dimensionality - if (resultShape.length != 3) { - throw new IllegalArgumentException("" + - "Expected 3 output dimensions for output name '" + - outputName + "': [batch, sequence, embedding], got " + resultShape.length); + if (!tensorType.dimensions().get(0).isIndexed()) { + throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed."); } - Tensor result; - if (tensorType.valueType() == TensorType.Value.INT8) { // binary quantization - long outputDimensions = resultShape[2]; - long targetDim = tensorType.dimensions().get(0).size().get(); - //🪆 flexibility - packing only the first 8*targetDim float values from the model output - long floatDimensions = 8 * targetDim; - if(floatDimensions > outputDimensions) { - throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s"); - } - //perform pooling and normalizing using float version before binary quantization - TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT). - indexed(tensorType.indexedSubtype().dimensions().get(0).name(), - floatDimensions).build(); - result = poolingStrategy.toSentenceEmbedding(poolingType, tokenEmbeddings, attentionMask); - result = normalize? normalize(result, poolingType) : result; - result = binarize((IndexedTensor) result, tensorType); - - } else { // regular float embeddings up to the target dimensionality - result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); - result = normalize ? normalize(result, tensorType) : result; + var embeddingResult = lookupOrEvaluate(context, text); + IndexedTensor tokenEmbeddings = embeddingResult.output; + if (tensorType.valueType() == TensorType.Value.INT8) { + return binaryQuantization(embeddingResult, tensorType); + } else { + Tensor result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, embeddingResult.attentionMask); + return normalize ? normalize(result, tensorType) : result; } - runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); - return result; } Tensor normalize(Tensor embedding, TensorType tensorType) { @@ -175,15 +142,56 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - /** - * Evaluate the model if the result is not present in the context cache. - * @param inputs the tensor inputs - * @param context the context accompanying the request, a singleton per embedder instance and request - * @param hashKey the key to the cached value - * @return the model output - */ - protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) { - return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs)); + private HuggingFaceEmbedder.HFEmbeddingResult lookupOrEvaluate(Context context, String text) { + var key = new HFEmbedderCacheKey(context.getEmbedderId(), text); + return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text)); + } + + private HuggingFaceEmbedder.HFEmbeddingResult evaluate(Context context, String text) { + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1"); + Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); + Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1"); + + Map<String, Tensor> inputs; + if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0")); + } else { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0"), + tokenTypeIdsName, tokenTypeIds.expand("d0")); + } + IndexedTensor tokenEmbeddings = (IndexedTensor) evaluator.evaluate(inputs).get(outputName); + long[] resultShape = tokenEmbeddings.shape(); + //shape batch, sequence, embedding dimensionality + if (resultShape.length != 3) { + throw new IllegalArgumentException("" + + "Expected 3 output dimensions for output name '" + + outputName + "': [batch, sequence, embedding], got " + resultShape.length); + } + runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); + return new HFEmbeddingResult(tokenEmbeddings, attentionMask, context.getEmbedderId()); + } + + private Tensor binaryQuantization(HuggingFaceEmbedder.HFEmbeddingResult embeddingResult, TensorType tensorType) { + long outputDimensions = embeddingResult.output().shape()[2]; + long targetDim = tensorType.dimensions().get(0).size().get(); + //🪆 flexibility - packing only the first 8*targetDim float values from the model output + long floatDimensions = 8 * targetDim; + if(floatDimensions > outputDimensions) { + throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s"); + } + //perform pooling and normalizing using float version before binary quantization + TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT). + indexed(tensorType.indexedSubtype().dimensions().get(0).name(), + floatDimensions).build(); + Tensor result = poolingStrategy.toSentenceEmbedding(poolingType, embeddingResult.output(), embeddingResult.attentionMask()); + result = normalize? normalize(result, poolingType) : result; + result = binarize((IndexedTensor) result, tensorType); + return result; } /** @@ -222,6 +230,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - + protected record HFEmbeddingResult(IndexedTensor output, Tensor attentionMask, String embedderId) {} + protected record HFEmbedderCacheKey(String embedderId, Object embeddedValue) { } } diff --git a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java index 89f9c63ad5f..d504d77cc9b 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java @@ -1,7 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.embedding; +package ai.vespa.embedding.huggingface; + -import ai.vespa.embedding.huggingface.HuggingFaceEmbedder; import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.ModelReference; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; @@ -15,8 +15,7 @@ import org.junit.Test; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assume.assumeTrue; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -48,8 +47,8 @@ public class HuggingFaceEmbedderTest { private void assertPackRight(String input, String expected, TensorType type) { Tensor inputTensor = Tensor.from(input); Tensor result = HuggingFaceEmbedder.binarize((IndexedTensor) inputTensor, type); - assertEquals(expected.toString(), result.toString()); - //Verify against what is done in ranking with unpack_bits + assertEquals(expected, result.toString()); + //Verify that the unpack_bits ranking feature produce compatible output Tensor unpacked = expandBitTensor(result); assertEquals(inputTensor.toString(), unpacked.toString()); } @@ -57,19 +56,34 @@ public class HuggingFaceEmbedderTest { @Test public void testCaching() { var context = new Embedder.Context("schema.indexing"); + var myEmbedderId = "my-hf-embedder"; + context.setEmbedderId(myEmbedderId); var input = "This is a test string to embed"; - embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[8])")); - var modelOuput = context.getCachedValue(input); + Tensor result = embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[8])")); + HuggingFaceEmbedder.HFEmbedderCacheKey key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, input); + var modelOuput = context.getCachedValue(key); + assertNotNull(modelOuput); - embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[4])")); - var modelOuput2 = context.getCachedValue(input); + Tensor binaryResult = embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(x[4])")); + var modelOuput2 = context.getCachedValue(key); assertEquals(modelOuput, modelOuput2); + assertNotEquals(result, binaryResult); - var input2 = "This is a different test string to embed"; - embedder.embed(input2, context,TensorType.fromSpec("tensor<float>(x[4])")); - var modelOuput3 = context.getCachedValue(input2); + var anotherInput = "This is a different test string to embed with the same embedder"; + embedder.embed(anotherInput, context,TensorType.fromSpec("tensor<float>(x[4])")); + key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, anotherInput); + var modelOuput3 = context.getCachedValue(key); assertNotEquals(modelOuput, modelOuput3); + + //context cache is shared + var copyContext = context.copy(); + var anotherEmbedderId = "another-hf-embedder"; + copyContext.setEmbedderId(anotherEmbedderId); + key = new HuggingFaceEmbedder.HFEmbedderCacheKey(anotherEmbedderId, input); + assertNull(copyContext.getCachedValue(key)); + embedder.embed(input, copyContext,TensorType.fromSpec("tensor<int8>(x[2])")); + assertNotEquals(modelOuput, copyContext.getCachedValue(key)); } @Test public void testEmbedder() { @@ -111,6 +125,24 @@ public class HuggingFaceEmbedderTest { assertEquals("tensor<int8>(x[2]):[119, 44]", binarizedResult.toAbbreviatedString()); } + @Test + public void testThatWrongTensorTypeThrows() { + var context = new Embedder.Context("schema.indexing"); + String input = "This is a test"; + assertThrows(IllegalArgumentException.class, () -> { + // throws because the target tensor type is mapped + embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{})"))); + }); + assertThrows(IllegalArgumentException.class, () -> { + // throws because the target tensor is 0d + embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[0]"))); + }); + assertThrows(IllegalArgumentException.class, () -> { + // throws because the target tensor is 2d + embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{}, y[2])"))); + }); + } + private static HuggingFaceEmbedder getEmbedder() { String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx"; |