diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-04 09:15:10 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-04 09:15:10 +0200 |
commit | 531bc532c592703221e232d817850d802cdcfd11 (patch) | |
tree | 69d9a60d6a8ea48dbea331906e775589bce15dd7 /model-integration/src/main/java/ai | |
parent | a009cdd704f427282c3c9ed3b70a7caf9d536c7e (diff) |
Support for dimensionality flexbility and caching onnx inference output using Context cache
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 60 |
1 files changed, 34 insertions, 26 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index f43f3834a65..2f4c0343bf6 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -181,34 +181,25 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { if (tensorType.valueType() == TensorType.Value.INT8) throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); - var start = System.nanoTime(); var encoding = tokenizer.encode(text, context.getLanguage()); runtime.sampleSequenceLength(encoding.ids().size(), context); TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true); - Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1"); Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), attentionMaskName, attentionMaskTensor.expand("d0")); - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings; - - int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); - if (dims != result.shape()[2]) { - throw new IllegalArgumentException("Token vector dimensionality does not" + - " match indexed dimensionality of " + dims); - } - Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size()); + IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName); + Tensor resultTensor = toFloatTensor(modelOutput, tensorType, input.inputIds.size()); runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); return resultTensor; } - + @SuppressWarnings("unchecked") protected Tensor embedDocument(String text, Context context, TensorType tensorType) { var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); runtime.sampleSequenceLength(encoding.ids().size(), context); @@ -218,19 +209,34 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), attentionMaskName, attentionMaskTensor.expand("d0")); - - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings; - Tensor contextualEmbeddings; - int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens. + IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName); + Tensor resultEmbeddings; + int maxTokens = input.inputIds.size(); if (tensorType.valueType() == TensorType.Value.INT8) { - contextualEmbeddings = toBitTensor(result, tensorType, maxTokens); + resultEmbeddings = toBitTensor(modelOutput, tensorType, maxTokens); } else { - contextualEmbeddings = toFloatTensor(result, tensorType, maxTokens); + resultEmbeddings = toFloatTensor(modelOutput, tensorType, maxTokens); } runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); - return contextualEmbeddings; + return resultEmbeddings; + } + + /** + * Evaluate the model if the result is not present in the context cache. + * @param inputs the tensor inputs + * @param context the context accompanying the request, a singleton per embedder instance and request + * @param hashKey the key to the cached value + * @return the model output + */ + @SuppressWarnings("unchecked") + protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) { + if (context.getCachedValue(hashKey) == null) { + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + context.putCachedValue(hashKey, outputs); + return outputs; + } else { + return (Map<String, Tensor>) context.getCachedValue(hashKey); + } } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { @@ -241,13 +247,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); int resultDimensionality = (int)result.shape()[2]; - if (resultDimensionality != wantedDimensionality) { + if (wantedDimensionality > resultDimensionality) { throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + " dimensions into tensor with " + wantedDimensionality); } Tensor.Builder builder = Tensor.Builder.of(type); for (int token = 0; token < nTokens; token++) { - for (int d = 0; d < resultDimensionality; d++) { + for (int d = 0; d < wantedDimensionality; d++) { var value = result.get(0,token,d); // batch, sequence token, dimension builder.cell(TensorAddress.of(token,d),value); } @@ -265,8 +271,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { if (size != 1) throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + //Allow using the first n float dimensions to pack into int8 + int floatDimensionality = 8 * wantedDimensionality; int resultDimensionality = (int)result.shape()[2]; - if (resultDimensionality != 8 * wantedDimensionality) { + if (floatDimensionality > resultDimensionality) { throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + " + dimensions into " + wantedDimensionality + " dimensions"); } @@ -274,7 +282,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { for (int token = 0; token < nTokens; token++) { BitSet bitSet = new BitSet(8); int key = 0; - for (int d = 0; d < result.shape()[2]; d++) { + for (int d = 0; d < floatDimensionality; d++) { var value = result.get(0, token, d); // batch, sequence token, dimension int bitIndex = 7 - (d % 8); if (value > 0.0) { |