aboutsummaryrefslogtreecommitdiffstats
path: root/indexinglanguage
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-10-01 11:09:08 +0200
committerJon Bratseth <bratseth@gmail.com>2021-10-01 11:09:08 +0200
commitac2519a8842a6397e4abd434439e9dddd2924394 (patch)
tree792275efbb88966a27a7ce54cc31465b563d7ad0 /indexinglanguage
parent380b9fa780ead9bcce0e824f7b6ee305e37dec43 (diff)
Encapsulate in a context
Diffstat (limited to 'indexinglanguage')
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java4
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java6
2 files changed, 6 insertions, 4 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 043a30ce66d..66d912cd987 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -40,7 +40,9 @@ public class EmbedExpression extends Expression {
@Override
protected void doExecute(ExecutionContext context) {
StringFieldValue input = (StringFieldValue) context.getValue();
- Tensor tensor = embedder.embed(input.getString(), context.getLanguage(), destination, targetType);
+ Tensor tensor = embedder.embed(input.getString(),
+ new Embedder.Context(destination).setLanguage(context.getLanguage()),
+ targetType);
context.setValue(new TensorFieldValue(tensor));
}
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index e0c0a9faba8..f193ac1a4c8 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -135,13 +135,13 @@ public class ScriptTestCase {
}
@Override
- public List<Integer> embed(String text, Language language, String destination) {
+ public List<Integer> embed(String text, Embedder.Context context) {
return null;
}
@Override
- public Tensor embed(String text, Language language, String destination, TensorType tensorType) {
- assertEquals(expectedDestination, destination);
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ assertEquals(expectedDestination, context.getDestination());
return Tensor.from(tensorType, "[7,3,0,0]");
}