summaryrefslogtreecommitdiffstats
path: root/linguistics
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2024-04-07 20:31:37 +0200
committerJon Bratseth <bratseth@vespa.ai>2024-04-07 20:31:37 +0200
commit6715471dceedbbda28d9d29ffb9d441ebfb848a2 (patch)
treee6255566e40817243f3df7a4667cf9e6822baa62 /linguistics
parent9f9160985a4f4848fa3f89d83a9f859958bd8e3c (diff)
Key by embedder id and don't recompute inputs
Diffstat (limited to 'linguistics')
-rw-r--r--linguistics/abi-spec.json6
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Embedder.java15
2 files changed, 11 insertions, 10 deletions
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 91574133658..58e28fd7975 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -346,9 +346,9 @@
"public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)",
"public java.lang.String getEmbedderId()",
"public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)",
- "public void putCachedValue(java.lang.String, java.lang.Object)",
- "public java.lang.Object getCachedValue(java.lang.String)",
- "public java.lang.Object computeCachedValueIfAbsent(java.lang.String, java.util.function.Supplier)"
+ "public void putCachedValue(java.lang.Object, java.lang.Object)",
+ "public java.lang.Object getCachedValue(java.lang.Object)",
+ "public java.lang.Object computeCachedValueIfAbsent(java.lang.Object, java.util.function.Supplier)"
],
"fields" : [ ]
},
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
index e53f79d98ec..989edcdb18a 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
@@ -73,9 +73,10 @@ public interface Embedder {
*/
@Beta
interface Runtime {
- /** Sample latency metric for embedding */
+
+ /** Add a sample embedding latency to this */
void sampleEmbeddingLatency(double millis, Context ctx);
- /** Sample sequence length metric for embedding */
+ /** Add a sample embedding length to this */
void sampleSequenceLength(long length, Context ctx);
static Runtime testInstance() {
@@ -91,7 +92,7 @@ public interface Embedder {
private Language language = Language.UNKNOWN;
private String destination;
private String embedderId = "unknown";
- private final Map<String, Object> cache;
+ private final Map<Object, Object> cache;
public Context(String destination) {
this(destination, LazyMap.newHashMap());
@@ -101,7 +102,7 @@ public interface Embedder {
* @param destination the name of the recipient of this tensor
* @param cache a cache shared between all embed invocations for a single request
*/
- public Context(String destination, Map<String, Object> cache) {
+ public Context(String destination, Map<Object, Object> cache) {
this.destination = destination;
this.cache = Objects.requireNonNull(cache);
}
@@ -153,18 +154,18 @@ public interface Embedder {
return this;
}
- public void putCachedValue(String key, Object value) {
+ public void putCachedValue(Object key, Object value) {
cache.put(key, value);
}
/** Returns a cached value, or null if not present. */
- public Object getCachedValue(String key) {
+ public Object getCachedValue(Object key) {
return cache.get(key);
}
/** Returns the cached value, or computes and caches it if not present. */
@SuppressWarnings("unchecked")
- public <T> T computeCachedValueIfAbsent(String key, Supplier<? extends T> supplier) {
+ public <T> T computeCachedValueIfAbsent(Object key, Supplier<? extends T> supplier) {
return (T) cache.computeIfAbsent(key, __ -> supplier.get());
}