summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java45
1 files changed, 8 insertions, 37 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index 3afc85300d4..ff7f4ae42bc 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -3,17 +3,17 @@ package com.yahoo.language.sentencepiece;
import com.yahoo.api.annotations.Beta;
import com.google.inject.Inject;
+import com.yahoo.language.tools.Embed;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.HashMap;
+import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -38,7 +38,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
}
public SentencePieceEmbedder(Builder builder) {
- algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring());
+ algorithm = new SentencePieceAlgorithm(builder.getCollapseUnknowns(), builder.getScoring());
models = builder.getModels().entrySet()
.stream()
@@ -94,9 +94,6 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
* they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small
* it will be truncated.</p>
*
- * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token
- * position as value.</p>
- *
* <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
*
* @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
@@ -105,40 +102,15 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
*/
@Override
public Tensor embed(String rawInput, Embedder.Context context, TensorType type) {
- if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
- // Build to a list first since we can't reverse a tensor builder
- List<Integer> values = embed(rawInput, context);
-
- long maxSize = values.size();
- if (type.dimensions().get(0).size().isPresent())
- maxSize = Math.min(maxSize, type.dimensions().get(0).size().get());
-
- Tensor.Builder builder = Tensor.Builder.of(type);
- for (int i = 0; i < maxSize; i++)
- builder.cell(values.get(i), i);
- return builder.build();
- }
- else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) {
- // Build to a list first since we can't reverse a tensor builder
- List<String> values = segment(rawInput, context.getLanguage());
-
- Tensor.Builder builder = Tensor.Builder.of(type);
- for (int i = 0; i < values.size(); i++)
- builder.cell(TensorAddress.ofLabels(values.get(i)), i);
- return builder.build();
- }
- else {
- throw new IllegalArgumentException("Don't know how to embed with SentencePiece into " + type);
- }
+ return Embed.asTensor(rawInput, this, context, type);
}
private <RESULTTYPE> void segment(String input, Language language,
ResultBuilder<RESULTTYPE> resultBuilder) {
- Model model = resolveFrom(language);
- algorithm.segment(input, resultBuilder, model);
+ algorithm.segment(input, resultBuilder, resolveModelFrom(language));
}
- private Model resolveFrom(Language language) {
+ private Model resolveModelFrom(Language language) {
// Disregard language if there is default model
if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
if (models.containsKey(language)) return models.get(language);
@@ -166,7 +138,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
public static class Builder {
- private final Map<Language, Path> models = new HashMap<>();
+ private final Map<Language, Path> models = new EnumMap<>(Language.class);
private boolean collapseUnknowns = true;
private Scoring scoring = Scoring.fewestSegments;
@@ -177,9 +149,8 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
collapseUnknowns = config.collapseUnknowns();
scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments
: Scoring.highestScore;
- for (SentencePieceConfig.Model model : config.model()) {
+ for (SentencePieceConfig.Model model : config.model())
addModel(Language.fromLanguageTag(model.language()), model.path());
- }
}
public void addModel(Language language, Path model) {