summaryrefslogtreecommitdiffstats
path: root/linguistics-components
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-30 14:21:25 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-30 14:21:25 +0200
commit380b9fa780ead9bcce0e824f7b6ee305e37dec43 (patch)
tree1a1d9a79910fdfc976643f02f1735d939ab689bf /linguistics-components
parent5007f6edcb5a2a461e859a10198f02171eab5516 (diff)
Update linguisticvs-components
Diffstat (limited to 'linguistics-components')
-rw-r--r--linguistics-components/abi-spec.json4
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java12
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java4
3 files changed, 13 insertions, 7 deletions
diff --git a/linguistics-components/abi-spec.json b/linguistics-components/abi-spec.json
index 808ec3af082..ebd7457dc71 100644
--- a/linguistics-components/abi-spec.json
+++ b/linguistics-components/abi-spec.json
@@ -180,8 +180,8 @@
"public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)",
"public void <init>(com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder)",
"public java.util.List segment(java.lang.String, com.yahoo.language.Language)",
- "public java.util.List embed(java.lang.String, com.yahoo.language.Language)",
- "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.Language, com.yahoo.tensor.TensorType)",
+ "public java.util.List embed(java.lang.String, com.yahoo.language.Language, java.lang.String)",
+ "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.Language, java.lang.String, com.yahoo.tensor.TensorType)",
"public java.lang.String normalize(java.lang.String)"
],
"fields": []
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index 116dd15f563..1e120969a59 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -73,10 +73,11 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
*
* @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
* @param language the model to use, or Language.UNKNOWN to use the default model if any
+ * @param destination ignored
* @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
- public List<Integer> embed(String rawInput, Language language) {
+ public List<Integer> embed(String rawInput, Language language, String destination) {
var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) {
public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
result().add(segmentEnds[segmentEnd].id);
@@ -98,12 +99,17 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
* position as value.</p>
*
* <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
+ *
+ * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
+ * @param language the model to use, or Language.UNKNOWN to use the default model if any
+ * @param destination ignored
+ * @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
- public Tensor embed(String rawInput, Language language, TensorType type) {
+ public Tensor embed(String rawInput, Language language, String destination, TensorType type) {
if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
// Build to a list first since we can't reverse a tensor builder
- List<Integer> values = embed(rawInput, language);
+ List<Integer> values = embed(rawInput, language, destination);
long maxSize = values.size();
if (type.dimensions().get(0).size().isPresent())
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
index c4cb13a3d23..c6aa8fdd370 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -29,13 +29,13 @@ class SentencePieceTester {
}
public void assertEmbedded(String input, Integer... expectedCodes) {
- assertArrayEquals(expectedCodes, embedder.embed(input, Language.UNKNOWN).toArray());
+ assertArrayEquals(expectedCodes, embedder.embed(input, Language.UNKNOWN, null).toArray());
}
public void assertEmbedded(String input, String tensorType, String tensor) {
TensorType type = TensorType.fromSpec(tensorType);
Tensor expected = Tensor.from(type, tensor);
- assertEquals(expected, embedder.embed(input, Language.UNKNOWN, type));
+ assertEquals(expected, embedder.embed(input, Language.UNKNOWN, null, type));
}
public void assertSegmented(String input, String... expectedSegments) {