diff options
author | Jon Bratseth <bratseth@vespa.ai> | 2024-04-01 20:35:36 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@vespa.ai> | 2024-04-01 20:35:36 +0200 |
commit | 07b1050887f40b3a008c7b18de40e6eca40aabed (patch) | |
tree | 02311ba31e0d1c39a80450734ba7e37f553f545c /indexinglanguage | |
parent | a10de7fb58ce8b5167a6afd6082f49a0f8cc7b1b (diff) |
Expose cache to embedders
Diffstat (limited to 'indexinglanguage')
3 files changed, 13 insertions, 7 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 5d5410c2ef0..05ac73618e8 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -169,9 +169,9 @@ public class EmbedExpression extends Expression { private Tensor embed(String input, TensorType targetType, ExecutionContext context) { return embedder.embed(input, - new Embedder.Context(destination).setLanguage(context.getLanguage()).setEmbedderId(embedderId), + new Embedder.Context(destination, context.getCache()).setLanguage(context.getLanguage()) + .setEmbedderId(embedderId), targetType); - } @Override diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java index 1935664cddc..ba07fc00ca8 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.indexinglanguage.expressions; +import com.yahoo.collections.LazyMap; import com.yahoo.document.DataType; import com.yahoo.document.FieldPath; import com.yahoo.document.datatypes.FieldValue; @@ -21,7 +22,7 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter { private final FieldValueAdapter adapter; private FieldValue value; private Language language; - private Map<String, Object> cache = null; + private final Map<String, Object> cache = LazyMap.newHashMap(); public ExecutionContext() { this(null); @@ -120,17 +121,19 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter { } public void putCachedValue(String key, Object value) { - if (cache == null) - cache = new HashMap<>(); cache.put(key, value); } /** Returns a cached value, or null if not present. */ public Object getCachedValue(String key) { - if (cache == null) return null; return cache.get(key); } + /** Returns a mutable reference to the cache of this. */ + public Map<String, Object> getCache() { + return cache; + } + /** Clears all state in this except the cache. */ public ExecutionContext clear() { variables.clear(); diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 4e1eae2ed46..f6995ac5a72 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -490,7 +490,9 @@ public class ScriptTestCase { assertTrue(adapter.values.containsKey("mySparseTensor")); var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"), - sparseTensor.getTensor().get()); + sparseTensor.getTensor().get()); + assertEquals("Cached value always set by MockMappedEmbedder is present", + "myCachedValue", context.getCachedValue("myCacheKey")); } /** Multiple paragraphs with sparse encoding (splade style) */ @@ -626,6 +628,7 @@ public class ScriptTestCase { @Override public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { verifyDestination(context); + context.putCachedValue("myCacheKey", "myCachedValue"); var b = Tensor.Builder.of(tensorType); for (int i = 0; i < text.length(); i++) b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition); |