From 9f9160985a4f4848fa3f89d83a9f859958bd8e3c Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Thu, 4 Apr 2024 11:35:52 +0200 Subject: Add equivalent to `Map.computeIfAbsent()` to simplify typical usage of the cache Current interface requires a lot of boilerplate code. --- linguistics/abi-spec.json | 3 ++- .../src/main/java/com/yahoo/language/process/Embedder.java | 8 +++++++- .../src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 11 ++--------- .../ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 12 +----------- 4 files changed, 12 insertions(+), 22 deletions(-) diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 0bd4638bb05..91574133658 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -347,7 +347,8 @@ "public java.lang.String getEmbedderId()", "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)", "public void putCachedValue(java.lang.String, java.lang.Object)", - "public java.lang.Object getCachedValue(java.lang.String)" + "public java.lang.Object getCachedValue(java.lang.String)", + "public java.lang.Object computeCachedValueIfAbsent(java.lang.String, java.util.function.Supplier)" ], "fields" : [ ] }, diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java index 2ab2de303c2..e53f79d98ec 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java @@ -7,10 +7,10 @@ import com.yahoo.language.Language; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; /** * An embedder converts a text string to a tensor @@ -162,6 +162,12 @@ public interface Embedder { return cache.get(key); } + /** Returns the cached value, or computes and caches it if not present. */ + @SuppressWarnings("unchecked") + public T computeCachedValueIfAbsent(String key, Supplier supplier) { + return (T) cache.computeIfAbsent(key, __ -> supplier.get()); + } + } class FailingEmbedder implements Embedder { diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index 2f4c0343bf6..a9d6d308df8 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -196,7 +196,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); return resultTensor; } - @SuppressWarnings("unchecked") + protected Tensor embedDocument(String text, Context context, TensorType tensorType) { var start = System.nanoTime(); @@ -228,15 +228,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { * @param hashKey the key to the cached value * @return the model output */ - @SuppressWarnings("unchecked") protected Map evaluateIfNotPresent(Map inputs, Context context, String hashKey) { - if (context.getCachedValue(hashKey) == null) { - Map outputs = evaluator.evaluate(inputs); - context.putCachedValue(hashKey, outputs); - return outputs; - } else { - return (Map) context.getCachedValue(hashKey); - } + return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs)); } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 08c98fedf3e..1b9f9dd2fe3 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -182,22 +182,12 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { * @param hashKey the key to the cached value * @return the model output */ - @SuppressWarnings("unchecked") protected Map evaluateIfNotPresent(Map inputs, Context context, String hashKey) { - if (context.getCachedValue(hashKey) == null) { - Map outputs = evaluator.evaluate(inputs); - context.putCachedValue(hashKey, outputs); - return outputs; - } else { - return (Map) context.getCachedValue(hashKey); - } + return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs)); } /** * Binary quantization of the embedding into a tensor of type int8 with the specified dimensions. - * @param embedding - * @param tensorType - * @return */ static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) { Tensor.Builder builder = Tensor.Builder.of(tensorType); -- cgit v1.2.3