summaryrefslogtreecommitdiffstats
path: root/linguistics
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2024-04-01 20:35:36 +0200
committerJon Bratseth <bratseth@vespa.ai>2024-04-01 20:35:36 +0200
commit07b1050887f40b3a008c7b18de40e6eca40aabed (patch)
tree02311ba31e0d1c39a80450734ba7e37f553f545c /linguistics
parenta10de7fb58ce8b5167a6afd6082f49a0f8cc7b1b (diff)
Expose cache to embedders
Diffstat (limited to 'linguistics')
-rw-r--r--linguistics/abi-spec.json5
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Embedder.java23
2 files changed, 27 insertions, 1 deletions
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 1ffb879e57e..0bd4638bb05 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -338,13 +338,16 @@
],
"methods" : [
"public void <init>(java.lang.String)",
+ "public void <init>(java.lang.String, java.util.Map)",
"public com.yahoo.language.process.Embedder$Context copy()",
"public com.yahoo.language.Language getLanguage()",
"public com.yahoo.language.process.Embedder$Context setLanguage(com.yahoo.language.Language)",
"public java.lang.String getDestination()",
"public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)",
"public java.lang.String getEmbedderId()",
- "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)"
+ "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)",
+ "public void putCachedValue(java.lang.String, java.lang.Object)",
+ "public java.lang.Object getCachedValue(java.lang.String)"
],
"fields" : [ ]
},
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
index c46c3ca690c..2ab2de303c2 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
@@ -2,12 +2,15 @@
package com.yahoo.language.process;
import com.yahoo.api.annotations.Beta;
+import com.yahoo.collections.LazyMap;
import com.yahoo.language.Language;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
/**
* An embedder converts a text string to a tensor
@@ -88,15 +91,26 @@ public interface Embedder {
private Language language = Language.UNKNOWN;
private String destination;
private String embedderId = "unknown";
+ private final Map<String, Object> cache;
public Context(String destination) {
+ this(destination, LazyMap.newHashMap());
+ }
+
+ /**
+ * @param destination the name of the recipient of this tensor
+ * @param cache a cache shared between all embed invocations for a single request
+ */
+ public Context(String destination, Map<String, Object> cache) {
this.destination = destination;
+ this.cache = Objects.requireNonNull(cache);
}
private Context(Context other) {
language = other.language;
destination = other.destination;
embedderId = other.embedderId;
+ this.cache = other.cache;
}
public Context copy() { return new Context(this); }
@@ -139,6 +153,15 @@ public interface Embedder {
return this;
}
+ public void putCachedValue(String key, Object value) {
+ cache.put(key, value);
+ }
+
+ /** Returns a cached value, or null if not present. */
+ public Object getCachedValue(String key) {
+ return cache.get(key);
+ }
+
}
class FailingEmbedder implements Embedder {