diff options
author | Jon Bratseth <bratseth@vespa.ai> | 2024-04-07 20:31:37 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@vespa.ai> | 2024-04-07 20:31:37 +0200 |
commit | 6715471dceedbbda28d9d29ffb9d441ebfb848a2 (patch) | |
tree | e6255566e40817243f3df7a4667cf9e6822baa62 /linguistics | |
parent | 9f9160985a4f4848fa3f89d83a9f859958bd8e3c (diff) |
Key by embedder id and don't recompute inputs
Diffstat (limited to 'linguistics')
-rw-r--r-- | linguistics/abi-spec.json | 6 | ||||
-rw-r--r-- | linguistics/src/main/java/com/yahoo/language/process/Embedder.java | 15 |
2 files changed, 11 insertions, 10 deletions
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 91574133658..58e28fd7975 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -346,9 +346,9 @@ "public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)", "public java.lang.String getEmbedderId()", "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)", - "public void putCachedValue(java.lang.String, java.lang.Object)", - "public java.lang.Object getCachedValue(java.lang.String)", - "public java.lang.Object computeCachedValueIfAbsent(java.lang.String, java.util.function.Supplier)" + "public void putCachedValue(java.lang.Object, java.lang.Object)", + "public java.lang.Object getCachedValue(java.lang.Object)", + "public java.lang.Object computeCachedValueIfAbsent(java.lang.Object, java.util.function.Supplier)" ], "fields" : [ ] }, diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java index e53f79d98ec..989edcdb18a 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java @@ -73,9 +73,10 @@ public interface Embedder { */ @Beta interface Runtime { - /** Sample latency metric for embedding */ + + /** Add a sample embedding latency to this */ void sampleEmbeddingLatency(double millis, Context ctx); - /** Sample sequence length metric for embedding */ + /** Add a sample embedding length to this */ void sampleSequenceLength(long length, Context ctx); static Runtime testInstance() { @@ -91,7 +92,7 @@ public interface Embedder { private Language language = Language.UNKNOWN; private String destination; private String embedderId = "unknown"; - private final Map<String, Object> cache; + private final Map<Object, Object> cache; public Context(String destination) { this(destination, LazyMap.newHashMap()); @@ -101,7 +102,7 @@ public interface Embedder { * @param destination the name of the recipient of this tensor * @param cache a cache shared between all embed invocations for a single request */ - public Context(String destination, Map<String, Object> cache) { + public Context(String destination, Map<Object, Object> cache) { this.destination = destination; this.cache = Objects.requireNonNull(cache); } @@ -153,18 +154,18 @@ public interface Embedder { return this; } - public void putCachedValue(String key, Object value) { + public void putCachedValue(Object key, Object value) { cache.put(key, value); } /** Returns a cached value, or null if not present. */ - public Object getCachedValue(String key) { + public Object getCachedValue(Object key) { return cache.get(key); } /** Returns the cached value, or computes and caches it if not present. */ @SuppressWarnings("unchecked") - public <T> T computeCachedValueIfAbsent(String key, Supplier<? extends T> supplier) { + public <T> T computeCachedValueIfAbsent(Object key, Supplier<? extends T> supplier) { return (T) cache.computeIfAbsent(key, __ -> supplier.get()); } |