aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java12
2 files changed, 3 insertions, 20 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 2f4c0343bf6..a9d6d308df8 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -196,7 +196,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
return resultTensor;
}
- @SuppressWarnings("unchecked")
+
protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
var start = System.nanoTime();
@@ -228,15 +228,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
* @param hashKey the key to the cached value
* @return the model output
*/
- @SuppressWarnings("unchecked")
protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- if (context.getCachedValue(hashKey) == null) {
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- context.putCachedValue(hashKey, outputs);
- return outputs;
- } else {
- return (Map<String, Tensor>) context.getCachedValue(hashKey);
- }
+ return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 08c98fedf3e..1b9f9dd2fe3 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -182,22 +182,12 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
* @param hashKey the key to the cached value
* @return the model output
*/
- @SuppressWarnings("unchecked")
protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- if (context.getCachedValue(hashKey) == null) {
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- context.putCachedValue(hashKey, outputs);
- return outputs;
- } else {
- return (Map<String, Tensor>) context.getCachedValue(hashKey);
- }
+ return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
}
/**
* Binary quantization of the embedding into a tensor of type int8 with the specified dimensions.
- * @param embedding
- * @param tensorType
- * @return
*/
static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
Tensor.Builder builder = Tensor.Builder.of(tensorType);