aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@vespa.ai>2024-04-04 11:35:52 +0200
committerBjørn Christian Seime <bjorncs@vespa.ai>2024-04-04 11:36:01 +0200
commit9f9160985a4f4848fa3f89d83a9f859958bd8e3c (patch)
tree646c33b3b4a23ad58caf73d03c808f50b5ff6e28 /model-integration
parent3aadce672938ac990c261b97d0ca9d752c0d0cf6 (diff)
Add equivalent to `Map.computeIfAbsent()` to simplify typical usage of the cache
Current interface requires a lot of boilerplate code.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java12
2 files changed, 3 insertions, 20 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 2f4c0343bf6..a9d6d308df8 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -196,7 +196,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
return resultTensor;
}
- @SuppressWarnings("unchecked")
+
protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
var start = System.nanoTime();
@@ -228,15 +228,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
* @param hashKey the key to the cached value
* @return the model output
*/
- @SuppressWarnings("unchecked")
protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- if (context.getCachedValue(hashKey) == null) {
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- context.putCachedValue(hashKey, outputs);
- return outputs;
- } else {
- return (Map<String, Tensor>) context.getCachedValue(hashKey);
- }
+ return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 08c98fedf3e..1b9f9dd2fe3 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -182,22 +182,12 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
* @param hashKey the key to the cached value
* @return the model output
*/
- @SuppressWarnings("unchecked")
protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- if (context.getCachedValue(hashKey) == null) {
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- context.putCachedValue(hashKey, outputs);
- return outputs;
- } else {
- return (Map<String, Tensor>) context.getCachedValue(hashKey);
- }
+ return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
}
/**
* Binary quantization of the embedding into a tensor of type int8 with the specified dimensions.
- * @param embedding
- * @param tensorType
- * @return
*/
static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
Tensor.Builder builder = Tensor.Builder.of(tensorType);