diff options
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java | 45 |
1 files changed, 8 insertions, 37 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index 3afc85300d4..ff7f4ae42bc 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -3,17 +3,17 @@ package com.yahoo.language.sentencepiece; import com.yahoo.api.annotations.Beta; import com.google.inject.Inject; +import com.yahoo.language.tools.Embed; import com.yahoo.language.Language; import com.yahoo.language.process.Embedder; import com.yahoo.language.process.Segmenter; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; +import java.util.EnumMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -38,7 +38,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { } public SentencePieceEmbedder(Builder builder) { - algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring()); + algorithm = new SentencePieceAlgorithm(builder.getCollapseUnknowns(), builder.getScoring()); models = builder.getModels().entrySet() .stream() @@ -94,9 +94,6 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small * it will be truncated.</p> * - * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token - * position as value.</p> - * * <p>If the tensor is any other type IllegalArgumentException is thrown.</p> * * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. @@ -105,40 +102,15 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { */ @Override public Tensor embed(String rawInput, Embedder.Context context, TensorType type) { - if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { - // Build to a list first since we can't reverse a tensor builder - List<Integer> values = embed(rawInput, context); - - long maxSize = values.size(); - if (type.dimensions().get(0).size().isPresent()) - maxSize = Math.min(maxSize, type.dimensions().get(0).size().get()); - - Tensor.Builder builder = Tensor.Builder.of(type); - for (int i = 0; i < maxSize; i++) - builder.cell(values.get(i), i); - return builder.build(); - } - else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) { - // Build to a list first since we can't reverse a tensor builder - List<String> values = segment(rawInput, context.getLanguage()); - - Tensor.Builder builder = Tensor.Builder.of(type); - for (int i = 0; i < values.size(); i++) - builder.cell(TensorAddress.ofLabels(values.get(i)), i); - return builder.build(); - } - else { - throw new IllegalArgumentException("Don't know how to embed with SentencePiece into " + type); - } + return Embed.asTensor(rawInput, this, context, type); } private <RESULTTYPE> void segment(String input, Language language, ResultBuilder<RESULTTYPE> resultBuilder) { - Model model = resolveFrom(language); - algorithm.segment(input, resultBuilder, model); + algorithm.segment(input, resultBuilder, resolveModelFrom(language)); } - private Model resolveFrom(Language language) { + private Model resolveModelFrom(Language language) { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); if (models.containsKey(language)) return models.get(language); @@ -166,7 +138,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { public static class Builder { - private final Map<Language, Path> models = new HashMap<>(); + private final Map<Language, Path> models = new EnumMap<>(Language.class); private boolean collapseUnknowns = true; private Scoring scoring = Scoring.fewestSegments; @@ -177,9 +149,8 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { collapseUnknowns = config.collapseUnknowns(); scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments : Scoring.highestScore; - for (SentencePieceConfig.Model model : config.model()) { + for (SentencePieceConfig.Model model : config.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); - } } public void addModel(Language language, Path model) { |