diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-10-01 11:09:08 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-10-01 11:09:08 +0200 |
commit | ac2519a8842a6397e4abd434439e9dddd2924394 (patch) | |
tree | 792275efbb88966a27a7ce54cc31465b563d7ad0 | |
parent | 380b9fa780ead9bcce0e824f7b6ee305e37dec43 (diff) |
Encapsulate in a context
9 files changed, 93 insertions, 38 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index 05befb24da0..76fb28ebb34 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.query.profile.types; +import com.yahoo.language.process.Embedder; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -52,7 +53,11 @@ public class TensorFieldType extends FieldType { if ( ! s.endsWith(")")) throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); String text = s.substring("embed(".length(), s.length() - 1); - return context.embedder().embed(text, context.language(), context.destination(), type); + return context.embedder().embed(text, toEmbedderContext(context), type); + } + + private Embedder.Context toEmbedderContext(ConversionContext context) { + return new Embedder.Context(context.destination()).setLanguage(context.language()); } public static TensorFieldType fromTypeString(String s) { diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java index e63c7711ff2..f11e5614635 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java @@ -747,16 +747,16 @@ public class QueryProfileTypeTestCase { } @Override - public List<Integer> embed(String text, Language language, String destination) { + public List<Integer> embed(String text, Embedder.Context context) { fail("Unexpected call"); return null; } @Override - public Tensor embed(String text, Language language, String destination, TensorType tensorType) { + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedText, text); - assertEquals(expectedLanguage, language); - assertEquals(expectedDestination, destination); + assertEquals(expectedLanguage, context.getLanguage()); + assertEquals(expectedDestination, context.getDestination()); assertEquals(tensorToReturn.type(), tensorType); return tensorToReturn; } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 043a30ce66d..66d912cd987 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -40,7 +40,9 @@ public class EmbedExpression extends Expression { @Override protected void doExecute(ExecutionContext context) { StringFieldValue input = (StringFieldValue) context.getValue(); - Tensor tensor = embedder.embed(input.getString(), context.getLanguage(), destination, targetType); + Tensor tensor = embedder.embed(input.getString(), + new Embedder.Context(destination).setLanguage(context.getLanguage()), + targetType); context.setValue(new TensorFieldValue(tensor)); } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index e0c0a9faba8..f193ac1a4c8 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -135,13 +135,13 @@ public class ScriptTestCase { } @Override - public List<Integer> embed(String text, Language language, String destination) { + public List<Integer> embed(String text, Embedder.Context context) { return null; } @Override - public Tensor embed(String text, Language language, String destination, TensorType tensorType) { - assertEquals(expectedDestination, destination); + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + assertEquals(expectedDestination, context.getDestination()); return Tensor.from(tensorType, "[7,3,0,0]"); } diff --git a/linguistics-components/abi-spec.json b/linguistics-components/abi-spec.json index ebd7457dc71..28025d84f25 100644 --- a/linguistics-components/abi-spec.json +++ b/linguistics-components/abi-spec.json @@ -180,8 +180,8 @@ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)", "public void <init>(com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder)", "public java.util.List segment(java.lang.String, com.yahoo.language.Language)", - "public java.util.List embed(java.lang.String, com.yahoo.language.Language, java.lang.String)", - "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.Language, java.lang.String, com.yahoo.tensor.TensorType)", + "public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)", + "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)", "public java.lang.String normalize(java.lang.String)" ], "fields": [] diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index 1e120969a59..3f4e8ee3462 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -72,18 +72,17 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { * Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids. * * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. - * @param language the model to use, or Language.UNKNOWN to use the default model if any - * @param destination ignored + * @param context the context which specifies the language used to select a model * @return the list of zero or more token ids resulting from segmenting the input text */ @Override - public List<Integer> embed(String rawInput, Language language, String destination) { + public List<Integer> embed(String rawInput, Embedder.Context context) { var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) { public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { result().add(segmentEnds[segmentEnd].id); } }; - segment(normalize(rawInput), language, resultBuilder); + segment(normalize(rawInput), context.getLanguage(), resultBuilder); Collections.reverse(resultBuilder.result()); return resultBuilder.result(); } @@ -101,15 +100,14 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { * <p>If the tensor is any other type IllegalArgumentException is thrown.</p> * * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. - * @param language the model to use, or Language.UNKNOWN to use the default model if any - * @param destination ignored + * @param context the context which specifies the language used to select a model * @return the list of zero or more token ids resulting from segmenting the input text */ @Override - public Tensor embed(String rawInput, Language language, String destination, TensorType type) { + public Tensor embed(String rawInput, Embedder.Context context, TensorType type) { if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { // Build to a list first since we can't reverse a tensor builder - List<Integer> values = embed(rawInput, language, destination); + List<Integer> values = embed(rawInput, context); long maxSize = values.size(); if (type.dimensions().get(0).size().isPresent()) @@ -122,7 +120,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { } else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) { // Build to a list first since we can't reverse a tensor builder - List<String> values = segment(rawInput, language); + List<String> values = segment(rawInput, context.getLanguage()); Tensor.Builder builder = Tensor.Builder.of(type); for (int i = 0; i < values.size(); i++) diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java index c6aa8fdd370..4dae53c60df 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java @@ -4,6 +4,7 @@ package com.yahoo.language.sentencepiece; import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -29,13 +30,13 @@ class SentencePieceTester { } public void assertEmbedded(String input, Integer... expectedCodes) { - assertArrayEquals(expectedCodes, embedder.embed(input, Language.UNKNOWN, null).toArray()); + assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray()); } public void assertEmbedded(String input, String tensorType, String tensor) { TensorType type = TensorType.fromSpec(tensorType); Tensor expected = Tensor.from(type, tensor); - assertEquals(expected, embedder.embed(input, Language.UNKNOWN, null, type)); + assertEquals(expected, embedder.embed(input, new Embedder.Context("test"), type)); } public void assertSegmented(String input, String... expectedSegments) { diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 5865c28bbb6..31612bea983 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -328,6 +328,21 @@ ], "fields": [] }, + "com.yahoo.language.process.Embedder$Context": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(java.lang.String)", + "public com.yahoo.language.Language getLanguage()", + "public com.yahoo.language.process.Embedder$Context setLanguage(com.yahoo.language.Language)", + "public java.lang.String getDestination()", + "public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)" + ], + "fields": [] + }, "com.yahoo.language.process.Embedder$FailingEmbedder": { "superClass": "java.lang.Object", "interfaces": [ @@ -338,8 +353,8 @@ ], "methods": [ "public void <init>()", - "public java.util.List embed(java.lang.String, com.yahoo.language.Language, java.lang.String)", - "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.Language, java.lang.String, com.yahoo.tensor.TensorType)" + "public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)", + "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)" ], "fields": [] }, @@ -352,8 +367,8 @@ "abstract" ], "methods": [ - "public abstract java.util.List embed(java.lang.String, com.yahoo.language.Language, java.lang.String)", - "public abstract com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.Language, java.lang.String, com.yahoo.tensor.TensorType)" + "public abstract java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)", + "public abstract com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)" ], "fields": [ "public static final com.yahoo.language.process.Embedder throwsOnUse" diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java index 1f4473220d7..17ee0419cea 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java @@ -21,39 +21,73 @@ public interface Embedder { * Converts text into a list of token id's (a vector embedding) * * @param text the text to embed - * @param language the language of the text, or UNKNOWN to use language independent embedding - * @param destination the name of the recipient of this tensor, either a query feature name - * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field"). - * This is useful for embedder components that alters behavior depending on the receiver. + * @param context the context which may influence an embedder's behavior * @return the text embedded as a list of token ids * @throws IllegalArgumentException if the language is not supported by this embedder */ - List<Integer> embed(String text, Language language, String destination); + List<Integer> embed(String text, Context context); /** * Converts text into tokens in a tensor. * The information contained in the embedding may depend on the tensor type. * * @param text the text to embed - * @param language the language of the text, or UNKNOWN to use language independent embedding - * @param destination the name of the recipient of this tensor, either a query feature name - * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field"). - * This is useful for embedder components that alters behavior depending on the receiver. + * @param context the context which may influence an embedder's behavior * @param tensorType the type of the tensor to be returned * @return the tensor embedding of the text, as the spoecified tensor type * @throws IllegalArgumentException if the language or tensor type is not supported by this embedder */ - Tensor embed(String text, Language language, String destination, TensorType tensorType); + Tensor embed(String text, Context context, TensorType tensorType); + + class Context { + + private Language language = Language.UNKNOWN; + private String destination; + + public Context(String destination) { + this.destination = destination; + } + + /** Returns the language of the text, or UNKNOWN (default) to use a language independent embedding */ + public Language getLanguage() { return language; } + + /** Sets the language of the text, or UNKNOWN to use language independent embedding */ + public Context setLanguage(Language language) { + this.language = language; + return this; + } + + /** + * Returns the name of the recipient of this tensor. + * + * This is either a query feature name + * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field"). + * This cannot be null. + */ + public String getDestination() { return destination; } + + /** + * Sets the name of the recipient of this tensor. + * + * This iseither a query feature name + * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field"). + */ + public Context setDestination(String destination) { + this.destination = destination; + return this; + } + + } class FailingEmbedder implements Embedder { @Override - public List<Integer> embed(String text, Language language, String destination) { + public List<Integer> embed(String text, Context context) { throw new IllegalStateException("No embedder has been configured"); } @Override - public Tensor embed(String text, Language language, String destination, TensorType tensorType) { + public Tensor embed(String text, Context context, TensorType tensorType) { throw new IllegalStateException("No embedder has been configured"); } |