aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--linguistics/abi-spec.json3
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Embedder.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java12
4 files changed, 12 insertions, 22 deletions
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 0bd4638bb05..91574133658 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -347,7 +347,8 @@
"public java.lang.String getEmbedderId()",
"public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)",
"public void putCachedValue(java.lang.String, java.lang.Object)",
- "public java.lang.Object getCachedValue(java.lang.String)"
+ "public java.lang.Object getCachedValue(java.lang.String)",
+ "public java.lang.Object computeCachedValueIfAbsent(java.lang.String, java.util.function.Supplier)"
],
"fields" : [ ]
},
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
index 2ab2de303c2..e53f79d98ec 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
@@ -7,10 +7,10 @@ import com.yahoo.language.Language;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.function.Supplier;
/**
* An embedder converts a text string to a tensor
@@ -162,6 +162,12 @@ public interface Embedder {
return cache.get(key);
}
+ /** Returns the cached value, or computes and caches it if not present. */
+ @SuppressWarnings("unchecked")
+ public <T> T computeCachedValueIfAbsent(String key, Supplier<? extends T> supplier) {
+ return (T) cache.computeIfAbsent(key, __ -> supplier.get());
+ }
+
}
class FailingEmbedder implements Embedder {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 2f4c0343bf6..a9d6d308df8 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -196,7 +196,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
return resultTensor;
}
- @SuppressWarnings("unchecked")
+
protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
var start = System.nanoTime();
@@ -228,15 +228,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
* @param hashKey the key to the cached value
* @return the model output
*/
- @SuppressWarnings("unchecked")
protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- if (context.getCachedValue(hashKey) == null) {
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- context.putCachedValue(hashKey, outputs);
- return outputs;
- } else {
- return (Map<String, Tensor>) context.getCachedValue(hashKey);
- }
+ return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 08c98fedf3e..1b9f9dd2fe3 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -182,22 +182,12 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
* @param hashKey the key to the cached value
* @return the model output
*/
- @SuppressWarnings("unchecked")
protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- if (context.getCachedValue(hashKey) == null) {
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- context.putCachedValue(hashKey, outputs);
- return outputs;
- } else {
- return (Map<String, Tensor>) context.getCachedValue(hashKey);
- }
+ return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
}
/**
* Binary quantization of the embedding into a tensor of type int8 with the specified dimensions.
- * @param embedding
- * @param tensorType
- * @return
*/
static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
Tensor.Builder builder = Tensor.Builder.of(tensorType);