summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java')
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java59
1 files changed, 59 insertions, 0 deletions
diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
new file mode 100644
index 00000000000..9599e60e556
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
@@ -0,0 +1,59 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.tools;
+
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.language.process.Segmenter;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tester of embedders.
+ *
+ * @author bratseth
+ */
+public class EmbedderTester {
+
+ private final Embedder embedder;
+
+ public EmbedderTester(Embedder embedder) {
+ this.embedder = embedder;
+ }
+
+ /**
+ * Tests both embedding to a list of id's and encoding the same ids to a vector of the given type.
+ *
+ * @param expectedCodes all the expected codes of the given input, not including any trailing 0-paddings
+ * required for the tensor only
+ */
+ public void assertEmbedded(String input, String tensorType, Integer... expectedCodes) {
+ TensorType type = TensorType.fromSpec(tensorType);
+ assertEquals(1, type.dimensions().size());
+ assertTrue(type.dimensions().get(0).isIndexed());
+
+ int tensorSize = type.dimensions().get(0).size().get().intValue();
+
+ assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray());
+
+ var builder = Tensor.Builder.of(type);
+ for (int i = 0; i < tensorSize; i++)
+ builder.cell(i < expectedCodes.length ? expectedCodes[i] : 0, i);
+ assertEquals(builder.build(), embedder.embed(input, new Embedder.Context("destination"), type));
+ }
+
+ public void assertSegmented(String input, String... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+
+ public void assertSegmented(Language language, String input, String... expectedSegments) {
+ assertArrayEquals(expectedSegments, ((Segmenter)embedder).segment(input, language).toArray());
+ }
+
+}