aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-10-01 11:09:08 +0200
committerJon Bratseth <bratseth@gmail.com>2021-10-01 11:09:08 +0200
commitac2519a8842a6397e4abd434439e9dddd2924394 (patch)
tree792275efbb88966a27a7ce54cc31465b563d7ad0
parent380b9fa780ead9bcce0e824f7b6ee305e37dec43 (diff)
Encapsulate in a context
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java7
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java8
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java4
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java6
-rw-r--r--linguistics-components/abi-spec.json4
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java16
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java5
-rw-r--r--linguistics/abi-spec.json23
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Embedder.java58
9 files changed, 93 insertions, 38 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index 05befb24da0..76fb28ebb34 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.query.profile.types;
+import com.yahoo.language.process.Embedder;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -52,7 +53,11 @@ public class TensorFieldType extends FieldType {
if ( ! s.endsWith(")"))
throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'");
String text = s.substring("embed(".length(), s.length() - 1);
- return context.embedder().embed(text, context.language(), context.destination(), type);
+ return context.embedder().embed(text, toEmbedderContext(context), type);
+ }
+
+ private Embedder.Context toEmbedderContext(ConversionContext context) {
+ return new Embedder.Context(context.destination()).setLanguage(context.language());
}
public static TensorFieldType fromTypeString(String s) {
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
index e63c7711ff2..f11e5614635 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
@@ -747,16 +747,16 @@ public class QueryProfileTypeTestCase {
}
@Override
- public List<Integer> embed(String text, Language language, String destination) {
+ public List<Integer> embed(String text, Embedder.Context context) {
fail("Unexpected call");
return null;
}
@Override
- public Tensor embed(String text, Language language, String destination, TensorType tensorType) {
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
assertEquals(expectedText, text);
- assertEquals(expectedLanguage, language);
- assertEquals(expectedDestination, destination);
+ assertEquals(expectedLanguage, context.getLanguage());
+ assertEquals(expectedDestination, context.getDestination());
assertEquals(tensorToReturn.type(), tensorType);
return tensorToReturn;
}
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 043a30ce66d..66d912cd987 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -40,7 +40,9 @@ public class EmbedExpression extends Expression {
@Override
protected void doExecute(ExecutionContext context) {
StringFieldValue input = (StringFieldValue) context.getValue();
- Tensor tensor = embedder.embed(input.getString(), context.getLanguage(), destination, targetType);
+ Tensor tensor = embedder.embed(input.getString(),
+ new Embedder.Context(destination).setLanguage(context.getLanguage()),
+ targetType);
context.setValue(new TensorFieldValue(tensor));
}
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index e0c0a9faba8..f193ac1a4c8 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -135,13 +135,13 @@ public class ScriptTestCase {
}
@Override
- public List<Integer> embed(String text, Language language, String destination) {
+ public List<Integer> embed(String text, Embedder.Context context) {
return null;
}
@Override
- public Tensor embed(String text, Language language, String destination, TensorType tensorType) {
- assertEquals(expectedDestination, destination);
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ assertEquals(expectedDestination, context.getDestination());
return Tensor.from(tensorType, "[7,3,0,0]");
}
diff --git a/linguistics-components/abi-spec.json b/linguistics-components/abi-spec.json
index ebd7457dc71..28025d84f25 100644
--- a/linguistics-components/abi-spec.json
+++ b/linguistics-components/abi-spec.json
@@ -180,8 +180,8 @@
"public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)",
"public void <init>(com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder)",
"public java.util.List segment(java.lang.String, com.yahoo.language.Language)",
- "public java.util.List embed(java.lang.String, com.yahoo.language.Language, java.lang.String)",
- "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.Language, java.lang.String, com.yahoo.tensor.TensorType)",
+ "public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)",
+ "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)",
"public java.lang.String normalize(java.lang.String)"
],
"fields": []
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index 1e120969a59..3f4e8ee3462 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -72,18 +72,17 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
* Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids.
*
* @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
- * @param language the model to use, or Language.UNKNOWN to use the default model if any
- * @param destination ignored
+ * @param context the context which specifies the language used to select a model
* @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
- public List<Integer> embed(String rawInput, Language language, String destination) {
+ public List<Integer> embed(String rawInput, Embedder.Context context) {
var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) {
public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
result().add(segmentEnds[segmentEnd].id);
}
};
- segment(normalize(rawInput), language, resultBuilder);
+ segment(normalize(rawInput), context.getLanguage(), resultBuilder);
Collections.reverse(resultBuilder.result());
return resultBuilder.result();
}
@@ -101,15 +100,14 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
* <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
*
* @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
- * @param language the model to use, or Language.UNKNOWN to use the default model if any
- * @param destination ignored
+ * @param context the context which specifies the language used to select a model
* @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
- public Tensor embed(String rawInput, Language language, String destination, TensorType type) {
+ public Tensor embed(String rawInput, Embedder.Context context, TensorType type) {
if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
// Build to a list first since we can't reverse a tensor builder
- List<Integer> values = embed(rawInput, language, destination);
+ List<Integer> values = embed(rawInput, context);
long maxSize = values.size();
if (type.dimensions().get(0).size().isPresent())
@@ -122,7 +120,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
}
else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) {
// Build to a list first since we can't reverse a tensor builder
- List<String> values = segment(rawInput, language);
+ List<String> values = segment(rawInput, context.getLanguage());
Tensor.Builder builder = Tensor.Builder.of(type);
for (int i = 0; i < values.size(); i++)
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
index c6aa8fdd370..4dae53c60df 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -4,6 +4,7 @@
package com.yahoo.language.sentencepiece;
import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -29,13 +30,13 @@ class SentencePieceTester {
}
public void assertEmbedded(String input, Integer... expectedCodes) {
- assertArrayEquals(expectedCodes, embedder.embed(input, Language.UNKNOWN, null).toArray());
+ assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray());
}
public void assertEmbedded(String input, String tensorType, String tensor) {
TensorType type = TensorType.fromSpec(tensorType);
Tensor expected = Tensor.from(type, tensor);
- assertEquals(expected, embedder.embed(input, Language.UNKNOWN, null, type));
+ assertEquals(expected, embedder.embed(input, new Embedder.Context("test"), type));
}
public void assertSegmented(String input, String... expectedSegments) {
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 5865c28bbb6..31612bea983 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -328,6 +328,21 @@
],
"fields": []
},
+ "com.yahoo.language.process.Embedder$Context": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(java.lang.String)",
+ "public com.yahoo.language.Language getLanguage()",
+ "public com.yahoo.language.process.Embedder$Context setLanguage(com.yahoo.language.Language)",
+ "public java.lang.String getDestination()",
+ "public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)"
+ ],
+ "fields": []
+ },
"com.yahoo.language.process.Embedder$FailingEmbedder": {
"superClass": "java.lang.Object",
"interfaces": [
@@ -338,8 +353,8 @@
],
"methods": [
"public void <init>()",
- "public java.util.List embed(java.lang.String, com.yahoo.language.Language, java.lang.String)",
- "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.Language, java.lang.String, com.yahoo.tensor.TensorType)"
+ "public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)",
+ "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)"
],
"fields": []
},
@@ -352,8 +367,8 @@
"abstract"
],
"methods": [
- "public abstract java.util.List embed(java.lang.String, com.yahoo.language.Language, java.lang.String)",
- "public abstract com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.Language, java.lang.String, com.yahoo.tensor.TensorType)"
+ "public abstract java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)",
+ "public abstract com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)"
],
"fields": [
"public static final com.yahoo.language.process.Embedder throwsOnUse"
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
index 1f4473220d7..17ee0419cea 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
@@ -21,39 +21,73 @@ public interface Embedder {
* Converts text into a list of token id's (a vector embedding)
*
* @param text the text to embed
- * @param language the language of the text, or UNKNOWN to use language independent embedding
- * @param destination the name of the recipient of this tensor, either a query feature name
- * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field").
- * This is useful for embedder components that alters behavior depending on the receiver.
+ * @param context the context which may influence an embedder's behavior
* @return the text embedded as a list of token ids
* @throws IllegalArgumentException if the language is not supported by this embedder
*/
- List<Integer> embed(String text, Language language, String destination);
+ List<Integer> embed(String text, Context context);
/**
* Converts text into tokens in a tensor.
* The information contained in the embedding may depend on the tensor type.
*
* @param text the text to embed
- * @param language the language of the text, or UNKNOWN to use language independent embedding
- * @param destination the name of the recipient of this tensor, either a query feature name
- * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field").
- * This is useful for embedder components that alters behavior depending on the receiver.
+ * @param context the context which may influence an embedder's behavior
* @param tensorType the type of the tensor to be returned
* @return the tensor embedding of the text, as the spoecified tensor type
* @throws IllegalArgumentException if the language or tensor type is not supported by this embedder
*/
- Tensor embed(String text, Language language, String destination, TensorType tensorType);
+ Tensor embed(String text, Context context, TensorType tensorType);
+
+ class Context {
+
+ private Language language = Language.UNKNOWN;
+ private String destination;
+
+ public Context(String destination) {
+ this.destination = destination;
+ }
+
+ /** Returns the language of the text, or UNKNOWN (default) to use a language independent embedding */
+ public Language getLanguage() { return language; }
+
+ /** Sets the language of the text, or UNKNOWN to use language independent embedding */
+ public Context setLanguage(Language language) {
+ this.language = language;
+ return this;
+ }
+
+ /**
+ * Returns the name of the recipient of this tensor.
+ *
+ * This is either a query feature name
+ * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field").
+ * This cannot be null.
+ */
+ public String getDestination() { return destination; }
+
+ /**
+ * Sets the name of the recipient of this tensor.
+ *
+ * This iseither a query feature name
+ * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field").
+ */
+ public Context setDestination(String destination) {
+ this.destination = destination;
+ return this;
+ }
+
+ }
class FailingEmbedder implements Embedder {
@Override
- public List<Integer> embed(String text, Language language, String destination) {
+ public List<Integer> embed(String text, Context context) {
throw new IllegalStateException("No embedder has been configured");
}
@Override
- public Tensor embed(String text, Language language, String destination, TensorType tensorType) {
+ public Tensor embed(String text, Context context, TensorType tensorType) {
throw new IllegalStateException("No embedder has been configured");
}