From 9f9160985a4f4848fa3f89d83a9f859958bd8e3c Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Thu, 4 Apr 2024 11:35:52 +0200 Subject: Add equivalent to `Map.computeIfAbsent()` to simplify typical usage of the cache Current interface requires a lot of boilerplate code. --- .../src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 11 ++--------- .../ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 12 +----------- 2 files changed, 3 insertions(+), 20 deletions(-) (limited to 'model-integration') diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index 2f4c0343bf6..a9d6d308df8 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -196,7 +196,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); return resultTensor; } - @SuppressWarnings("unchecked") + protected Tensor embedDocument(String text, Context context, TensorType tensorType) { var start = System.nanoTime(); @@ -228,15 +228,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { * @param hashKey the key to the cached value * @return the model output */ - @SuppressWarnings("unchecked") protected Map evaluateIfNotPresent(Map inputs, Context context, String hashKey) { - if (context.getCachedValue(hashKey) == null) { - Map outputs = evaluator.evaluate(inputs); - context.putCachedValue(hashKey, outputs); - return outputs; - } else { - return (Map) context.getCachedValue(hashKey); - } + return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs)); } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 08c98fedf3e..1b9f9dd2fe3 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -182,22 +182,12 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { * @param hashKey the key to the cached value * @return the model output */ - @SuppressWarnings("unchecked") protected Map evaluateIfNotPresent(Map inputs, Context context, String hashKey) { - if (context.getCachedValue(hashKey) == null) { - Map outputs = evaluator.evaluate(inputs); - context.putCachedValue(hashKey, outputs); - return outputs; - } else { - return (Map) context.getCachedValue(hashKey); - } + return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs)); } /** * Binary quantization of the embedding into a tensor of type int8 with the specified dimensions. - * @param embedding - * @param tensorType - * @return */ static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) { Tensor.Builder builder = Tensor.Builder.of(tensorType); -- cgit v1.2.3