diff options
author | Jon Bratseth <bratseth@oath.com> | 2021-09-27 23:09:03 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-27 23:09:03 +0200 |
commit | 2df97d23d9f25ae60f010a2e9f273cb5b38e049b (patch) | |
tree | d2923a45682e91d80e7011c60cfb301e05acead3 /linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java | |
parent | 037f756caf4cfb99bcd988174839d7bc385267b9 (diff) | |
parent | 8f3fb1a105ded07144f6de527266a438e48a1766 (diff) |
Merge pull request #19294 from vespa-engine/bratseth/linguistics-componentsv7.473.17
Bratseth/linguistics components
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java new file mode 100644 index 00000000000..74f300057dc --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java @@ -0,0 +1,60 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +import com.yahoo.io.IOUtils; +import com.yahoo.language.Language; +import sentencepiece.SentencepieceModel; + +import java.io.IOException; +import java.nio.file.Path; + +/** + * A SentencePiece model + * + * @author bratseth + */ +final class Model { + + final Path source; + final Language language; + final float minScore; + final float maxScore; + final Trie tokens = new Trie(); + + Model(Language language, Path path) { + try { + this.source = path; + this.language = language; + var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile())); + float minScore = Float.MAX_VALUE; + float maxScore = Float.MIN_VALUE; + for (int i = 0; i < sp.getPiecesCount(); i++) { + var piece = sp.getPieces(i); + tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); + minScore = Math.min(piece.getScore(), minScore); + maxScore = Math.max(piece.getScore(), maxScore); + } + this.minScore = minScore; + this.maxScore = maxScore; + } catch (IOException e) { + throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e); + } + } + + private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) { + switch (type) { + case USER_DEFINED : return TokenType.userDefined; + case UNKNOWN : return TokenType.unknown; + case NORMAL : return TokenType.text; + case CONTROL : return TokenType.control; + case UNUSED : return TokenType.unused; + default : throw new IllegalArgumentException("Unknkown token type " + type); + } + } + + @Override + public String toString() { + return "SentencePiece model for " + language + ": '" + source + "'"; + } + +} |