aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-04-04 09:15:10 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2024-04-04 09:15:10 +0200
commit531bc532c592703221e232d817850d802cdcfd11 (patch)
tree69d9a60d6a8ea48dbea331906e775589bce15dd7 /model-integration/src/main/java/ai
parenta009cdd704f427282c3c9ed3b70a7caf9d536c7e (diff)
Support for dimensionality flexbility and caching onnx inference output using Context cache
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java60
1 files changed, 34 insertions, 26 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index f43f3834a65..2f4c0343bf6 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -181,34 +181,25 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (tensorType.valueType() == TensorType.Value.INT8)
throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
-
var start = System.nanoTime();
var encoding = tokenizer.encode(text, context.getLanguage());
runtime.sampleSequenceLength(encoding.ids().size(), context);
TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true);
-
Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
attentionMaskName, attentionMaskTensor.expand("d0"));
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings;
-
- int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue();
- if (dims != result.shape()[2]) {
- throw new IllegalArgumentException("Token vector dimensionality does not" +
- " match indexed dimensionality of " + dims);
- }
- Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size());
+ IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName);
+ Tensor resultTensor = toFloatTensor(modelOutput, tensorType, input.inputIds.size());
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
return resultTensor;
}
-
+ @SuppressWarnings("unchecked")
protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
var start = System.nanoTime();
+
var encoding = tokenizer.encode(text, context.getLanguage());
runtime.sampleSequenceLength(encoding.ids().size(), context);
@@ -218,19 +209,34 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
attentionMaskName, attentionMaskTensor.expand("d0"));
-
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings;
- Tensor contextualEmbeddings;
- int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens.
+ IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName);
+ Tensor resultEmbeddings;
+ int maxTokens = input.inputIds.size();
if (tensorType.valueType() == TensorType.Value.INT8) {
- contextualEmbeddings = toBitTensor(result, tensorType, maxTokens);
+ resultEmbeddings = toBitTensor(modelOutput, tensorType, maxTokens);
} else {
- contextualEmbeddings = toFloatTensor(result, tensorType, maxTokens);
+ resultEmbeddings = toFloatTensor(modelOutput, tensorType, maxTokens);
}
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
- return contextualEmbeddings;
+ return resultEmbeddings;
+ }
+
+ /**
+ * Evaluate the model if the result is not present in the context cache.
+ * @param inputs the tensor inputs
+ * @param context the context accompanying the request, a singleton per embedder instance and request
+ * @param hashKey the key to the cached value
+ * @return the model output
+ */
+ @SuppressWarnings("unchecked")
+ protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
+ if (context.getCachedValue(hashKey) == null) {
+ Map<String, Tensor> outputs = evaluator.evaluate(inputs);
+ context.putCachedValue(hashKey, outputs);
+ return outputs;
+ } else {
+ return (Map<String, Tensor>) context.getCachedValue(hashKey);
+ }
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
@@ -241,13 +247,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
int resultDimensionality = (int)result.shape()[2];
- if (resultDimensionality != wantedDimensionality) {
+ if (wantedDimensionality > resultDimensionality) {
throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality +
" dimensions into tensor with " + wantedDimensionality);
}
Tensor.Builder builder = Tensor.Builder.of(type);
for (int token = 0; token < nTokens; token++) {
- for (int d = 0; d < resultDimensionality; d++) {
+ for (int d = 0; d < wantedDimensionality; d++) {
var value = result.get(0,token,d); // batch, sequence token, dimension
builder.cell(TensorAddress.of(token,d),value);
}
@@ -265,8 +271,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (size != 1)
throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
+ //Allow using the first n float dimensions to pack into int8
+ int floatDimensionality = 8 * wantedDimensionality;
int resultDimensionality = (int)result.shape()[2];
- if (resultDimensionality != 8 * wantedDimensionality) {
+ if (floatDimensionality > resultDimensionality) {
throw new IllegalArgumentException("Not possible to pack " + resultDimensionality +
" + dimensions into " + wantedDimensionality + " dimensions");
}
@@ -274,7 +282,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
for (int token = 0; token < nTokens; token++) {
BitSet bitSet = new BitSet(8);
int key = 0;
- for (int d = 0; d < result.shape()[2]; d++) {
+ for (int d = 0; d < floatDimensionality; d++) {
var value = result.get(0, token, d); // batch, sequence token, dimension
int bitIndex = 7 - (d % 8);
if (value > 0.0) {