aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 17:25:55 +0200
committerGitHub <noreply@github.com>2023-06-12 17:25:55 +0200
commitfb5d1bf9f451fbeb4a40d7f73fa856ef81bd77ed (patch)
tree1f8ab291370e84407a827f9a80bcf943f522ca29 /linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
parent0647b650c3334ff86d50431e78549e25dc46caf9 (diff)
parent4f722322cc9f8df5146ffb27d74239b3b4f2d634 (diff)
Merge pull request #27387 from vespa-engine/bjorncs/hfv8.176.13
Prefer truncation configuration from tokenizer model
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java41
1 files changed, 41 insertions, 0 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
new file mode 100644
index 00000000000..4b30b1f0435
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
@@ -0,0 +1,41 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.huggingface;
+
+import java.util.Arrays;
+
+/**
+ * @author bjorncs
+ */
+public record ModelInfo(
+ TruncationStrategy truncation, PaddingStrategy padding, int maxLength, int stride, int padToMultipleOf) {
+
+ public enum TruncationStrategy {
+ LONGEST_FIRST,
+ ONLY_FIRST,
+ ONLY_SECOND,
+ DO_NOT_TRUNCATE;
+
+ public static TruncationStrategy fromString(String v) {
+ if ("true".equals(v)) return LONGEST_FIRST;
+ else if ("false".equals(v)) return DO_NOT_TRUNCATE;
+ return Arrays.stream(values())
+ .filter(s -> s.name().equalsIgnoreCase(v))
+ .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v)));
+ }
+ }
+
+ public enum PaddingStrategy {
+ LONGEST,
+ MAX_LENGTH,
+ DO_NOT_PAD;
+
+ public static PaddingStrategy fromString(String v) {
+ if ("true".equals(v)) return LONGEST;
+ else if ("false".equals(v)) return DO_NOT_PAD;
+ return Arrays.stream(values())
+ .filter(s -> s.name().equalsIgnoreCase(v))
+ .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v)));
+ }
+ }
+}