aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
blob: 28cbf4ef64b426ef11f149b7e9d8a5d5df83c021 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;

import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;

import java.util.Optional;

/**
 * Class used by clients for serializing a Tensor object into binary format or
 * de-serializing binary data into a Tensor object.
 *
 * The actual binary format used is not a concern for the client and
 * is hidden in this class and in the binary data.
 *
 * @author geirst
 */
public class TypedBinaryFormat {

    private static final int SPARSE_BINARY_FORMAT_TYPE = 1;
    private static final int DENSE_BINARY_FORMAT_TYPE = 2;
    private static final int MIXED_BINARY_FORMAT_TYPE = 3;
    private static final int SPARSE_BINARY_FORMAT_WITH_CELLTYPE = 5;
    private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6;
    private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7;

    private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing
    private static final int FLOAT_VALUE_TYPE = 1;
    private static final int BFLOAT16_VALUE_TYPE = 2;
    private static final int INT8_VALUE_TYPE = 3;

    public static byte[] encode(Tensor tensor) {
        GrowableByteBuffer buffer = new GrowableByteBuffer();
        return asByteArray(encode(tensor, buffer));
    }
    public static GrowableByteBuffer encode(Tensor tensor, GrowableByteBuffer buffer) {
        BinaryFormat encoder = getFormatEncoder(buffer, tensor);
        encoder.encode(buffer, tensor);
        return buffer;
    }

    /**
     * Decode some data to a tensor
     *
     * @param type the type to decode and validate to, or empty to use the type given in the data
     * @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array
     * @return the resulting tensor
     * @throws IllegalArgumentException if the tensor data was invalid
     */
    public static Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer) {
        BinaryFormat decoder = getFormatDecoder(buffer);
        return decoder.decode(type, buffer);
    }

    private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) {
        boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
        boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
        boolean isMixed = hasMappedDimensions && hasIndexedDimensions;

        // TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead
        if (tensor instanceof MixedTensor && ! isMixed && hasIndexedDimensions)
            isMixed = true;

        if (isMixed && tensor.type().valueType() == TensorType.Value.DOUBLE) {
            encodeFormatType(buffer, MIXED_BINARY_FORMAT_TYPE);
            return new MixedBinaryFormat();
        }
        else if (isMixed) {
            encodeFormatType(buffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE);
            encodeValueType(buffer, tensor.type().valueType());
            return new MixedBinaryFormat(tensor.type().valueType());
        }
        else if (hasIndexedDimensions && tensor.type().valueType() == TensorType.Value.DOUBLE) {
            encodeFormatType(buffer, DENSE_BINARY_FORMAT_TYPE);
            return new DenseBinaryFormat();
        }
        else if (hasIndexedDimensions) {
            encodeFormatType(buffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE);
            encodeValueType(buffer, tensor.type().valueType());
            return new DenseBinaryFormat(tensor.type().valueType());
        }
        else if (tensor.type().valueType() == TensorType.Value.DOUBLE) {
            encodeFormatType(buffer, SPARSE_BINARY_FORMAT_TYPE);
            return new SparseBinaryFormat();
        }
        else {
            encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE);
            encodeValueType(buffer, tensor.type().valueType());
            return new SparseBinaryFormat(tensor.type().valueType());
        }
    }

    private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) {
        int formatType = decodeFormatType(buffer);
        switch (formatType) {
            case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat();
            case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat();
            case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat();
            case SPARSE_BINARY_FORMAT_WITH_CELLTYPE: return new SparseBinaryFormat(decodeValueType(buffer));
            case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(decodeValueType(buffer));
            case MIXED_BINARY_FORMAT_WITH_CELLTYPE: return new MixedBinaryFormat(decodeValueType(buffer));
        }
        throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
    }

    private static void encodeFormatType(GrowableByteBuffer buffer, int formatType) {
        buffer.putInt1_4Bytes(formatType);
    }

    private static int decodeFormatType(GrowableByteBuffer buffer) {
        return buffer.getInt1_4Bytes();
    }

    private static void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) {
        switch (valueType) {
            case DOUBLE -> buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE);
            case FLOAT -> buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE);
            case BFLOAT16 -> buffer.putInt1_4Bytes(BFLOAT16_VALUE_TYPE);
            case INT8 -> buffer.putInt1_4Bytes(INT8_VALUE_TYPE);
            default -> throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType);
        }
    }

    private static TensorType.Value decodeValueType(GrowableByteBuffer buffer) {
        int valueType = buffer.getInt1_4Bytes();
        switch (valueType) {
            case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE;
            case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT;
            case BFLOAT16_VALUE_TYPE: return TensorType.Value.BFLOAT16;
            case INT8_VALUE_TYPE: return TensorType.Value.INT8;
        }
        throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. " +
                "Only 0(double), 1(float), 2(bfloat16), or 3(int8) is legal.");
    }

    private static byte[] asByteArray(GrowableByteBuffer buffer) {
        buffer.flip();
        byte[] result = new byte[buffer.remaining()];
        buffer.get(result);
        return result;
    }

    static short bFloat16BitsFromFloat(float val) {
        return (short) (Float.floatToRawIntBits(val) >>> 16);
    }

    static float floatFromBFloat16Bits(short bits) {
        return Float.intBitsToFloat(bits << 16);
    }

}