aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/sentencepiece
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-10-01 11:09:08 +0200
committerJon Bratseth <bratseth@gmail.com>2021-10-01 11:09:08 +0200
commitac2519a8842a6397e4abd434439e9dddd2924394 (patch)
tree792275efbb88966a27a7ce54cc31465b563d7ad0 /linguistics-components/src/main/java/com/yahoo/language/sentencepiece
parent380b9fa780ead9bcce0e824f7b6ee305e37dec43 (diff)
Encapsulate in a context
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/sentencepiece')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java16
1 files changed, 7 insertions, 9 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index 1e120969a59..3f4e8ee3462 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -72,18 +72,17 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
* Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids.
*
* @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
- * @param language the model to use, or Language.UNKNOWN to use the default model if any
- * @param destination ignored
+ * @param context the context which specifies the language used to select a model
* @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
- public List<Integer> embed(String rawInput, Language language, String destination) {
+ public List<Integer> embed(String rawInput, Embedder.Context context) {
var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) {
public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
result().add(segmentEnds[segmentEnd].id);
}
};
- segment(normalize(rawInput), language, resultBuilder);
+ segment(normalize(rawInput), context.getLanguage(), resultBuilder);
Collections.reverse(resultBuilder.result());
return resultBuilder.result();
}
@@ -101,15 +100,14 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
* <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
*
* @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
- * @param language the model to use, or Language.UNKNOWN to use the default model if any
- * @param destination ignored
+ * @param context the context which specifies the language used to select a model
* @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
- public Tensor embed(String rawInput, Language language, String destination, TensorType type) {
+ public Tensor embed(String rawInput, Embedder.Context context, TensorType type) {
if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
// Build to a list first since we can't reverse a tensor builder
- List<Integer> values = embed(rawInput, language, destination);
+ List<Integer> values = embed(rawInput, context);
long maxSize = values.size();
if (type.dimensions().get(0).size().isPresent())
@@ -122,7 +120,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
}
else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) {
// Build to a list first since we can't reverse a tensor builder
- List<String> values = segment(rawInput, language);
+ List<String> values = segment(rawInput, context.getLanguage());
Tensor.Builder builder = Tensor.Builder.of(type);
for (int i = 0; i < values.size(); i++)