summaryrefslogtreecommitdiffstats
path: root/indexinglanguage
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2024-04-01 20:35:36 +0200
committerJon Bratseth <bratseth@vespa.ai>2024-04-01 20:35:36 +0200
commit07b1050887f40b3a008c7b18de40e6eca40aabed (patch)
tree02311ba31e0d1c39a80450734ba7e37f553f545c /indexinglanguage
parenta10de7fb58ce8b5167a6afd6082f49a0f8cc7b1b (diff)
Expose cache to embedders
Diffstat (limited to 'indexinglanguage')
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java4
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java11
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java5
3 files changed, 13 insertions, 7 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 5d5410c2ef0..05ac73618e8 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -169,9 +169,9 @@ public class EmbedExpression extends Expression {
private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
return embedder.embed(input,
- new Embedder.Context(destination).setLanguage(context.getLanguage()).setEmbedderId(embedderId),
+ new Embedder.Context(destination, context.getCache()).setLanguage(context.getLanguage())
+ .setEmbedderId(embedderId),
targetType);
-
}
@Override
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
index 1935664cddc..ba07fc00ca8 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.indexinglanguage.expressions;
+import com.yahoo.collections.LazyMap;
import com.yahoo.document.DataType;
import com.yahoo.document.FieldPath;
import com.yahoo.document.datatypes.FieldValue;
@@ -21,7 +22,7 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter {
private final FieldValueAdapter adapter;
private FieldValue value;
private Language language;
- private Map<String, Object> cache = null;
+ private final Map<String, Object> cache = LazyMap.newHashMap();
public ExecutionContext() {
this(null);
@@ -120,17 +121,19 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter {
}
public void putCachedValue(String key, Object value) {
- if (cache == null)
- cache = new HashMap<>();
cache.put(key, value);
}
/** Returns a cached value, or null if not present. */
public Object getCachedValue(String key) {
- if (cache == null) return null;
return cache.get(key);
}
+ /** Returns a mutable reference to the cache of this. */
+ public Map<String, Object> getCache() {
+ return cache;
+ }
+
/** Clears all state in this except the cache. */
public ExecutionContext clear() {
variables.clear();
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 4e1eae2ed46..f6995ac5a72 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -490,7 +490,9 @@ public class ScriptTestCase {
assertTrue(adapter.values.containsKey("mySparseTensor"));
var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"),
- sparseTensor.getTensor().get());
+ sparseTensor.getTensor().get());
+ assertEquals("Cached value always set by MockMappedEmbedder is present",
+ "myCachedValue", context.getCachedValue("myCacheKey"));
}
/** Multiple paragraphs with sparse encoding (splade style) */
@@ -626,6 +628,7 @@ public class ScriptTestCase {
@Override
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
verifyDestination(context);
+ context.putCachedValue("myCacheKey", "myCachedValue");
var b = Tensor.Builder.of(tensorType);
for (int i = 0; i < text.length(); i++)
b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition);