aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java
blob: f5804419387596fd5dac2975852aa95e07fcc399 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.language.huggingface;

import com.yahoo.api.annotations.Beta;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * @author bjorncs
 */
@Beta
public record Encoding(
        List<Long> ids, List<Long> typeIds, List<String> tokens, List<Long> wordIds, List<Long> attentionMask,
        List<Long> specialTokenMask, List<CharSpan> charTokenSpans, List<Encoding> overflowing) {

    public record CharSpan(int start, int end) {
        public static final CharSpan NONE = new CharSpan(-1, -1);
        static CharSpan from(ai.djl.huggingface.tokenizers.jni.CharSpan s) {
            if (s == null) return NONE;
            return new CharSpan(s.getStart(), s.getEnd());
        }
        public boolean isNone() { return this.equals(NONE); }
    }

    public Encoding {
        ids = List.copyOf(ids);
        typeIds = List.copyOf(typeIds);
        tokens = List.copyOf(tokens);
        wordIds = List.copyOf(wordIds);
        attentionMask = List.copyOf(attentionMask);
        specialTokenMask = List.copyOf(specialTokenMask);
        charTokenSpans = List.copyOf(charTokenSpans);
        overflowing = List.copyOf(overflowing);
    }

    static Encoding from(ai.djl.huggingface.tokenizers.Encoding e) {
        return new Encoding(
                toList(e.getIds()),
                toList(e.getTypeIds()),
                List.of(e.getTokens()),
                toList(e.getWordIds()),
                toList(e.getAttentionMask()),
                toList(e.getSpecialTokenMask()),
                Arrays.stream(e.getCharTokenSpans()).map(CharSpan::from).toList(),
                Arrays.stream(e.getOverflowing()).map(Encoding::from).toList());
    }

    private static List<Long> toList(long[] array) {
        if (array == null) return List.of();
        var list = new ArrayList<Long>(array.length);
        for (long e : array) list.add(e);
        return list;
    }
}