aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java11
1 files changed, 2 insertions, 9 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 2f4c0343bf6..a9d6d308df8 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -196,7 +196,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
return resultTensor;
}
- @SuppressWarnings("unchecked")
+
protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
var start = System.nanoTime();
@@ -228,15 +228,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
* @param hashKey the key to the cached value
* @return the model output
*/
- @SuppressWarnings("unchecked")
protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- if (context.getCachedValue(hashKey) == null) {
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- context.putCachedValue(hashKey, outputs);
- return outputs;
- } else {
- return (Map<String, Tensor>) context.getCachedValue(hashKey);
- }
+ return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {