summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java4
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java11
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java5
-rw-r--r--linguistics/abi-spec.json5
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Embedder.java23
5 files changed, 40 insertions, 8 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 5d5410c2ef0..05ac73618e8 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -169,9 +169,9 @@ public class EmbedExpression extends Expression {
private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
return embedder.embed(input,
- new Embedder.Context(destination).setLanguage(context.getLanguage()).setEmbedderId(embedderId),
+ new Embedder.Context(destination, context.getCache()).setLanguage(context.getLanguage())
+ .setEmbedderId(embedderId),
targetType);
-
}
@Override
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
index 1935664cddc..ba07fc00ca8 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.indexinglanguage.expressions;
+import com.yahoo.collections.LazyMap;
import com.yahoo.document.DataType;
import com.yahoo.document.FieldPath;
import com.yahoo.document.datatypes.FieldValue;
@@ -21,7 +22,7 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter {
private final FieldValueAdapter adapter;
private FieldValue value;
private Language language;
- private Map<String, Object> cache = null;
+ private final Map<String, Object> cache = LazyMap.newHashMap();
public ExecutionContext() {
this(null);
@@ -120,17 +121,19 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter {
}
public void putCachedValue(String key, Object value) {
- if (cache == null)
- cache = new HashMap<>();
cache.put(key, value);
}
/** Returns a cached value, or null if not present. */
public Object getCachedValue(String key) {
- if (cache == null) return null;
return cache.get(key);
}
+ /** Returns a mutable reference to the cache of this. */
+ public Map<String, Object> getCache() {
+ return cache;
+ }
+
/** Clears all state in this except the cache. */
public ExecutionContext clear() {
variables.clear();
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 4e1eae2ed46..f6995ac5a72 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -490,7 +490,9 @@ public class ScriptTestCase {
assertTrue(adapter.values.containsKey("mySparseTensor"));
var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"),
- sparseTensor.getTensor().get());
+ sparseTensor.getTensor().get());
+ assertEquals("Cached value always set by MockMappedEmbedder is present",
+ "myCachedValue", context.getCachedValue("myCacheKey"));
}
/** Multiple paragraphs with sparse encoding (splade style) */
@@ -626,6 +628,7 @@ public class ScriptTestCase {
@Override
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
verifyDestination(context);
+ context.putCachedValue("myCacheKey", "myCachedValue");
var b = Tensor.Builder.of(tensorType);
for (int i = 0; i < text.length(); i++)
b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition);
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 1ffb879e57e..0bd4638bb05 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -338,13 +338,16 @@
],
"methods" : [
"public void <init>(java.lang.String)",
+ "public void <init>(java.lang.String, java.util.Map)",
"public com.yahoo.language.process.Embedder$Context copy()",
"public com.yahoo.language.Language getLanguage()",
"public com.yahoo.language.process.Embedder$Context setLanguage(com.yahoo.language.Language)",
"public java.lang.String getDestination()",
"public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)",
"public java.lang.String getEmbedderId()",
- "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)"
+ "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)",
+ "public void putCachedValue(java.lang.String, java.lang.Object)",
+ "public java.lang.Object getCachedValue(java.lang.String)"
],
"fields" : [ ]
},
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
index c46c3ca690c..2ab2de303c2 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
@@ -2,12 +2,15 @@
package com.yahoo.language.process;
import com.yahoo.api.annotations.Beta;
+import com.yahoo.collections.LazyMap;
import com.yahoo.language.Language;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
/**
* An embedder converts a text string to a tensor
@@ -88,15 +91,26 @@ public interface Embedder {
private Language language = Language.UNKNOWN;
private String destination;
private String embedderId = "unknown";
+ private final Map<String, Object> cache;
public Context(String destination) {
+ this(destination, LazyMap.newHashMap());
+ }
+
+ /**
+ * @param destination the name of the recipient of this tensor
+ * @param cache a cache shared between all embed invocations for a single request
+ */
+ public Context(String destination, Map<String, Object> cache) {
this.destination = destination;
+ this.cache = Objects.requireNonNull(cache);
}
private Context(Context other) {
language = other.language;
destination = other.destination;
embedderId = other.embedderId;
+ this.cache = other.cache;
}
public Context copy() { return new Context(this); }
@@ -139,6 +153,15 @@ public interface Embedder {
return this;
}
+ public void putCachedValue(String key, Object value) {
+ cache.put(key, value);
+ }
+
+ /** Returns a cached value, or null if not present. */
+ public Object getCachedValue(String key) {
+ return cache.get(key);
+ }
+
}
class FailingEmbedder implements Embedder {