diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-12-16 18:35:11 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-12-16 18:35:11 +0100 |
commit | 767cb63af0f530605180f5438767406e1db27520 (patch) | |
tree | c0ea9e8ec4ded2dea6064a45334e6f8a1408f7b8 /linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java | |
parent | 1eefb9854bcd7ca264889239b32e7fc8c8830672 (diff) |
Add a BERT embedder
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java b/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java new file mode 100644 index 00000000000..401347cc452 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java @@ -0,0 +1,43 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.tools; + +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.List; + +/** + * Component internal helpers for embedding + * + * @author bratseth + */ +public class Embed { + + /** + * Convenience function which embeds the given string into the given tensor type (if possible), + * using the given embedder. + */ + public static Tensor asTensor(String text, + Embedder embedder, + Embedder.Context context, + TensorType type) { + if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { + // Build to a list first since we can't reverse a tensor builder + List<Integer> values = embedder.embed(text, context); + + long maxSize = values.size(); + if (type.dimensions().get(0).size().isPresent()) + maxSize = Math.min(maxSize, type.dimensions().get(0).size().get()); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int i = 0; i < maxSize; i++) + builder.cell(values.get(i), i); + return builder.build(); + } + else { + throw new IllegalArgumentException("Don't know how to embed into " + type); + } + } + +} |