From ac2519a8842a6397e4abd434439e9dddd2924394 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 1 Oct 2021 11:09:08 +0200 Subject: Encapsulate in a context --- .../language/sentencepiece/SentencePieceEmbedder.java | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) (limited to 'linguistics-components/src/main/java') diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index 1e120969a59..3f4e8ee3462 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -72,18 +72,17 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { * Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids. * * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. - * @param language the model to use, or Language.UNKNOWN to use the default model if any - * @param destination ignored + * @param context the context which specifies the language used to select a model * @return the list of zero or more token ids resulting from segmenting the input text */ @Override - public List embed(String rawInput, Language language, String destination) { + public List embed(String rawInput, Embedder.Context context) { var resultBuilder = new ResultBuilder>(new ArrayList<>()) { public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { result().add(segmentEnds[segmentEnd].id); } }; - segment(normalize(rawInput), language, resultBuilder); + segment(normalize(rawInput), context.getLanguage(), resultBuilder); Collections.reverse(resultBuilder.result()); return resultBuilder.result(); } @@ -101,15 +100,14 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { *

If the tensor is any other type IllegalArgumentException is thrown.

* * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. - * @param language the model to use, or Language.UNKNOWN to use the default model if any - * @param destination ignored + * @param context the context which specifies the language used to select a model * @return the list of zero or more token ids resulting from segmenting the input text */ @Override - public Tensor embed(String rawInput, Language language, String destination, TensorType type) { + public Tensor embed(String rawInput, Embedder.Context context, TensorType type) { if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { // Build to a list first since we can't reverse a tensor builder - List values = embed(rawInput, language, destination); + List values = embed(rawInput, context); long maxSize = values.size(); if (type.dimensions().get(0).size().isPresent()) @@ -122,7 +120,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { } else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) { // Build to a list first since we can't reverse a tensor builder - List values = segment(rawInput, language); + List values = segment(rawInput, context.getLanguage()); Tensor.Builder builder = Tensor.Builder.of(type); for (int i = 0; i < values.size(); i++) -- cgit v1.2.3