aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-12-16 18:35:11 +0100
committerJon Bratseth <bratseth@gmail.com>2021-12-16 18:35:11 +0100
commit767cb63af0f530605180f5438767406e1db27520 (patch)
treec0ea9e8ec4ded2dea6064a45334e6f8a1408f7b8 /linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java
parent1eefb9854bcd7ca264889239b32e7fc8c8830672 (diff)
Add a BERT embedder
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java43
1 files changed, 43 insertions, 0 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java b/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java
new file mode 100644
index 00000000000..401347cc452
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java
@@ -0,0 +1,43 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.tools;
+
+import com.yahoo.language.process.Embedder;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.List;
+
+/**
+ * Component internal helpers for embedding
+ *
+ * @author bratseth
+ */
+public class Embed {
+
+ /**
+ * Convenience function which embeds the given string into the given tensor type (if possible),
+ * using the given embedder.
+ */
+ public static Tensor asTensor(String text,
+ Embedder embedder,
+ Embedder.Context context,
+ TensorType type) {
+ if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
+ // Build to a list first since we can't reverse a tensor builder
+ List<Integer> values = embedder.embed(text, context);
+
+ long maxSize = values.size();
+ if (type.dimensions().get(0).size().isPresent())
+ maxSize = Math.min(maxSize, type.dimensions().get(0).size().get());
+
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ for (int i = 0; i < maxSize; i++)
+ builder.cell(values.get(i), i);
+ return builder.build();
+ }
+ else {
+ throw new IllegalArgumentException("Don't know how to embed into " + type);
+ }
+ }
+
+}